diff --git a/server/db/models/user.go b/server/db/models/user.go index 12709d4..46b445b 100644 --- a/server/db/models/user.go +++ b/server/db/models/user.go @@ -38,8 +38,13 @@ func (user *User) AsAPIUser() *model.User { email := user.Email createdAt := user.CreatedAt updatedAt := user.UpdatedAt + + id := user.ID + if strings.Contains(id, Collections.WebhookLog+"/") { + id = strings.TrimPrefix(id, Collections.WebhookLog+"/") + } return &model.User{ - ID: user.ID, + ID: id, Email: user.Email, EmailVerified: isEmailVerified, SignupMethods: user.SignupMethods, diff --git a/server/db/models/verification_requests.go b/server/db/models/verification_requests.go index afd9ad7..9b8edb1 100644 --- a/server/db/models/verification_requests.go +++ b/server/db/models/verification_requests.go @@ -1,6 +1,10 @@ package models -import "github.com/authorizerdev/authorizer/server/graph/model" +import ( + "strings" + + "github.com/authorizerdev/authorizer/server/graph/model" +) // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation @@ -27,8 +31,13 @@ func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequ redirectURI := v.RedirectURI expires := v.ExpiresAt identifier := v.Identifier + + id := v.ID + if strings.Contains(id, Collections.WebhookLog+"/") { + id = strings.TrimPrefix(id, Collections.WebhookLog+"/") + } return &model.VerificationRequest{ - ID: v.ID, + ID: id, Token: &token, Identifier: &identifier, Expires: &expires, diff --git a/server/db/models/webhook.go b/server/db/models/webhook.go index 643d54e..8cd5108 100644 --- a/server/db/models/webhook.go +++ b/server/db/models/webhook.go @@ -2,6 +2,7 @@ package models import ( "encoding/json" + "strings" "github.com/authorizerdev/authorizer/server/graph/model" ) @@ -23,8 +24,13 @@ type Webhook struct { func (w *Webhook) AsAPIWebhook() *model.Webhook { headersMap := make(map[string]interface{}) json.Unmarshal([]byte(w.Headers), &headersMap) + + id := w.ID + if strings.Contains(id, Collections.Webhook+"/") { + id = strings.TrimPrefix(id, Collections.Webhook+"/") + } return &model.Webhook{ - ID: w.ID, + ID: id, EventName: &w.EventName, Endpoint: &w.EndPoint, Headers: headersMap, diff --git a/server/db/models/webhook_log.go b/server/db/models/webhook_log.go index 0239bcb..f8765f3 100644 --- a/server/db/models/webhook_log.go +++ b/server/db/models/webhook_log.go @@ -1,6 +1,10 @@ package models -import "github.com/authorizerdev/authorizer/server/graph/model" +import ( + "strings" + + "github.com/authorizerdev/authorizer/server/graph/model" +) // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation @@ -17,8 +21,12 @@ type WebhookLog struct { } func (w *WebhookLog) AsAPIWebhookLog() *model.WebhookLog { + id := w.ID + if strings.Contains(id, Collections.WebhookLog+"/") { + id = strings.TrimPrefix(id, Collections.WebhookLog+"/") + } return &model.WebhookLog{ - ID: w.ID, + ID: id, HTTPStatus: &w.HttpStatus, Response: &w.Response, Request: &w.Request, diff --git a/server/db/providers/arangodb/user.go b/server/db/providers/arangodb/user.go index a0eb32a..abc3ec0 100644 --- a/server/db/providers/arangodb/user.go +++ b/server/db/providers/arangodb/user.go @@ -63,9 +63,9 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { return err } - query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @userId REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session) + query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @user_id REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session) bindVars := map[string]interface{}{ - "userId": user.ID, + "user_id": user.ID, } cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { diff --git a/server/db/providers/arangodb/webhook.go b/server/db/providers/arangodb/webhook.go index 1a33d76..302eb61 100644 --- a/server/db/providers/arangodb/webhook.go +++ b/server/db/providers/arangodb/webhook.go @@ -83,7 +83,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { var webhook models.Webhook - query := fmt.Sprintf("FOR d in %s FILTER d._id == @webhook_id RETURN d", models.Collections.Webhook) + query := fmt.Sprintf("FOR d in %s FILTER d._key == @webhook_id RETURN d", models.Collections.Webhook) bindVars := map[string]interface{}{ "webhook_id": webhookID, } @@ -146,9 +146,9 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er return err } - query := fmt.Sprintf("FOR d in %s FILTER d.event_id == @event_id REMOVE { _key: d._key }", models.Collections.WebhookLog) + query := fmt.Sprintf("FOR d IN %s FILTER d.webhook_id == @webhook_id REMOVE { _key: d._key } IN %s", models.Collections.WebhookLog, models.Collections.WebhookLog) bindVars := map[string]interface{}{ - "event_id": webhook.ID, + "webhook_id": webhook.ID, } cursor, err := p.db.Query(ctx, query, bindVars) diff --git a/server/db/providers/arangodb/webhook_log.go b/server/db/providers/arangodb/webhook_log.go index d0337e1..bc758c4 100644 --- a/server/db/providers/arangodb/webhook_log.go +++ b/server/db/providers/arangodb/webhook_log.go @@ -37,11 +37,12 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit) if webhookID != "" { - query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhookID SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit) + query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhook_id SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit) bindVariables = map[string]interface{}{ "webhook_id": webhookID, } } + sctx := driver.WithQueryFullCount(ctx) cursor, err := p.db.Query(sctx, query, bindVariables) if err != nil { diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go index 3d90d2a..b47c324 100644 --- a/server/db/providers/cassandradb/provider.go +++ b/server/db/providers/cassandradb/provider.go @@ -143,6 +143,11 @@ func NewProvider() (*provider, error) { if err != nil { return nil, err } + sessionIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_session_user_id ON %s.%s (user_id)", KeySpace, models.Collections.Session) + err = session.Query(sessionIndexQuery).Exec() + if err != nil { + return nil, err + } userCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, email_verified_at bigint, password text, signup_methods text, given_name text, family_name text, middle_name text, nickname text, gender text, birthdate text, phone_number text, phone_number_verified_at bigint, picture text, roles text, updated_at bigint, created_at bigint, revoked_timestamp bigint, PRIMARY KEY (id))", KeySpace, models.Collections.User) err = session.Query(userCollectionQuery).Exec() @@ -177,7 +182,7 @@ func NewProvider() (*provider, error) { return nil, err } - webhookCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, event_name text, endpoint text, enabled boolean, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.Webhook) + webhookCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, event_name text, endpoint text, enabled boolean, headers text, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.Webhook) err = session.Query(webhookCollectionQuery).Exec() if err != nil { return nil, err diff --git a/server/db/providers/cassandradb/user.go b/server/db/providers/cassandradb/user.go index 730a04b..9489cdd 100644 --- a/server/db/providers/cassandradb/user.go +++ b/server/db/providers/cassandradb/user.go @@ -102,6 +102,10 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use continue } + if key == "_key" { + continue + } + if value == nil { updateFields += fmt.Sprintf("%s = null,", key) continue @@ -135,7 +139,19 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { return err } - deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE user_id = '%s'", KeySpace+"."+models.Collections.Session, user.ID) + getSessionsQuery := fmt.Sprintf("SELECT id FROM %s WHERE user_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.Session, user.ID) + scanner := p.db.Query(getSessionsQuery).Iter().Scanner() + sessionIDs := "" + for scanner.Next() { + var wlID string + err = scanner.Scan(&wlID) + if err != nil { + return err + } + sessionIDs += fmt.Sprintf("'%s',", wlID) + } + sessionIDs = strings.TrimSuffix(sessionIDs, ",") + deleteSessionQuery := fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.Session, sessionIDs) err = p.db.Query(deleteSessionQuery).Exec() if err != nil { return err @@ -181,7 +197,7 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( // GetUserByEmail to get user information from database using email address func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { var user models.User - query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, email) + query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email) err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.CreatedAt, &user.UpdatedAt) if err != nil { return user, err diff --git a/server/db/providers/cassandradb/webhook.go b/server/db/providers/cassandradb/webhook.go index 7f7a123..1954052 100644 --- a/server/db/providers/cassandradb/webhook.go +++ b/server/db/providers/cassandradb/webhook.go @@ -24,6 +24,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() + existingHook, _ := p.GetWebhookByEventName(ctx, webhook.EventName) + if existingHook != nil { + return nil, fmt.Errorf("Webhook with %s event_name already exists", webhook.EventName) + } + insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt) err := p.db.Query(insertQuery).Exec() if err != nil { @@ -56,6 +61,10 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* continue } + if key == "_key" { + continue + } + if value == nil { updateFields += fmt.Sprintf("%s = null,", key) continue @@ -72,7 +81,6 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* updateFields = strings.TrimSuffix(updateFields, ",") query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID) - err = p.db.Query(query).Exec() if err != nil { return nil, err @@ -130,7 +138,7 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model // GetWebhookByEventName to get webhook by event_name func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { var webhook models.Webhook - query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, eventName) + query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName) err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) if err != nil { return nil, err @@ -146,7 +154,19 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er return err } - query = fmt.Sprintf("DELETE FROM %s WHERE webhook_id = '%s'", KeySpace+"."+models.Collections.WebhookLog, webhook.ID) + getWebhookLogQuery := fmt.Sprintf("SELECT id FROM %s WHERE webhook_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhook.ID) + scanner := p.db.Query(getWebhookLogQuery).Iter().Scanner() + webhookLogIDs := "" + for scanner.Next() { + var wlID string + err = scanner.Scan(&wlID) + if err != nil { + return err + } + webhookLogIDs += fmt.Sprintf("'%s',", wlID) + } + webhookLogIDs = strings.TrimSuffix(webhookLogIDs, ",") + query = fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", KeySpace+"."+models.Collections.WebhookLog, webhookLogIDs) err = p.db.Query(query).Exec() return err } diff --git a/server/db/providers/cassandradb/webhook_log.go b/server/db/providers/cassandradb/webhook_log.go index ab979f7..9ecf939 100644 --- a/server/db/providers/cassandradb/webhook_log.go +++ b/server/db/providers/cassandradb/webhook_log.go @@ -40,8 +40,8 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat query := fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, pagination.Limit+pagination.Offset) if webhookID != "" { - totalCountQuery = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE webhook_id='%s'`, KeySpace+"."+models.Collections.WebhookLog, webhookID) - query = fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s WHERE webhook_id = '%s' LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, webhookID, pagination.Limit+pagination.Offset) + totalCountQuery = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE webhook_id='%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.WebhookLog, webhookID) + query = fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s WHERE webhook_id = '%s' LIMIT %d ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhookID, pagination.Limit+pagination.Offset) } err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) diff --git a/server/db/providers/mongodb/webhook.go b/server/db/providers/mongodb/webhook.go index b0c2c32..7b29398 100644 --- a/server/db/providers/mongodb/webhook.go +++ b/server/db/providers/mongodb/webhook.go @@ -111,7 +111,7 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er } webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection()) - _, err = webhookLogCollection.DeleteOne(nil, bson.M{"webhook_id": webhook.ID}, options.Delete()) + _, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete()) if err != nil { return err } diff --git a/server/env/persist_env.go b/server/env/persist_env.go index 381cfad..e9b849d 100644 --- a/server/env/persist_env.go +++ b/server/env/persist_env.go @@ -113,7 +113,7 @@ func PersistEnv() error { ctx := context.Background() env, err := db.Provider.GetEnv(ctx) // config not found in db - if err != nil { + if err != nil || env.EnvData == "" { // AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid hash := uuid.New().String()[:36-4] err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash) @@ -174,7 +174,7 @@ func PersistEnv() error { err = json.Unmarshal(decryptedConfigs, &storeData) if err != nil { - log.Debug("Error while unmarshalling env data: ", err) + log.Debug("Error while un-marshalling env data: ", err) return err } diff --git a/server/memorystore/providers/inmemory/stores/env_store.go b/server/memorystore/providers/inmemory/stores/env_store.go index d1a3feb..d6ffe6a 100644 --- a/server/memorystore/providers/inmemory/stores/env_store.go +++ b/server/memorystore/providers/inmemory/stores/env_store.go @@ -1,10 +1,7 @@ package stores import ( - "os" "sync" - - "github.com/authorizerdev/authorizer/server/constants" ) // EnvStore struct to store the env variables @@ -23,12 +20,10 @@ func NewEnvStore() *EnvStore { // UpdateEnvStore to update the whole env store object func (e *EnvStore) UpdateStore(store map[string]interface{}) { - if os.Getenv("ENV") != constants.TestEnv { - e.mutex.Lock() - defer e.mutex.Unlock() - } - // just override the keys + new keys + e.mutex.Lock() + defer e.mutex.Unlock() + // just override the keys + new keys for key, value := range store { e.store[key] = value } @@ -46,9 +41,8 @@ func (e *EnvStore) Get(key string) interface{} { // Set sets the value of the key in env store func (e *EnvStore) Set(key string, value interface{}) { - if os.Getenv("ENV") != constants.TestEnv { - e.mutex.Lock() - defer e.mutex.Unlock() - } + e.mutex.Lock() + defer e.mutex.Unlock() + e.store[key] = value } diff --git a/server/memorystore/providers/inmemory/stores/session_store.go b/server/memorystore/providers/inmemory/stores/session_store.go index d702fa0..ad617af 100644 --- a/server/memorystore/providers/inmemory/stores/session_store.go +++ b/server/memorystore/providers/inmemory/stores/session_store.go @@ -1,11 +1,8 @@ package stores import ( - "os" "strings" "sync" - - "github.com/authorizerdev/authorizer/server/constants" ) // SessionStore struct to store the env variables @@ -29,10 +26,9 @@ func (s *SessionStore) Get(key, subKey string) string { // Set sets the value of the key in state store func (s *SessionStore) Set(key string, subKey, value string) { - if os.Getenv("ENV") != constants.TestEnv { - s.mutex.Lock() - defer s.mutex.Unlock() - } + s.mutex.Lock() + defer s.mutex.Unlock() + if _, ok := s.store[key]; !ok { s.store[key] = make(map[string]string) } @@ -41,19 +37,15 @@ func (s *SessionStore) Set(key string, subKey, value string) { // RemoveAll all values for given key func (s *SessionStore) RemoveAll(key string) { - if os.Getenv("ENV") != constants.TestEnv { - s.mutex.Lock() - defer s.mutex.Unlock() - } + s.mutex.Lock() + defer s.mutex.Unlock() delete(s.store, key) } // Remove value for given key and subkey func (s *SessionStore) Remove(key, subKey string) { - if os.Getenv("ENV") != constants.TestEnv { - s.mutex.Lock() - defer s.mutex.Unlock() - } + s.mutex.Lock() + defer s.mutex.Unlock() if _, ok := s.store[key]; ok { delete(s.store[key], subKey) } @@ -69,11 +61,8 @@ func (s *SessionStore) GetAll(key string) map[string]string { // RemoveByNamespace to delete session for a given namespace example google,github func (s *SessionStore) RemoveByNamespace(namespace string) error { - if os.Getenv("ENV") != constants.TestEnv { - s.mutex.Lock() - defer s.mutex.Unlock() - } - + s.mutex.Lock() + defer s.mutex.Unlock() for key := range s.store { if strings.Contains(key, namespace+":") { delete(s.store, key) diff --git a/server/memorystore/providers/inmemory/stores/state_store.go b/server/memorystore/providers/inmemory/stores/state_store.go index 5e66b6e..2ba8417 100644 --- a/server/memorystore/providers/inmemory/stores/state_store.go +++ b/server/memorystore/providers/inmemory/stores/state_store.go @@ -1,10 +1,7 @@ package stores import ( - "os" "sync" - - "github.com/authorizerdev/authorizer/server/constants" ) // StateStore struct to store the env variables @@ -28,19 +25,16 @@ func (s *StateStore) Get(key string) string { // Set sets the value of the key in state store func (s *StateStore) Set(key string, value string) { - if os.Getenv("ENV") != constants.TestEnv { - s.mutex.Lock() - defer s.mutex.Unlock() - } + s.mutex.Lock() + defer s.mutex.Unlock() + s.store[key] = value } // Remove removes the key from state store func (s *StateStore) Remove(key string) { - if os.Getenv("ENV") != constants.TestEnv { - s.mutex.Lock() - defer s.mutex.Unlock() - } + s.mutex.Lock() + defer s.mutex.Unlock() delete(s.store, key) } diff --git a/server/test/delete_webhook_test.go b/server/test/delete_webhook_test.go index 471a5f5..55df1ae 100644 --- a/server/test/delete_webhook_test.go +++ b/server/test/delete_webhook_test.go @@ -24,7 +24,11 @@ func deleteWebhookTest(t *testing.T, s TestSetup) { req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // get all webhooks - webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{}) + webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{ + Limit: 10, + Page: 1, + Offset: 0, + }) assert.NoError(t, err) for _, w := range webhooks.Webhooks { @@ -37,12 +41,17 @@ func deleteWebhookTest(t *testing.T, s TestSetup) { assert.NotEmpty(t, res.Message) } - webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{}) + webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{ + Limit: 10, + Page: 1, + Offset: 0, + }) assert.NoError(t, err) assert.Len(t, webhooks.Webhooks, 0) - webhookLogs, err := db.Provider.ListWebhookLogs(ctx, model.Pagination{ - Limit: 10, + Limit: 100, + Page: 1, + Offset: 0, }, "") assert.NoError(t, err) assert.Len(t, webhookLogs.WebhookLogs, 0) diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index 29ff0c4..399ae84 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -2,8 +2,8 @@ package test import ( "context" + "os" "testing" - "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" @@ -13,31 +13,45 @@ import ( func TestResolvers(t *testing.T) { databases := map[string]string{ - constants.DbTypeSqlite: "../../data.db", - // constants.DbTypeArangodb: "http://localhost:8529", - // constants.DbTypeMongodb: "mongodb://localhost:27017", - // constants.DbTypeCassandraDB: "127.0.0.1:9042", + // constants.DbTypeSqlite: "../../data.db", + // constants.DbTypeArangodb: "http://localhost:8529", + // constants.DbTypeMongodb: "mongodb://localhost:27017", + constants.DbTypeScyllaDB: "127.0.0.1:9042", } + testDb := "authorizer_test" + s := testSetup() + defer s.Server.Close() + for dbType, dbURL := range databases { - s := testSetup() - defer s.Server.Close() ctx := context.Background() memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseURL, dbURL) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseType, dbType) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseName, testDb) + os.Setenv(constants.EnvKeyDatabaseURL, dbURL) + os.Setenv(constants.EnvKeyDatabaseType, dbType) + os.Setenv(constants.EnvKeyDatabaseName, testDb) + memorystore.InitRequiredEnv() + err := db.InitDB() if err != nil { - t.Errorf("Error initializing database: %s", err) + t.Errorf("Error initializing database: %s", err.Error()) } // clean the persisted config for test to use fresh config envData, err := db.Provider.GetEnv(ctx) if err == nil { envData.EnvData = "" - db.Provider.UpdateEnv(ctx, envData) + _, err = db.Provider.UpdateEnv(ctx, envData) + if err != nil { + t.Errorf("Error updating env: %s", err.Error()) + } + } + err = env.PersistEnv() + if err != nil { + t.Errorf("Error persisting env: %s", err.Error()) } - env.PersistEnv() memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEnv, "test") memorystore.Provider.UpdateEnvVariable(constants.EnvKeyIsProd, false) @@ -78,9 +92,8 @@ func TestResolvers(t *testing.T) { inviteUserTest(t, s) validateJwtTokenTest(t, s) - time.Sleep(5 * time.Second) // add sleep for webhooklogs to get generated as they are async - webhookLogsTest(t, s) // get logs after above resolver tests are done - deleteWebhookTest(t, s) // delete webhooks (admin resolver) + webhookLogsTest(t, s) // get logs after above resolver tests are done + deleteWebhookTest(t, s) // delete webhooks (admin resolver) }) } } diff --git a/server/test/webhook_logs_test.go b/server/test/webhook_logs_test.go index 4be6889..97b69de 100644 --- a/server/test/webhook_logs_test.go +++ b/server/test/webhook_logs_test.go @@ -3,6 +3,7 @@ package test import ( "fmt" "testing" + "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" @@ -14,6 +15,7 @@ import ( ) func webhookLogsTest(t *testing.T, s TestSetup) { + time.Sleep(30 * time.Second) // add sleep for webhooklogs to get generated as they are async t.Helper() t.Run("should get webhook logs", func(t *testing.T) { req, ctx := createContext(s) @@ -23,23 +25,25 @@ func webhookLogsTest(t *testing.T, s TestSetup) { assert.NoError(t, err) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) - webhooks, err := resolvers.WebhooksResolver(ctx, nil) - assert.NoError(t, err) - assert.NotEmpty(t, webhooks) - webhookLogs, err := resolvers.WebhookLogsResolver(ctx, nil) assert.NoError(t, err) assert.Greater(t, len(webhookLogs.WebhookLogs), 1) + webhooks, err := resolvers.WebhooksResolver(ctx, nil) + assert.NoError(t, err) + assert.NotEmpty(t, webhooks) + for _, w := range webhooks.Webhooks { - webhookLogs, err := resolvers.WebhookLogsResolver(ctx, &model.ListWebhookLogRequest{ - WebhookID: &w.ID, + t.Run(fmt.Sprintf("should get webhook for webhook_id:%s", w.ID), func(t *testing.T) { + webhookLogs, err := resolvers.WebhookLogsResolver(ctx, &model.ListWebhookLogRequest{ + WebhookID: &w.ID, + }) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(webhookLogs.WebhookLogs), 1) + for _, wl := range webhookLogs.WebhookLogs { + assert.Equal(t, utils.StringValue(wl.WebhookID), w.ID) + } }) - assert.NoError(t, err) - assert.GreaterOrEqual(t, len(webhookLogs.WebhookLogs), 1) - for _, wl := range webhookLogs.WebhookLogs { - assert.Equal(t, utils.StringValue(wl.WebhookID), w.ID) - } } }) }