diff --git a/server/db/providers/arangodb/otp.go b/server/db/providers/arangodb/otp.go index 0b546c5..0e24095 100644 --- a/server/db/providers/arangodb/otp.go +++ b/server/db/providers/arangodb/otp.go @@ -9,40 +9,40 @@ import ( "github.com/google/uuid" ) -// AddOTP to add otp -func (p *provider) AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { - if otp.ID == "" { +// UpsertOTP to add or update otp +func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { + otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) + shouldCreate := false + if otp == nil { + shouldCreate = true otp.ID = uuid.New().String() + otp.Key = otp.ID + otp.CreatedAt = time.Now().Unix() + } else { + otp = otpParam } - otp.Key = otp.ID - otp.CreatedAt = time.Now().Unix() otp.UpdatedAt = time.Now().Unix() - otpCollection, _ := p.db.Collection(ctx, models.Collections.OTP) - _, err := otpCollection.CreateDocument(ctx, otp) - if err != nil { - return nil, err + + if shouldCreate { + _, err := otpCollection.CreateDocument(ctx, otp) + if err != nil { + return nil, err + } + } else { + meta, err := otpCollection.UpdateDocument(ctx, otp.Key, otp) + if err != nil { + return nil, err + } + + otp.Key = meta.Key + otp.ID = meta.ID.String() } return otp, nil } -// UpdateOTP to update otp for a given email address -func (p *provider) UpdateOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { - otp.UpdatedAt = time.Now().Unix() - - otpCollection, _ := p.db.Collection(ctx, models.Collections.OTP) - meta, err := otpCollection.UpdateDocument(ctx, otp.Key, otp) - if err != nil { - return nil, err - } - - otp.Key = meta.Key - otp.ID = meta.ID.String() - return otp, nil -} - // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { var otp *models.OTP diff --git a/server/db/providers/cassandradb/otp.go b/server/db/providers/cassandradb/otp.go index a057f7d..7ead206 100644 --- a/server/db/providers/cassandradb/otp.go +++ b/server/db/providers/cassandradb/otp.go @@ -2,37 +2,60 @@ package cassandradb import ( "context" + "fmt" "time" "github.com/authorizerdev/authorizer/server/db/models" + "github.com/gocql/gocql" "github.com/google/uuid" ) -// AddOTP to add otp -func (p *provider) AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { - if otp.ID == "" { +// UpsertOTP to add or update otp +func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { + otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) + shouldCreate := false + if otp == nil { + shouldCreate = true otp.ID = uuid.New().String() + otp.Key = otp.ID + otp.CreatedAt = time.Now().Unix() + } else { + otp = otpParam } - otp.Key = otp.ID - otp.CreatedAt = time.Now().Unix() otp.UpdatedAt = time.Now().Unix() + query := "" - return otp, nil -} + if shouldCreate { + query = fmt.Sprintf(`INSERT INTO %s (id, email, otp, expires_at, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d, %d)`, KeySpace+"."+models.Collections.OTP, otp.ID, otp.Email, otp.Otp, otp.ExpiresAt, otp.CreatedAt, otp.UpdatedAt) + } else { + query = fmt.Sprintf(`UPDATE %s SET otp = '%s', expires_at = %d, updated_at = %d WHERE email = '%s'`, KeySpace+"."+models.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.Email) + } + err := p.db.Query(query).Exec() + if err != nil { + return nil, err + } -// UpdateOTP to update otp for a given email address -func (p *provider) UpdateOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { - otp.UpdatedAt = time.Now().Unix() return otp, nil } // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { - return nil, nil + var otp models.OTP + query := fmt.Sprintf(`SELECT id, email, otp, expires_at, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.OTP, emailAddress) + err := p.db.Query(query).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt) + if err != nil { + return nil, err + } + return &otp, nil } // DeleteOTP to delete otp func (p *provider) DeleteOTP(ctx context.Context, otp *models.OTP) error { + query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.OTP, otp.ID) + err := p.db.Query(query).Exec() + if err != nil { + return err + } return nil } diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go index 9b35767..80a9cb6 100644 --- a/server/db/providers/cassandradb/provider.go +++ b/server/db/providers/cassandradb/provider.go @@ -221,6 +221,17 @@ func NewProvider() (*provider, error) { return nil, err } + otpCollection := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, email text, otp text, expires_at bigint, updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.OTP) + err = session.Query(otpCollection).Exec() + if err != nil { + return nil, err + } + otpIndexQuery := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_otp_email ON %s.%s (email)", KeySpace, models.Collections.OTP) + err = session.Query(otpIndexQuery).Exec() + if err != nil { + return nil, err + } + return &provider{ db: session, }, err diff --git a/server/db/providers/mongodb/otp.go b/server/db/providers/mongodb/otp.go index 715b02b..c3f637e 100644 --- a/server/db/providers/mongodb/otp.go +++ b/server/db/providers/mongodb/otp.go @@ -10,31 +10,20 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) -// AddOTP to add otp -func (p *provider) AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { +// UpsertOTP to add or update otp +func (p *provider) UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { if otp.ID == "" { otp.ID = uuid.New().String() } otp.Key = otp.ID - otp.CreatedAt = time.Now().Unix() - otp.UpdatedAt = time.Now().Unix() - - otpCollection := p.db.Collection(models.Collections.OTP, options.Collection()) - _, err := otpCollection.InsertOne(ctx, otp) - if err != nil { - return nil, err + if otp.CreatedAt <= 0 { + otp.CreatedAt = time.Now().Unix() } - - return otp, nil -} - -// UpdateOTP to update otp for a given email address -func (p *provider) UpdateOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { otp.UpdatedAt = time.Now().Unix() otpCollection := p.db.Collection(models.Collections.OTP, options.Collection()) - _, err := otpCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": otp.ID}}, bson.M{"$set": otp}, options.MergeUpdateOptions()) + _, err := otpCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": otp.ID}}, bson.M{"$set": otp}, options.MergeUpdateOptions().SetUpsert(true)) if err != nil { return nil, err } diff --git a/server/db/providers/provider_template/otp.go b/server/db/providers/provider_template/otp.go index e58c5eb..d8685e7 100644 --- a/server/db/providers/provider_template/otp.go +++ b/server/db/providers/provider_template/otp.go @@ -2,23 +2,13 @@ package provider_template import ( "context" - "time" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/google/uuid" ) -// AddOTP to add otp -func (p *provider) AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { - if otp.ID == "" { - otp.ID = uuid.New().String() - } - - otp.Key = otp.ID - otp.CreatedAt = time.Now().Unix() - otp.UpdatedAt = time.Now().Unix() - - return otp, nil +// UpsertOTP to add or update otp +func (p *provider) UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { + return nil, nil } // GetOTPByEmail to get otp for a given email address diff --git a/server/db/providers/providers.go b/server/db/providers/providers.go index cb3493a..da72190 100644 --- a/server/db/providers/providers.go +++ b/server/db/providers/providers.go @@ -73,10 +73,8 @@ type Provider interface { // DeleteEmailTemplate to delete EmailTemplate DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error - // AddOTP to add otp - AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) - // UpdateOTP to update otp for a given email address - UpdateOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) + // UpsertOTP to add or update otp + UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) // GetOTPByEmail to get otp for a given email address GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) // DeleteOTP to delete otp diff --git a/server/db/providers/sql/otp.go b/server/db/providers/sql/otp.go index c83394b..8fd7780 100644 --- a/server/db/providers/sql/otp.go +++ b/server/db/providers/sql/otp.go @@ -6,10 +6,11 @@ import ( "github.com/authorizerdev/authorizer/server/db/models" "github.com/google/uuid" + "gorm.io/gorm/clause" ) -// AddOTP to add otp -func (p *provider) AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { +// UpsertOTP to add or update otp +func (p *provider) UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { if otp.ID == "" { otp.ID = uuid.New().String() } @@ -18,7 +19,10 @@ func (p *provider) AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, er otp.CreatedAt = time.Now().Unix() otp.UpdatedAt = time.Now().Unix() - res := p.db.Create(&otp) + res := p.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "email"}}, + DoUpdates: clause.AssignmentColumns([]string{"otp", "expires_at", "updated_at"}), + }).Create(&otp) if res.Error != nil { return nil, res.Error } @@ -26,17 +30,6 @@ func (p *provider) AddOTP(ctx context.Context, otp *models.OTP) (*models.OTP, er return otp, nil } -// UpdateOTP to update otp for a given email address -func (p *provider) UpdateOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) { - otp.UpdatedAt = time.Now().Unix() - - res := p.db.Save(&otp) - if res.Error != nil { - return nil, res.Error - } - return otp, nil -} - // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { var otp *models.OTP