diff --git a/.env.test b/.env.test index 99c8081..293248a 100644 --- a/.env.test +++ b/.env.test @@ -7,5 +7,9 @@ SMTP_PORT=2525 SMTP_USERNAME=test SMTP_PASSWORD=test SENDER_EMAIL="info@authorizer.dev" +TWILIO_API_KEY=test +TWILIO_API_SECRET=test +TWILIO_ACCOUNT_SID=ACXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX +TWILIO_SENDER=909921212112 SENDER_NAME="Authorizer" AWS_REGION=ap-south-1 \ No newline at end of file diff --git a/app/package-lock.json b/app/package-lock.json index cf332cb..1b7f5a4 100644 --- a/app/package-lock.json +++ b/app/package-lock.json @@ -9,7 +9,7 @@ "version": "1.0.0", "license": "ISC", "dependencies": { - "@authorizerdev/authorizer-react": "^1.1.11", + "@authorizerdev/authorizer-react": "^1.1.13", "@types/react": "^17.0.15", "@types/react-dom": "^17.0.9", "esbuild": "^0.12.17", @@ -27,9 +27,9 @@ } }, "node_modules/@authorizerdev/authorizer-js": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-1.2.3.tgz", - "integrity": "sha512-rk/fMRIsqbp+fsy2y09etVjf7CY9/4mG6hf0RKgXgRRfxtAQa1jdkt/De23hBTNeEwAWu6hP/9BQZjcrln6KtA==", + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-1.2.6.tgz", + "integrity": "sha512-9+9phHUMF+AeDM0y+XQvIRDoerOXnQ1vfTfYN6KxWN1apdrkAd9nzS1zUsA2uJSnX3fFZOErn83GjbYYCYF1BA==", "dependencies": { "cross-fetch": "^3.1.5" }, @@ -41,11 +41,11 @@ } }, "node_modules/@authorizerdev/authorizer-react": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-1.1.11.tgz", - "integrity": "sha512-tSI/yjsoeK/RvCOMiHSf1QGOeSpaLYQZEM864LFLndKoJwk7UWCJ86qg1w6ge7B00PmZSNWqST/w5JTcQaVNpw==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-1.1.13.tgz", + "integrity": "sha512-LmpzyfR0+nEn+bjUrb/QU9b3kiVoYzMBIvcQ1nV4TNvrvVSqbLPKk+GmoIPkiBEtfy/QSM6XFLkiGNGD9BRP+g==", "dependencies": { - "@authorizerdev/authorizer-js": "^1.2.3" + "@authorizerdev/authorizer-js": "^1.2.6" }, "engines": { "node": ">=10" @@ -406,11 +406,11 @@ "integrity": "sha1-p9BVi9icQveV3UIyj3QIMcpTvCU=" }, "node_modules/cross-fetch": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.5.tgz", - "integrity": "sha512-lvb1SBsI0Z7GDwmuid+mU3kWVBwTVUbe7S0H52yaaAdQOXq2YktTCZdlAcNKFzE6QtRz0snpw9bNiPeOIkkQvw==", + "version": "3.1.8", + "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.8.tgz", + "integrity": "sha512-cvA+JwZoU0Xq+h6WkMvAUqPEYy92Obet6UdKLfW60qn99ftItKjB5T+BkyWOFWe2pUyfQ+IJHmpOTznqk1M6Kg==", "dependencies": { - "node-fetch": "2.6.7" + "node-fetch": "^2.6.12" } }, "node_modules/css-color-keywords": { @@ -567,9 +567,9 @@ "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" }, "node_modules/node-fetch": { - "version": "2.6.7", - "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.7.tgz", - "integrity": "sha512-ZjMPFEfVx5j+y2yF35Kzx5sF7kDzxuDj6ziH4FFbOp87zKDZNx8yExJIb05OGF4Nlt9IHFIMBkRl41VdvcNdbQ==", + "version": "2.6.12", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.12.tgz", + "integrity": "sha512-C/fGU2E8ToujUivIO0H+tpQ6HWo4eEmchoPIoXtxCrVghxdKq+QOHqEZW7tuP3KlV3bC8FRMO5nMCC7Zm1VP6g==", "dependencies": { "whatwg-url": "^5.0.0" }, @@ -837,19 +837,19 @@ }, "dependencies": { "@authorizerdev/authorizer-js": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-1.2.3.tgz", - "integrity": "sha512-rk/fMRIsqbp+fsy2y09etVjf7CY9/4mG6hf0RKgXgRRfxtAQa1jdkt/De23hBTNeEwAWu6hP/9BQZjcrln6KtA==", + "version": "1.2.6", + "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-1.2.6.tgz", + "integrity": "sha512-9+9phHUMF+AeDM0y+XQvIRDoerOXnQ1vfTfYN6KxWN1apdrkAd9nzS1zUsA2uJSnX3fFZOErn83GjbYYCYF1BA==", "requires": { "cross-fetch": "^3.1.5" } }, "@authorizerdev/authorizer-react": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-1.1.11.tgz", - "integrity": "sha512-tSI/yjsoeK/RvCOMiHSf1QGOeSpaLYQZEM864LFLndKoJwk7UWCJ86qg1w6ge7B00PmZSNWqST/w5JTcQaVNpw==", + "version": "1.1.13", + "resolved": "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-1.1.13.tgz", + "integrity": "sha512-LmpzyfR0+nEn+bjUrb/QU9b3kiVoYzMBIvcQ1nV4TNvrvVSqbLPKk+GmoIPkiBEtfy/QSM6XFLkiGNGD9BRP+g==", "requires": { - "@authorizerdev/authorizer-js": "^1.2.3" + "@authorizerdev/authorizer-js": "^1.2.6" } }, "@babel/code-frame": { @@ -1144,11 +1144,11 @@ "integrity": "sha1-p9BVi9icQveV3UIyj3QIMcpTvCU=" }, "cross-fetch": { - "version": "3.1.5", - "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.5.tgz", - "integrity": "sha512-lvb1SBsI0Z7GDwmuid+mU3kWVBwTVUbe7S0H52yaaAdQOXq2YktTCZdlAcNKFzE6QtRz0snpw9bNiPeOIkkQvw==", + "version": "3.1.8", + "resolved": "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.8.tgz", + "integrity": "sha512-cvA+JwZoU0Xq+h6WkMvAUqPEYy92Obet6UdKLfW60qn99ftItKjB5T+BkyWOFWe2pUyfQ+IJHmpOTznqk1M6Kg==", "requires": { - "node-fetch": "2.6.7" + "node-fetch": "^2.6.12" } }, "css-color-keywords": { @@ -1270,9 +1270,9 @@ "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==" }, "node-fetch": { - "version": "2.6.7", - "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.7.tgz", - "integrity": "sha512-ZjMPFEfVx5j+y2yF35Kzx5sF7kDzxuDj6ziH4FFbOp87zKDZNx8yExJIb05OGF4Nlt9IHFIMBkRl41VdvcNdbQ==", + "version": "2.6.12", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.12.tgz", + "integrity": "sha512-C/fGU2E8ToujUivIO0H+tpQ6HWo4eEmchoPIoXtxCrVghxdKq+QOHqEZW7tuP3KlV3bC8FRMO5nMCC7Zm1VP6g==", "requires": { "whatwg-url": "^5.0.0" } diff --git a/app/package.json b/app/package.json index 1225346..2406108 100644 --- a/app/package.json +++ b/app/package.json @@ -12,7 +12,7 @@ "author": "Lakhan Samani", "license": "ISC", "dependencies": { - "@authorizerdev/authorizer-react": "^1.1.11", + "@authorizerdev/authorizer-react": "^1.1.13", "@types/react": "^17.0.15", "@types/react-dom": "^17.0.9", "esbuild": "^0.12.17", diff --git a/app/yarn.lock b/app/yarn.lock index 380f761..2be982c 100644 --- a/app/yarn.lock +++ b/app/yarn.lock @@ -2,19 +2,19 @@ # yarn lockfile v1 -"@authorizerdev/authorizer-js@^1.2.3": - "integrity" "sha512-rk/fMRIsqbp+fsy2y09etVjf7CY9/4mG6hf0RKgXgRRfxtAQa1jdkt/De23hBTNeEwAWu6hP/9BQZjcrln6KtA==" - "resolved" "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-1.2.3.tgz" - "version" "1.2.3" +"@authorizerdev/authorizer-js@^1.2.6": + "integrity" "sha512-9+9phHUMF+AeDM0y+XQvIRDoerOXnQ1vfTfYN6KxWN1apdrkAd9nzS1zUsA2uJSnX3fFZOErn83GjbYYCYF1BA==" + "resolved" "https://registry.npmjs.org/@authorizerdev/authorizer-js/-/authorizer-js-1.2.6.tgz" + "version" "1.2.6" dependencies: "cross-fetch" "^3.1.5" -"@authorizerdev/authorizer-react@^1.1.11": - "integrity" "sha512-tSI/yjsoeK/RvCOMiHSf1QGOeSpaLYQZEM864LFLndKoJwk7UWCJ86qg1w6ge7B00PmZSNWqST/w5JTcQaVNpw==" - "resolved" "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-1.1.11.tgz" - "version" "1.1.11" +"@authorizerdev/authorizer-react@^1.1.13": + "integrity" "sha512-LmpzyfR0+nEn+bjUrb/QU9b3kiVoYzMBIvcQ1nV4TNvrvVSqbLPKk+GmoIPkiBEtfy/QSM6XFLkiGNGD9BRP+g==" + "resolved" "https://registry.npmjs.org/@authorizerdev/authorizer-react/-/authorizer-react-1.1.13.tgz" + "version" "1.1.13" dependencies: - "@authorizerdev/authorizer-js" "^1.2.3" + "@authorizerdev/authorizer-js" "^1.2.6" "@babel/code-frame@^7.16.7": "integrity" "sha512-iAXqUn8IIeBTNd72xsFlgaXHkMBMt6y4HJp1tIaK465CWLT/fG1aqB7ykr95gHHmlBdGbFeWWfyB4NJJ0nmeIg==" @@ -278,11 +278,11 @@ "version" "1.1.3" "cross-fetch@^3.1.5": - "integrity" "sha512-lvb1SBsI0Z7GDwmuid+mU3kWVBwTVUbe7S0H52yaaAdQOXq2YktTCZdlAcNKFzE6QtRz0snpw9bNiPeOIkkQvw==" - "resolved" "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.5.tgz" - "version" "3.1.5" + "integrity" "sha512-cvA+JwZoU0Xq+h6WkMvAUqPEYy92Obet6UdKLfW60qn99ftItKjB5T+BkyWOFWe2pUyfQ+IJHmpOTznqk1M6Kg==" + "resolved" "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.8.tgz" + "version" "3.1.8" dependencies: - "node-fetch" "2.6.7" + "node-fetch" "^2.6.12" "css-color-keywords@^1.0.0": "integrity" "sha1-/qJhbcZ2spYmhrOvjb2+GAskTgU=" @@ -389,10 +389,10 @@ "resolved" "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz" "version" "2.1.2" -"node-fetch@2.6.7": - "integrity" "sha512-ZjMPFEfVx5j+y2yF35Kzx5sF7kDzxuDj6ziH4FFbOp87zKDZNx8yExJIb05OGF4Nlt9IHFIMBkRl41VdvcNdbQ==" - "resolved" "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.7.tgz" - "version" "2.6.7" +"node-fetch@^2.6.12": + "integrity" "sha512-C/fGU2E8ToujUivIO0H+tpQ6HWo4eEmchoPIoXtxCrVghxdKq+QOHqEZW7tuP3KlV3bC8FRMO5nMCC7Zm1VP6g==" + "resolved" "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.12.tgz" + "version" "2.6.12" dependencies: "whatwg-url" "^5.0.0" diff --git a/dashboard/src/pages/Webhooks.tsx b/dashboard/src/pages/Webhooks.tsx index a4484b5..9d90378 100644 --- a/dashboard/src/pages/Webhooks.tsx +++ b/dashboard/src/pages/Webhooks.tsx @@ -118,7 +118,6 @@ const Webhooks = () => { useEffect(() => { fetchWebookData(); }, [paginationProps.page, paginationProps.limit]); - console.log({ webhookData }); return ( diff --git a/dashboard/yarn.lock b/dashboard/yarn.lock index 79e25d9..e1239de 100644 --- a/dashboard/yarn.lock +++ b/dashboard/yarn.lock @@ -1222,9 +1222,9 @@ dependencies: "is-arrayish" "^0.2.1" -"esbuild-linux-64@0.14.9": - "integrity" "sha512-WoEI+R6/PLZAxS7XagfQMFgRtLUi5cjqqU9VCfo3tnWmAXh/wt8QtUfCVVCcXVwZLS/RNvI19CtfjlrJU61nOg==" - "resolved" "https://registry.npmjs.org/esbuild-linux-64/-/esbuild-linux-64-0.14.9.tgz" +"esbuild-darwin-arm64@0.14.9": + "integrity" "sha512-3ue+1T4FR5TaAu4/V1eFMG8Uwn0pgAwQZb/WwL1X78d5Cy8wOVQ67KNH1lsjU+y/9AcwMKZ9x0GGNxBB4a1Rbw==" + "resolved" "https://registry.npmjs.org/esbuild-darwin-arm64/-/esbuild-darwin-arm64-0.14.9.tgz" "version" "0.14.9" "esbuild@^0.14.9": diff --git a/server/constants/auth_methods.go b/server/constants/auth_methods.go index 9e7b9fa..dbe5175 100644 --- a/server/constants/auth_methods.go +++ b/server/constants/auth_methods.go @@ -7,6 +7,8 @@ const ( AuthRecipeMethodMobileBasicAuth = "mobile_basic_auth" // AuthRecipeMethodMagicLinkLogin is the magic_link_login auth method AuthRecipeMethodMagicLinkLogin = "magic_link_login" + // AuthRecipeMethodMobileOTP is the mobile_otp auth method + AuthRecipeMethodMobileOTP = "mobile_otp" // AuthRecipeMethodGoogle is the google auth method AuthRecipeMethodGoogle = "google" // AuthRecipeMethodGithub is the github auth method diff --git a/server/constants/cookie.go b/server/constants/cookie.go index 71320a9..8f6399b 100644 --- a/server/constants/cookie.go +++ b/server/constants/cookie.go @@ -5,4 +5,6 @@ const ( AppCookieName = "cookie" // AdminCookieName is the name of the cookie that is used to store the admin token AdminCookieName = "authorizer-admin" + // MfaCookieName is the name of the cookie that is used to store the mfa session + MfaCookieName = "mfa" ) diff --git a/server/constants/env.go b/server/constants/env.go index d8382fb..74b9d36 100644 --- a/server/constants/env.go +++ b/server/constants/env.go @@ -66,6 +66,8 @@ const ( EnvKeySenderName = "SENDER_NAME" // EnvKeyIsEmailServiceEnabled key for env variable IS_EMAIL_SERVICE_ENABLED EnvKeyIsEmailServiceEnabled = "IS_EMAIL_SERVICE_ENABLED" + // EnvKeyIsSMSServiceEnabled key for env variable IS_SMS_SERVICE_ENABLED + EnvKeyIsSMSServiceEnabled = "IS_SMS_SERVICE_ENABLED" // EnvKeyAppCookieSecure key for env variable APP_COOKIE_SECURE EnvKeyAppCookieSecure = "APP_COOKIE_SECURE" // EnvKeyAdminCookieSecure key for env variable ADMIN_COOKIE_SECURE @@ -158,6 +160,9 @@ const ( // EnvKeyDisableMultiFactorAuthentication is key for env variable DISABLE_MULTI_FACTOR_AUTHENTICATION // this variable is used to completely disable multi factor authentication. It will have no effect on profile preference EnvKeyDisableMultiFactorAuthentication = "DISABLE_MULTI_FACTOR_AUTHENTICATION" + // EnvKeyDisablePhoneVerification is key for env variable DISABLE_PHONE_VERIFICATION + // this variable is used to disable phone verification + EnvKeyDisablePhoneVerification = "DISABLE_PHONE_VERIFICATION" // Slice variables // EnvKeyRoles key for env variable ROLES @@ -177,12 +182,13 @@ const ( // This env is used for setting default response mode in authorize handler EnvKeyDefaultAuthorizeResponseMode = "DEFAULT_AUTHORIZE_RESPONSE_MODE" - // Phone verification setting - EnvKeyDisablePhoneVerification = "DISABLE_PHONE_VERIFICATION" - // Twilio env variables + // EnvKeyTwilioAPIKey key for env variable TWILIO_API_KEY EnvKeyTwilioAPIKey = "TWILIO_API_KEY" + // EnvKeyTwilioAPISecret key for env variable TWILIO_API_SECRET EnvKeyTwilioAPISecret = "TWILIO_API_SECRET" + // EnvKeyTwilioAccountSID key for env variable TWILIO_ACCOUNT_SID EnvKeyTwilioAccountSID = "TWILIO_ACCOUNT_SID" - EnvKeyTwilioSenderFrom = "TWILIO_SENDER_FROM" + // EnvKeyTwilioSender key for env variable TWILIO_SENDER + EnvKeyTwilioSender = "TWILIO_SENDER" ) diff --git a/server/cookie/mfa_session.go b/server/cookie/mfa_session.go new file mode 100644 index 0000000..3fdcaac --- /dev/null +++ b/server/cookie/mfa_session.go @@ -0,0 +1,89 @@ +package cookie + +import ( + "net/http" + "net/url" + + log "github.com/sirupsen/logrus" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" + "github.com/gin-gonic/gin" +) + +// SetMfaSession sets the mfa session cookie in the response +func SetMfaSession(gc *gin.Context, sessionID string) { + appCookieSecure, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyAppCookieSecure) + if err != nil { + log.Debug("Error while getting app cookie secure from env variable: %v", err) + appCookieSecure = true + } + + secure := appCookieSecure + httpOnly := appCookieSecure + hostname := parsers.GetHost(gc) + host, _ := parsers.GetHostParts(hostname) + domain := parsers.GetDomainName(hostname) + if domain != "localhost" { + domain = "." + domain + } + + // Since app cookie can come from cross site it becomes important to set this in lax mode when insecure. + // Example person using custom UI on their app domain and making request to authorizer domain. + // For more information check: + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie/SameSite + // https://github.com/gin-gonic/gin/blob/master/context.go#L86 + // TODO add ability to sameSite = none / strict from dashboard + if !appCookieSecure { + gc.SetSameSite(http.SameSiteLaxMode) + } else { + gc.SetSameSite(http.SameSiteNoneMode) + } + // TODO allow configuring from dashboard + age := 60 + + gc.SetCookie(constants.MfaCookieName+"_session", sessionID, age, "/", host, secure, httpOnly) + gc.SetCookie(constants.MfaCookieName+"_session_domain", sessionID, age, "/", domain, secure, httpOnly) +} + +// DeleteMfaSession deletes the mfa session cookies to expire +func DeleteMfaSession(gc *gin.Context) { + appCookieSecure, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyAppCookieSecure) + if err != nil { + log.Debug("Error while getting app cookie secure from env variable: %v", err) + appCookieSecure = true + } + + secure := appCookieSecure + httpOnly := appCookieSecure + hostname := parsers.GetHost(gc) + host, _ := parsers.GetHostParts(hostname) + domain := parsers.GetDomainName(hostname) + if domain != "localhost" { + domain = "." + domain + } + + gc.SetSameSite(http.SameSiteNoneMode) + gc.SetCookie(constants.MfaCookieName+"_session", "", -1, "/", host, secure, httpOnly) + gc.SetCookie(constants.MfaCookieName+"_session_domain", "", -1, "/", domain, secure, httpOnly) +} + +// GetMfaSession gets the mfa session cookie from context +func GetMfaSession(gc *gin.Context) (string, error) { + var cookie *http.Cookie + var err error + cookie, err = gc.Request.Cookie(constants.MfaCookieName + "_session") + if err != nil { + cookie, err = gc.Request.Cookie(constants.MfaCookieName + "_session_domain") + if err != nil { + return "", err + } + } + + decodedValue, err := url.PathUnescape(cookie.Value) + if err != nil { + return "", err + } + return decodedValue, nil +} diff --git a/server/db/models/model.go b/server/db/models/model.go index a0d5763..5061c41 100644 --- a/server/db/models/model.go +++ b/server/db/models/model.go @@ -2,14 +2,14 @@ package models // Collections / Tables available for authorizer in the database type CollectionList struct { - User string - VerificationRequest string - Session string - Env string - Webhook string - WebhookLog string - EmailTemplate string - OTP string + User string + VerificationRequest string + Session string + Env string + Webhook string + WebhookLog string + EmailTemplate string + OTP string SMSVerificationRequest string } @@ -18,14 +18,14 @@ var ( Prefix = "authorizer_" // Collections / Tables available for authorizer in the database (used for dbs other than gorm) Collections = CollectionList{ - User: Prefix + "users", - VerificationRequest: Prefix + "verification_requests", - Session: Prefix + "sessions", - Env: Prefix + "env", - Webhook: Prefix + "webhooks", - WebhookLog: Prefix + "webhook_logs", - EmailTemplate: Prefix + "email_templates", - OTP: Prefix + "otps", - SMSVerificationRequest: Prefix + "sms_verification_requests", + User: Prefix + "users", + VerificationRequest: Prefix + "verification_requests", + Session: Prefix + "sessions", + Env: Prefix + "env", + Webhook: Prefix + "webhooks", + WebhookLog: Prefix + "webhook_logs", + EmailTemplate: Prefix + "email_templates", + OTP: Prefix + "otps", + SMSVerificationRequest: Prefix + "sms_verification_requests", } ) diff --git a/server/db/models/otp.go b/server/db/models/otp.go index ac9732b..bd0b41c 100644 --- a/server/db/models/otp.go +++ b/server/db/models/otp.go @@ -1,14 +1,22 @@ package models +const ( + // FieldName email is the field name for email + FieldNameEmail = "email" + // FieldNamePhoneNumber is the field name for phone number + FieldNamePhoneNumber = "phone_number" +) + // OTP model for database type OTP struct { - Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb - ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` - Email string `gorm:"unique" json:"email" bson:"email" cql:"email" dynamo:"email" index:"email,hash"` - Otp string `json:"otp" bson:"otp" cql:"otp" dynamo:"otp"` - ExpiresAt int64 `json:"expires_at" bson:"expires_at" cql:"expires_at" dynamo:"expires_at"` - CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at" dynamo:"created_at"` - UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at" dynamo:"updated_at"` + Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb + ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` + Email string `gorm:"unique" json:"email" bson:"email" cql:"email" dynamo:"email" index:"email,hash"` + PhoneNumber string `gorm:"index:unique_index_phone_number,unique" json:"phone_number" bson:"phone_number" cql:"phone_number" dynamo:"phone_number"` + Otp string `json:"otp" bson:"otp" cql:"otp" dynamo:"otp"` + ExpiresAt int64 `json:"expires_at" bson:"expires_at" cql:"expires_at" dynamo:"expires_at"` + CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at" dynamo:"created_at"` + UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at" dynamo:"updated_at"` } type Paging struct { diff --git a/server/db/models/sms_verification_requests.go b/server/db/models/sms_verification_requests.go deleted file mode 100644 index 2a70d5e..0000000 --- a/server/db/models/sms_verification_requests.go +++ /dev/null @@ -1,11 +0,0 @@ -package models - -// SMS verification requests model for database -type SMSVerificationRequest struct { - ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` - PhoneNumber string `gorm:"unique" json:"phone_number" bson:"phone_number" cql:"phone_number" dynamo:"phone_number" index:"phone_number,hash"` - Code string `json:"code" bson:"code" cql:"code" dynamo:"code"` - CodeExpiresAt int64 `json:"code_expires_at" bson:"code_expires_at" cql:"code_expires_at" dynamo:"code_expires_at"` - CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at" dynamo:"created_at"` - UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at" dynamo:"updated_at"` -} diff --git a/server/db/models/user.go b/server/db/models/user.go index 4628359..a262823 100644 --- a/server/db/models/user.go +++ b/server/db/models/user.go @@ -33,12 +33,14 @@ type User struct { IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled" bson:"is_multi_factor_auth_enabled" cql:"is_multi_factor_auth_enabled" dynamo:"is_multi_factor_auth_enabled"` UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at" dynamo:"updated_at"` CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at" dynamo:"created_at"` + AppData *string `json:"app_data" bson:"app_data" cql:"app_data" dynamo:"app_data"` } func (user *User) AsAPIUser() *model.User { isEmailVerified := user.EmailVerifiedAt != nil isPhoneVerified := user.PhoneNumberVerifiedAt != nil - + appDataMap := make(map[string]interface{}) + json.Unmarshal([]byte(refs.StringValue(user.AppData)), &appDataMap) // id := user.ID // if strings.Contains(id, Collections.User+"/") { // id = strings.TrimPrefix(id, Collections.User+"/") @@ -63,6 +65,7 @@ func (user *User) AsAPIUser() *model.User { IsMultiFactorAuthEnabled: user.IsMultiFactorAuthEnabled, CreatedAt: refs.NewInt64Ref(user.CreatedAt), UpdatedAt: refs.NewInt64Ref(user.UpdatedAt), + AppData: appDataMap, } } diff --git a/server/db/providers/arangodb/email_template.go b/server/db/providers/arangodb/email_template.go index 8134cbe..30c0fd0 100644 --- a/server/db/providers/arangodb/email_template.go +++ b/server/db/providers/arangodb/email_template.go @@ -12,16 +12,14 @@ import ( ) // AddEmailTemplate to add EmailTemplate -func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() emailTemplate.Key = emailTemplate.ID } - emailTemplate.Key = emailTemplate.ID emailTemplate.CreatedAt = time.Now().Unix() emailTemplate.UpdatedAt = time.Now().Unix() - emailTemplateCollection, _ := p.db.Collection(ctx, models.Collections.EmailTemplate) _, err := emailTemplateCollection.CreateDocument(ctx, emailTemplate) if err != nil { @@ -31,74 +29,63 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.Em } // UpdateEmailTemplate to update EmailTemplate -func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { emailTemplate.UpdatedAt = time.Now().Unix() - emailTemplateCollection, _ := p.db.Collection(ctx, models.Collections.EmailTemplate) meta, err := emailTemplateCollection.UpdateDocument(ctx, emailTemplate.Key, emailTemplate) if err != nil { return nil, err } - emailTemplate.Key = meta.Key emailTemplate.ID = meta.ID.String() return emailTemplate.AsAPIEmailTemplate(), nil } // ListEmailTemplates to list EmailTemplate -func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { +func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) { emailTemplates := []*model.EmailTemplate{} - query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.EmailTemplate, pagination.Offset, pagination.Limit) - sctx := arangoDriver.WithQueryFullCount(ctx) cursor, err := p.db.Query(sctx, query, nil) if err != nil { return nil, err } defer cursor.Close() - paginationClone := pagination paginationClone.Total = cursor.Statistics().FullCount() - for { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate meta, err := cursor.ReadDocument(ctx, &emailTemplate) - if arangoDriver.IsNoMoreDocuments(err) { break } else if err != nil { return nil, err } - if meta.Key != "" { emailTemplates = append(emailTemplates, emailTemplate.AsAPIEmailTemplate()) } } - return &model.EmailTemplates{ - Pagination: &paginationClone, + Pagination: paginationClone, EmailTemplates: emailTemplates, }, nil } // GetEmailTemplateByID to get EmailTemplate by id func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate query := fmt.Sprintf("FOR d in %s FILTER d._key == @email_template_id RETURN d", models.Collections.EmailTemplate) bindVars := map[string]interface{}{ "email_template_id": emailTemplateID, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return nil, err } defer cursor.Close() - for { if !cursor.HasMore() { - if emailTemplate.Key == "" { + if emailTemplate == nil { return nil, fmt.Errorf("email template not found") } break @@ -113,21 +100,19 @@ func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID str // GetEmailTemplateByEventName to get EmailTemplate by event_name func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate query := fmt.Sprintf("FOR d in %s FILTER d.event_name == @event_name RETURN d", models.Collections.EmailTemplate) bindVars := map[string]interface{}{ "event_name": eventName, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return nil, err } defer cursor.Close() - for { if !cursor.HasMore() { - if emailTemplate.Key == "" { + if emailTemplate == nil { return nil, fmt.Errorf("email template not found") } break diff --git a/server/db/providers/arangodb/env.go b/server/db/providers/arangodb/env.go index 29687a8..bb4610a 100644 --- a/server/db/providers/arangodb/env.go +++ b/server/db/providers/arangodb/env.go @@ -12,7 +12,7 @@ import ( ) // AddEnv to save environment information in database -func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) { if env.ID == "" { env.ID = uuid.New().String() env.Key = env.ID @@ -31,7 +31,7 @@ func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, erro } // UpdateEnv to update environment information in database -func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) { env.UpdatedAt = time.Now().Unix() collection, _ := p.db.Collection(ctx, models.Collections.Env) meta, err := collection.UpdateDocument(ctx, env.Key, env) @@ -45,19 +45,17 @@ func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, e } // GetEnv to get environment information from database -func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { - var env models.Env +func (p *provider) GetEnv(ctx context.Context) (*models.Env, error) { + var env *models.Env query := fmt.Sprintf("FOR d in %s RETURN d", models.Collections.Env) - cursor, err := p.db.Query(ctx, query, nil) if err != nil { return env, err } defer cursor.Close() - for { if !cursor.HasMore() { - if env.Key == "" { + if env == nil { return env, fmt.Errorf("config not found") } break diff --git a/server/db/providers/arangodb/otp.go b/server/db/providers/arangodb/otp.go index 29f265a..3f8f464 100644 --- a/server/db/providers/arangodb/otp.go +++ b/server/db/providers/arangodb/otp.go @@ -2,6 +2,7 @@ package arangodb import ( "context" + "errors" "fmt" "time" @@ -12,27 +13,39 @@ import ( // UpsertOTP to add or update otp func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { - otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) + // check if email or phone number is present + if otpParam.Email == "" && otpParam.PhoneNumber == "" { + return nil, errors.New("email or phone_number is required") + } + uniqueField := models.FieldNameEmail + if otpParam.Email == "" && otpParam.PhoneNumber != "" { + uniqueField = models.FieldNamePhoneNumber + } + var otp *models.OTP + if uniqueField == models.FieldNameEmail { + otp, _ = p.GetOTPByEmail(ctx, otpParam.Email) + } else { + otp, _ = p.GetOTPByPhoneNumber(ctx, otpParam.PhoneNumber) + } shouldCreate := false if otp == nil { id := uuid.NewString() otp = &models.OTP{ - ID: id, - Key: id, - Otp: otpParam.Otp, - Email: otpParam.Email, - ExpiresAt: otpParam.ExpiresAt, - CreatedAt: time.Now().Unix(), + ID: id, + Key: id, + Otp: otpParam.Otp, + Email: otpParam.Email, + PhoneNumber: otpParam.PhoneNumber, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), } shouldCreate = true } else { otp.Otp = otpParam.Otp otp.ExpiresAt = otpParam.ExpiresAt } - otp.UpdatedAt = time.Now().Unix() otpCollection, _ := p.db.Collection(ctx, models.Collections.OTP) - var meta driver.DocumentMeta var err error if shouldCreate { @@ -40,11 +53,9 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models } else { meta, err = otpCollection.UpdateDocument(ctx, otp.Key, otp) } - if err != nil { return nil, err } - otp.Key = meta.Key otp.ID = meta.ID.String() return otp, nil @@ -52,22 +63,20 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { - var otp models.OTP + var otp *models.OTP query := fmt.Sprintf("FOR d in %s FILTER d.email == @email RETURN d", models.Collections.OTP) bindVars := map[string]interface{}{ "email": emailAddress, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return nil, err } defer cursor.Close() - for { if !cursor.HasMore() { - if otp.Key == "" { - return nil, fmt.Errorf("email template not found") + if otp == nil { + return nil, fmt.Errorf("otp with given email not found") } break } @@ -76,8 +85,34 @@ func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*mod return nil, err } } + return otp, nil +} - return &otp, nil +// GetOTPByPhoneNumber to get otp for a given phone number +func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) { + var otp *models.OTP + query := fmt.Sprintf("FOR d in %s FILTER d.phone_number == @phone_number RETURN d", models.Collections.OTP) + bindVars := map[string]interface{}{ + "phone_number": phoneNumber, + } + cursor, err := p.db.Query(ctx, query, bindVars) + if err != nil { + return nil, err + } + defer cursor.Close() + for { + if !cursor.HasMore() { + if otp == nil { + return nil, fmt.Errorf("otp with given phone_number not found") + } + break + } + _, err := cursor.ReadDocument(ctx, &otp) + if err != nil { + return nil, err + } + } + return otp, nil } // DeleteOTP to delete otp @@ -87,6 +122,5 @@ func (p *provider) DeleteOTP(ctx context.Context, otp *models.OTP) error { if err != nil { return err } - return nil } diff --git a/server/db/providers/arangodb/provider.go b/server/db/providers/arangodb/provider.go index 5488428..507a938 100644 --- a/server/db/providers/arangodb/provider.go +++ b/server/db/providers/arangodb/provider.go @@ -61,7 +61,6 @@ func NewProvider() (*provider, error) { if err != nil { return nil, err } - var arangodb arangoDriver.Database dbName := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseName arangodb_exists, err := arangoClient.DatabaseExists(ctx, dbName) @@ -79,7 +78,6 @@ func NewProvider() (*provider, error) { return nil, err } } - userCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.User) if err != nil { return nil, err @@ -113,7 +111,6 @@ func NewProvider() (*provider, error) { return nil, err } } - verificationRequestCollection, err := arangodb.Collection(ctx, models.Collections.VerificationRequest) if err != nil { return nil, err @@ -136,7 +133,6 @@ func NewProvider() (*provider, error) { return nil, err } } - sessionCollection, err := arangodb.Collection(ctx, models.Collections.Session) if err != nil { return nil, err @@ -144,7 +140,6 @@ func NewProvider() (*provider, error) { sessionCollection.EnsureHashIndex(ctx, []string{"user_id"}, &arangoDriver.EnsureHashIndexOptions{ Sparse: true, }) - envCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Env) if err != nil { return nil, err @@ -155,7 +150,6 @@ func NewProvider() (*provider, error) { return nil, err } } - webhookCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.Webhook) if err != nil { return nil, err @@ -166,7 +160,6 @@ func NewProvider() (*provider, error) { return nil, err } } - webhookCollection, err := arangodb.Collection(ctx, models.Collections.Webhook) if err != nil { return nil, err @@ -186,7 +179,6 @@ func NewProvider() (*provider, error) { return nil, err } } - webhookLogCollection, err := arangodb.Collection(ctx, models.Collections.WebhookLog) if err != nil { return nil, err @@ -194,7 +186,6 @@ func NewProvider() (*provider, error) { webhookLogCollection.EnsureHashIndex(ctx, []string{"webhook_id"}, &arangoDriver.EnsureHashIndexOptions{ Sparse: true, }) - emailTemplateCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.EmailTemplate) if err != nil { return nil, err @@ -205,7 +196,6 @@ func NewProvider() (*provider, error) { return nil, err } } - emailTemplateCollection, err := arangodb.Collection(ctx, models.Collections.EmailTemplate) if err != nil { return nil, err @@ -214,7 +204,6 @@ func NewProvider() (*provider, error) { Unique: true, Sparse: true, }) - otpCollectionExists, err := arangodb.CollectionExists(ctx, models.Collections.OTP) if err != nil { return nil, err @@ -225,16 +214,14 @@ func NewProvider() (*provider, error) { return nil, err } } - otpCollection, err := arangodb.Collection(ctx, models.Collections.OTP) if err != nil { return nil, err } - otpCollection.EnsureHashIndex(ctx, []string{"email"}, &arangoDriver.EnsureHashIndexOptions{ + otpCollection.EnsureHashIndex(ctx, []string{models.FieldNameEmail, models.FieldNamePhoneNumber}, &arangoDriver.EnsureHashIndexOptions{ Unique: true, Sparse: true, }) - return &provider{ db: arangodb, }, err diff --git a/server/db/providers/arangodb/session.go b/server/db/providers/arangodb/session.go index 9bc46ca..5dc981d 100644 --- a/server/db/providers/arangodb/session.go +++ b/server/db/providers/arangodb/session.go @@ -9,12 +9,11 @@ import ( ) // AddSession to save session information in database -func (p *provider) AddSession(ctx context.Context, session models.Session) error { +func (p *provider) AddSession(ctx context.Context, session *models.Session) error { if session.ID == "" { session.ID = uuid.New().String() session.Key = session.ID } - session.CreatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix() sessionCollection, _ := p.db.Collection(ctx, models.Collections.Session) @@ -24,3 +23,8 @@ func (p *provider) AddSession(ctx context.Context, session models.Session) error } return nil } + +// DeleteSession to delete session information from database +func (p *provider) DeleteSession(ctx context.Context, userId string) error { + return nil +} diff --git a/server/db/providers/arangodb/sms_verification_requests.go b/server/db/providers/arangodb/sms_verification_requests.go deleted file mode 100644 index 4dee5bd..0000000 --- a/server/db/providers/arangodb/sms_verification_requests.go +++ /dev/null @@ -1,23 +0,0 @@ -package arangodb - -import ( - "context" - - "github.com/authorizerdev/authorizer/server/db/models" - -) - -// SMS verification Request -func (p *provider) UpsertSMSRequest(ctx context.Context, sms_code *models.SMSVerificationRequest) (*models.SMSVerificationRequest, error) { - return sms_code, nil -} - -func (p *provider) GetCodeByPhone(ctx context.Context, phoneNumber string) (*models.SMSVerificationRequest, error) { - var sms_verification_request models.SMSVerificationRequest - - return &sms_verification_request, nil -} - -func(p *provider) DeleteSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) error { - return nil -} diff --git a/server/db/providers/arangodb/user.go b/server/db/providers/arangodb/user.go index cccbf94..926cdb9 100644 --- a/server/db/providers/arangodb/user.go +++ b/server/db/providers/arangodb/user.go @@ -18,7 +18,7 @@ import ( ) // AddUser to save user information in database -func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) { if user.ID == "" { user.ID = uuid.New().String() user.Key = user.ID @@ -52,7 +52,7 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, } // UpdateUser to update user information in database -func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) { user.UpdatedAt = time.Now().Unix() collection, _ := p.db.Collection(ctx, models.Collections.User) @@ -67,13 +67,12 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use } // DeleteUser to delete user information from database -func (p *provider) DeleteUser(ctx context.Context, user models.User) error { +func (p *provider) DeleteUser(ctx context.Context, user *models.User) error { collection, _ := p.db.Collection(ctx, models.Collections.User) _, err := collection.RemoveDocument(ctx, user.Key) if err != nil { return err } - query := fmt.Sprintf(`FOR d IN %s FILTER d.user_id == @user_id REMOVE { _key: d._key } IN %s`, models.Collections.Session, models.Collections.Session) bindVars := map[string]interface{}{ "user_id": user.Key, @@ -83,65 +82,55 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { return err } defer cursor.Close() - return nil } // ListUsers to get list of users from database -func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { +func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) { var users []*model.User sctx := arangoDriver.WithQueryFullCount(ctx) query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.User, pagination.Offset, pagination.Limit) - cursor, err := p.db.Query(sctx, query, nil) if err != nil { return nil, err } defer cursor.Close() - paginationClone := pagination paginationClone.Total = cursor.Statistics().FullCount() - for { - var user models.User + var user *models.User meta, err := cursor.ReadDocument(ctx, &user) - if arangoDriver.IsNoMoreDocuments(err) { break } else if err != nil { return nil, err } - if meta.Key != "" { users = append(users, user.AsAPIUser()) } } - return &model.Users{ - Pagination: &paginationClone, + Pagination: paginationClone, Users: users, }, nil } // GetUserByEmail to get user information from database using email address -func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { - var user models.User - +func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { + var user *models.User query := fmt.Sprintf("FOR d in %s FILTER d.email == @email RETURN d", models.Collections.User) bindVars := map[string]interface{}{ "email": email, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return user, err } defer cursor.Close() - for { if !cursor.HasMore() { - if user.Key == "" { + if user == nil { return user, fmt.Errorf("user not found") } break @@ -151,28 +140,24 @@ func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.Use return user, err } } - return user, nil } // GetUserByID to get user information from database using user ID -func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { - var user models.User - +func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) { + var user *models.User query := fmt.Sprintf("FOR d in %s FILTER d._id == @id LIMIT 1 RETURN d", models.Collections.User) bindVars := map[string]interface{}{ "id": id, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return user, err } defer cursor.Close() - for { if !cursor.HasMore() { - if user.Key == "" { + if user == nil { return user, fmt.Errorf("user not found") } break @@ -182,7 +167,6 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err return user, err } } - return user, nil } @@ -191,12 +175,10 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { // set updated_at time for all users data["updated_at"] = time.Now().Unix() - userInfoBytes, err := json.Marshal(data) if err != nil { return err } - query := "" if len(ids) > 0 { keysArray := "" @@ -209,33 +191,28 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, } else { query = fmt.Sprintf("FOR u IN %s UPDATE u._key with %s IN %s", models.Collections.User, string(userInfoBytes), models.Collections.User) } - _, err = p.db.Query(ctx, query, nil) if err != nil { return err } - return nil } // GetUserByPhoneNumber to get user information from database using phone number func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*models.User, error) { - var user models.User - + var user *models.User query := fmt.Sprintf("FOR d in %s FILTER d.phone_number == @phone_number RETURN d", models.Collections.User) bindVars := map[string]interface{}{ "phone_number": phoneNumber, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return nil, err } defer cursor.Close() - for { if !cursor.HasMore() { - if user.Key == "" { + if user == nil { return nil, fmt.Errorf("user not found") } break @@ -245,6 +222,5 @@ func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) return nil, err } } - - return &user, nil + return user, nil } diff --git a/server/db/providers/arangodb/verification_requests.go b/server/db/providers/arangodb/verification_requests.go index f69bcb0..05a8186 100644 --- a/server/db/providers/arangodb/verification_requests.go +++ b/server/db/providers/arangodb/verification_requests.go @@ -12,12 +12,11 @@ import ( ) // AddVerification to save verification request in database -func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { +func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) { if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() verificationRequest.Key = verificationRequest.ID } - verificationRequest.CreatedAt = time.Now().Unix() verificationRequest.UpdatedAt = time.Now().Unix() verificationRequestRequestCollection, _ := p.db.Collection(ctx, models.Collections.VerificationRequest) @@ -27,27 +26,24 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque } verificationRequest.Key = meta.Key verificationRequest.ID = meta.ID.String() - return verificationRequest, nil } // GetVerificationRequestByToken to get verification request from database using token -func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest +func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest query := fmt.Sprintf("FOR d in %s FILTER d.token == @token LIMIT 1 RETURN d", models.Collections.VerificationRequest) bindVars := map[string]interface{}{ "token": token, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return verificationRequest, err } defer cursor.Close() - for { if !cursor.HasMore() { - if verificationRequest.Key == "" { + if verificationRequest == nil { return verificationRequest, fmt.Errorf("verification request not found") } break @@ -57,29 +53,25 @@ func (p *provider) GetVerificationRequestByToken(ctx context.Context, token stri return verificationRequest, err } } - return verificationRequest, nil } // GetVerificationRequestByEmail to get verification request by email from database -func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest - +func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest query := fmt.Sprintf("FOR d in %s FILTER d.email == @email FILTER d.identifier == @identifier LIMIT 1 RETURN d", models.Collections.VerificationRequest) bindVars := map[string]interface{}{ "email": email, "identifier": identifier, } - cursor, err := p.db.Query(ctx, query, bindVars) if err != nil { return verificationRequest, err } defer cursor.Close() - for { if !cursor.HasMore() { - if verificationRequest.Key == "" { + if verificationRequest == nil { return verificationRequest, fmt.Errorf("verification request not found") } break @@ -89,27 +81,23 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri return verificationRequest, err } } - return verificationRequest, nil } // ListVerificationRequests to get list of verification requests from database -func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) { +func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) { var verificationRequests []*model.VerificationRequest sctx := arangoDriver.WithQueryFullCount(ctx) query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.VerificationRequest, pagination.Offset, pagination.Limit) - cursor, err := p.db.Query(sctx, query, nil) if err != nil { return nil, err } defer cursor.Close() - paginationClone := pagination paginationClone.Total = cursor.Statistics().FullCount() - for { - var verificationRequest models.VerificationRequest + var verificationRequest *models.VerificationRequest meta, err := cursor.ReadDocument(ctx, &verificationRequest) if arangoDriver.IsNoMoreDocuments(err) { @@ -123,15 +111,14 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode } } - return &model.VerificationRequests{ VerificationRequests: verificationRequests, - Pagination: &paginationClone, + Pagination: paginationClone, }, nil } // DeleteVerificationRequest to delete verification request from database -func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error { +func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error { collection, _ := p.db.Collection(ctx, models.Collections.VerificationRequest) _, err := collection.RemoveDocument(ctx, verificationRequest.Key) if err != nil { diff --git a/server/db/providers/arangodb/webhook.go b/server/db/providers/arangodb/webhook.go index 73cefad..dbdc9e4 100644 --- a/server/db/providers/arangodb/webhook.go +++ b/server/db/providers/arangodb/webhook.go @@ -14,7 +14,7 @@ import ( ) // AddWebhook to add webhook -func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { if webhook.ID == "" { webhook.ID = uuid.New().String() webhook.Key = webhook.ID @@ -33,7 +33,7 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod } // UpdateWebhook to update webhook -func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() // Event is changed if !strings.Contains(webhook.EventName, "-") { @@ -50,11 +50,9 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } // ListWebhooks to list webhook -func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { +func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} - query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.Webhook, pagination.Offset, pagination.Limit) - sctx := arangoDriver.WithQueryFullCount(ctx) cursor, err := p.db.Query(sctx, query, nil) if err != nil { @@ -64,9 +62,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) paginationClone := pagination paginationClone.Total = cursor.Statistics().FullCount() for { - var webhook models.Webhook + var webhook *models.Webhook meta, err := cursor.ReadDocument(ctx, &webhook) - if arangoDriver.IsNoMoreDocuments(err) { break } else if err != nil { @@ -79,14 +76,14 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) } return &model.Webhooks{ - Pagination: &paginationClone, + Pagination: paginationClone, Webhooks: webhooks, }, nil } // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { - var webhook models.Webhook + var webhook *models.Webhook query := fmt.Sprintf("FOR d in %s FILTER d._key == @webhook_id RETURN d", models.Collections.Webhook) bindVars := map[string]interface{}{ "webhook_id": webhookID, @@ -98,7 +95,7 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model defer cursor.Close() for { if !cursor.HasMore() { - if webhook.Key == "" { + if webhook == nil { return nil, fmt.Errorf("webhook not found") } break @@ -124,7 +121,7 @@ func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) defer cursor.Close() webhooks := []*model.Webhook{} for { - var webhook models.Webhook + var webhook *models.Webhook if _, err := cursor.ReadDocument(ctx, &webhook); driver.IsNoMoreDocuments(err) { // We're done break diff --git a/server/db/providers/arangodb/webhook_log.go b/server/db/providers/arangodb/webhook_log.go index 42de751..64db2cb 100644 --- a/server/db/providers/arangodb/webhook_log.go +++ b/server/db/providers/arangodb/webhook_log.go @@ -12,12 +12,11 @@ import ( ) // AddWebhookLog to add webhook log -func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { +func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) { if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() webhookLog.Key = webhookLog.ID } - webhookLog.Key = webhookLog.ID webhookLog.CreatedAt = time.Now().Unix() webhookLog.UpdatedAt = time.Now().Unix() @@ -30,46 +29,38 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL } // ListWebhookLogs to list webhook logs -func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { +func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) { webhookLogs := []*model.WebhookLog{} bindVariables := map[string]interface{}{} - query := fmt.Sprintf("FOR d in %s SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit) - if webhookID != "" { query = fmt.Sprintf("FOR d in %s FILTER d.webhook_id == @webhook_id SORT d.created_at DESC LIMIT %d, %d RETURN d", models.Collections.WebhookLog, pagination.Offset, pagination.Limit) bindVariables = map[string]interface{}{ "webhook_id": webhookID, } } - sctx := arangoDriver.WithQueryFullCount(ctx) cursor, err := p.db.Query(sctx, query, bindVariables) if err != nil { return nil, err } defer cursor.Close() - paginationClone := pagination paginationClone.Total = cursor.Statistics().FullCount() - for { - var webhookLog models.WebhookLog + var webhookLog *models.WebhookLog meta, err := cursor.ReadDocument(ctx, &webhookLog) - if arangoDriver.IsNoMoreDocuments(err) { break } else if err != nil { return nil, err } - if meta.Key != "" { webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog()) } } - return &model.WebhookLogs{ - Pagination: &paginationClone, + Pagination: paginationClone, WebhookLogs: webhookLogs, }, nil } diff --git a/server/db/providers/cassandradb/email_template.go b/server/db/providers/cassandradb/email_template.go index 7cb64cb..9adb768 100644 --- a/server/db/providers/cassandradb/email_template.go +++ b/server/db/providers/cassandradb/email_template.go @@ -15,33 +15,28 @@ import ( ) // AddEmailTemplate to add EmailTemplate -func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() } - emailTemplate.Key = emailTemplate.ID emailTemplate.CreatedAt = time.Now().Unix() emailTemplate.UpdatedAt = time.Now().Unix() - existingEmailTemplate, _ := p.GetEmailTemplateByEventName(ctx, emailTemplate.EventName) if existingEmailTemplate != nil { return nil, fmt.Errorf("Email template with %s event_name already exists", emailTemplate.EventName) } - insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, subject, design, template, created_at, updated_at) VALUES ('%s', '%s', '%s','%s','%s', %d, %d)", KeySpace+"."+models.Collections.EmailTemplate, emailTemplate.ID, emailTemplate.EventName, emailTemplate.Subject, emailTemplate.Design, emailTemplate.Template, emailTemplate.CreatedAt, emailTemplate.UpdatedAt) err := p.db.Query(insertQuery).Exec() if err != nil { return nil, err } - return emailTemplate.AsAPIEmailTemplate(), nil } // UpdateEmailTemplate to update EmailTemplate -func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { emailTemplate.UpdatedAt = time.Now().Unix() - bytes, err := json.Marshal(emailTemplate) if err != nil { return nil, err @@ -54,22 +49,18 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models if err != nil { return nil, err } - updateFields := "" for key, value := range emailTemplateMap { if key == "_id" { continue } - if key == "_key" { continue } - if value == nil { updateFields += fmt.Sprintf("%s = null,", key) continue } - valueType := reflect.TypeOf(value) if valueType.Name() == "string" { updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string)) @@ -90,7 +81,7 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models } // ListEmailTemplates to list EmailTemplate -func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { +func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) { emailTemplates := []*model.EmailTemplate{} paginationClone := pagination @@ -120,7 +111,7 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin } return &model.EmailTemplates{ - Pagination: &paginationClone, + Pagination: paginationClone, EmailTemplates: emailTemplates, }, nil } diff --git a/server/db/providers/cassandradb/env.go b/server/db/providers/cassandradb/env.go index 384b539..636f9f4 100644 --- a/server/db/providers/cassandradb/env.go +++ b/server/db/providers/cassandradb/env.go @@ -11,11 +11,10 @@ import ( ) // AddEnv to save environment information in database -func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) { if env.ID == "" { env.ID = uuid.New().String() } - env.CreatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix() insertEnvQuery := fmt.Sprintf("INSERT INTO %s (id, env, hash, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.Env, env.ID, env.EnvData, env.Hash, env.CreatedAt, env.UpdatedAt) @@ -28,9 +27,8 @@ func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, erro } // UpdateEnv to update environment information in database -func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) { env.UpdatedAt = time.Now().Unix() - updateEnvQuery := fmt.Sprintf("UPDATE %s SET env = '%s', updated_at = %d WHERE id = '%s'", KeySpace+"."+models.Collections.Env, env.EnvData, env.UpdatedAt, env.ID) err := p.db.Query(updateEnvQuery).Exec() if err != nil { @@ -40,14 +38,12 @@ func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, e } // GetEnv to get environment information from database -func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { +func (p *provider) GetEnv(ctx context.Context) (*models.Env, error) { var env models.Env - query := fmt.Sprintf("SELECT id, env, hash, created_at, updated_at FROM %s LIMIT 1", KeySpace+"."+models.Collections.Env) err := p.db.Query(query).Consistency(gocql.One).Scan(&env.ID, &env.EnvData, &env.Hash, &env.CreatedAt, &env.UpdatedAt) if err != nil { - return env, err + return nil, err } - - return env, nil + return &env, nil } diff --git a/server/db/providers/cassandradb/otp.go b/server/db/providers/cassandradb/otp.go index bfe481d..e453242 100644 --- a/server/db/providers/cassandradb/otp.go +++ b/server/db/providers/cassandradb/otp.go @@ -2,6 +2,7 @@ package cassandradb import ( "context" + "errors" "fmt" "time" @@ -12,17 +13,31 @@ import ( // UpsertOTP to add or update otp func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { - otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) + // check if email or phone number is present + if otpParam.Email == "" && otpParam.PhoneNumber == "" { + return nil, errors.New("email or phone_number is required") + } + uniqueField := models.FieldNameEmail + if otpParam.Email == "" && otpParam.PhoneNumber != "" { + uniqueField = models.FieldNamePhoneNumber + } + var otp *models.OTP + if uniqueField == models.FieldNameEmail { + otp, _ = p.GetOTPByEmail(ctx, otpParam.Email) + } else { + otp, _ = p.GetOTPByPhoneNumber(ctx, otpParam.PhoneNumber) + } shouldCreate := false if otp == nil { shouldCreate = true otp = &models.OTP{ - ID: uuid.NewString(), - Otp: otpParam.Otp, - Email: otpParam.Email, - ExpiresAt: otpParam.ExpiresAt, - CreatedAt: time.Now().Unix(), - UpdatedAt: time.Now().Unix(), + ID: uuid.NewString(), + Otp: otpParam.Otp, + Email: otpParam.Email, + PhoneNumber: otpParam.PhoneNumber, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), + UpdatedAt: time.Now().Unix(), } } else { otp.Otp = otpParam.Otp @@ -32,7 +47,7 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models otp.UpdatedAt = time.Now().Unix() query := "" if shouldCreate { - query = fmt.Sprintf(`INSERT INTO %s (id, email, otp, expires_at, created_at, updated_at) VALUES ('%s', '%s', '%s', %d, %d, %d)`, KeySpace+"."+models.Collections.OTP, otp.ID, otp.Email, otp.Otp, otp.ExpiresAt, otp.CreatedAt, otp.UpdatedAt) + query = fmt.Sprintf(`INSERT INTO %s (id, email, phone_number, otp, expires_at, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %d, %d, %d)`, KeySpace+"."+models.Collections.OTP, otp.ID, otp.Email, otp.PhoneNumber, otp.Otp, otp.ExpiresAt, otp.CreatedAt, otp.UpdatedAt) } else { query = fmt.Sprintf(`UPDATE %s SET otp = '%s', expires_at = %d, updated_at = %d WHERE id = '%s'`, KeySpace+"."+models.Collections.OTP, otp.Otp, otp.ExpiresAt, otp.UpdatedAt, otp.ID) } @@ -48,8 +63,19 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { var otp models.OTP - query := fmt.Sprintf(`SELECT id, email, otp, expires_at, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.OTP, emailAddress) - err := p.db.Query(query).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt) + query := fmt.Sprintf(`SELECT id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.OTP, emailAddress) + err := p.db.Query(query).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.PhoneNumber, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt) + if err != nil { + return nil, err + } + return &otp, nil +} + +// GetOTPByPhoneNumber to get otp for a given phone number +func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) { + var otp models.OTP + query := fmt.Sprintf(`SELECT id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s WHERE phone_number = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.OTP, phoneNumber) + err := p.db.Query(query).Consistency(gocql.One).Scan(&otp.ID, &otp.Email, &otp.PhoneNumber, &otp.Otp, &otp.ExpiresAt, &otp.CreatedAt, &otp.UpdatedAt) if err != nil { return nil, err } diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go index 1d8fa49..6f1fe6b 100644 --- a/server/db/providers/cassandradb/provider.go +++ b/server/db/providers/cassandradb/provider.go @@ -254,7 +254,19 @@ func NewProvider() (*provider, error) { if err != nil { return nil, err } - + // Add phone_number column to otp table + otpAlterQuery := fmt.Sprintf(`ALTER TABLE %s.%s ADD (phone_number text);`, KeySpace, models.Collections.OTP) + err = session.Query(otpAlterQuery).Exec() + if err != nil { + log.Debug("Failed to alter table as column exists: ", err) + // continue + } + // Add phone number index + otpIndexQueryPhoneNumber := fmt.Sprintf("CREATE INDEX IF NOT EXISTS authorizer_otp_phone_number ON %s.%s (phone_number)", KeySpace, models.Collections.OTP) + err = session.Query(otpIndexQueryPhoneNumber).Exec() + if err != nil { + return nil, err + } return &provider{ db: session, }, err diff --git a/server/db/providers/cassandradb/session.go b/server/db/providers/cassandradb/session.go index e6042ea..bdf205c 100644 --- a/server/db/providers/cassandradb/session.go +++ b/server/db/providers/cassandradb/session.go @@ -10,14 +10,12 @@ import ( ) // AddSession to save session information in database -func (p *provider) AddSession(ctx context.Context, session models.Session) error { +func (p *provider) AddSession(ctx context.Context, session *models.Session) error { if session.ID == "" { session.ID = uuid.New().String() } - session.CreatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix() - insertSessionQuery := fmt.Sprintf("INSERT INTO %s (id, user_id, user_agent, ip, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %d, %d)", KeySpace+"."+models.Collections.Session, session.ID, session.UserID, session.UserAgent, session.IP, session.CreatedAt, session.UpdatedAt) err := p.db.Query(insertSessionQuery).Exec() if err != nil { @@ -25,3 +23,8 @@ func (p *provider) AddSession(ctx context.Context, session models.Session) error } return nil } + +// DeleteSession to delete session information from database +func (p *provider) DeleteSession(ctx context.Context, userId string) error { + return nil +} diff --git a/server/db/providers/cassandradb/sms_verification_requests.go b/server/db/providers/cassandradb/sms_verification_requests.go deleted file mode 100644 index 3c67c1b..0000000 --- a/server/db/providers/cassandradb/sms_verification_requests.go +++ /dev/null @@ -1,23 +0,0 @@ -package cassandradb - -import ( - "context" - - "github.com/authorizerdev/authorizer/server/db/models" - -) - -// SMS verification Request -func (p *provider) UpsertSMSRequest(ctx context.Context, sms_code *models.SMSVerificationRequest) (*models.SMSVerificationRequest, error) { - return sms_code, nil -} - -func (p *provider) GetCodeByPhone(ctx context.Context, phoneNumber string) (*models.SMSVerificationRequest, error) { - var sms_verification_request models.SMSVerificationRequest - - return &sms_verification_request, nil -} - -func(p *provider) DeleteSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) error { - return nil -} diff --git a/server/db/providers/cassandradb/user.go b/server/db/providers/cassandradb/user.go index f8be709..376a6db 100644 --- a/server/db/providers/cassandradb/user.go +++ b/server/db/providers/cassandradb/user.go @@ -18,7 +18,7 @@ import ( ) // AddUser to save user information in database -func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) { if user.ID == "" { user.ID = uuid.New().String() } @@ -77,7 +77,6 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, values = values[:len(values)-1] + ")" query := fmt.Sprintf("INSERT INTO %s %s VALUES %s IF NOT EXISTS", KeySpace+"."+models.Collections.User, fields, values) - err = p.db.Query(query).Exec() if err != nil { return user, err @@ -87,7 +86,7 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, } // UpdateUser to update user information in database -func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) { user.UpdatedAt = time.Now().Unix() bytes, err := json.Marshal(user) @@ -138,13 +137,12 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use } // DeleteUser to delete user information from database -func (p *provider) DeleteUser(ctx context.Context, user models.User) error { +func (p *provider) DeleteUser(ctx context.Context, user *models.User) error { query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.User, user.ID) err := p.db.Query(query).Exec() if err != nil { return err } - getSessionsQuery := fmt.Sprintf("SELECT id FROM %s WHERE user_id = '%s' ALLOW FILTERING", KeySpace+"."+models.Collections.Session, user.ID) scanner := p.db.Query(getSessionsQuery).Iter().Scanner() sessionIDs := "" @@ -167,7 +165,7 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { } // ListUsers to get list of users from database -func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { +func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) { responseUsers := []*model.User{} paginationClone := pagination totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.User) @@ -180,7 +178,6 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( // so we fetch till limit + offset // and return the results from offset to limit query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.User, pagination.Limit+pagination.Offset) - scanner := p.db.Query(query).Iter().Scanner() counter := int64(0) for scanner.Next() { @@ -195,31 +192,31 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( counter++ } return &model.Users{ + Pagination: paginationClone, Users: responseUsers, - Pagination: &paginationClone, }, nil } // GetUserByEmail to get user information from database using email address -func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { +func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { var user models.User query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE email = '%s' LIMIT 1 ALLOW FILTERING", KeySpace+"."+models.Collections.User, email) err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt) if err != nil { - return user, err + return nil, err } - return user, nil + return &user, nil } // GetUserByID to get user information from database using user ID -func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { +func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) { var user models.User query := fmt.Sprintf("SELECT id, email, email_verified_at, password, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1", KeySpace+"."+models.Collections.User, id) err := p.db.Query(query).Consistency(gocql.One).Scan(&user.ID, &user.Email, &user.EmailVerifiedAt, &user.Password, &user.SignupMethods, &user.GivenName, &user.FamilyName, &user.MiddleName, &user.Nickname, &user.Birthdate, &user.PhoneNumber, &user.PhoneNumberVerifiedAt, &user.Picture, &user.Roles, &user.RevokedTimestamp, &user.IsMultiFactorAuthEnabled, &user.CreatedAt, &user.UpdatedAt) if err != nil { - return user, err + return nil, err } - return user, nil + return &user, nil } // UpdateUsers to update multiple users, with parameters of user IDs slice @@ -252,9 +249,8 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, } updateFields = strings.Trim(updateFields, " ") updateFields = strings.TrimSuffix(updateFields, ",") - query := "" - if ids != nil && len(ids) > 0 { + if len(ids) > 0 { idsString := "" for _, id := range ids { idsString += fmt.Sprintf("'%s', ", id) diff --git a/server/db/providers/cassandradb/verification_requests.go b/server/db/providers/cassandradb/verification_requests.go index 3786a2b..aa8e66d 100644 --- a/server/db/providers/cassandradb/verification_requests.go +++ b/server/db/providers/cassandradb/verification_requests.go @@ -12,7 +12,7 @@ import ( ) // AddVerification to save verification request in database -func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { +func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) { if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() } @@ -29,41 +29,39 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque } // GetVerificationRequestByToken to get verification request from database using token -func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) { +func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) { var verificationRequest models.VerificationRequest query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE jwt_token = '%s' LIMIT 1`, KeySpace+"."+models.Collections.VerificationRequest, token) err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) if err != nil { - return verificationRequest, err + return nil, err } - return verificationRequest, nil + return &verificationRequest, nil } // GetVerificationRequestByEmail to get verification request by email from database -func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { +func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) { var verificationRequest models.VerificationRequest query := fmt.Sprintf(`SELECT id, jwt_token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s WHERE email = '%s' AND identifier = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.VerificationRequest, email, identifier) err := p.db.Query(query).Consistency(gocql.One).Scan(&verificationRequest.ID, &verificationRequest.Token, &verificationRequest.Identifier, &verificationRequest.ExpiresAt, &verificationRequest.Email, &verificationRequest.Nonce, &verificationRequest.RedirectURI, &verificationRequest.CreatedAt, &verificationRequest.UpdatedAt) if err != nil { - return verificationRequest, err + return nil, err } - return verificationRequest, nil + return &verificationRequest, nil } // ListVerificationRequests to get list of verification requests from database -func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) { +func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) { var verificationRequests []*model.VerificationRequest - paginationClone := pagination totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.VerificationRequest) err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) if err != nil { return nil, err } - // there is no offset in cassandra // so we fetch till limit + offset // and return the results from offset to limit @@ -85,12 +83,12 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode return &model.VerificationRequests{ VerificationRequests: verificationRequests, - Pagination: &paginationClone, + Pagination: paginationClone, }, nil } // DeleteVerificationRequest to delete verification request from database -func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error { +func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error { query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", KeySpace+"."+models.Collections.VerificationRequest, verificationRequest.ID) err := p.db.Query(query).Exec() if err != nil { diff --git a/server/db/providers/cassandradb/webhook.go b/server/db/providers/cassandradb/webhook.go index cb50f08..e80dfdd 100644 --- a/server/db/providers/cassandradb/webhook.go +++ b/server/db/providers/cassandradb/webhook.go @@ -15,7 +15,7 @@ import ( ) // AddWebhook to add webhook -func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { if webhook.ID == "" { webhook.ID = uuid.New().String() } @@ -33,7 +33,7 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod } // UpdateWebhook to update webhook -func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() // Event is changed if !strings.Contains(webhook.EventName, "-") { @@ -81,7 +81,7 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } // ListWebhooks to list webhook -func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { +func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} paginationClone := pagination totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.Webhook) @@ -108,7 +108,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) } return &model.Webhooks{ - Pagination: &paginationClone, + Pagination: paginationClone, Webhooks: webhooks, }, nil } diff --git a/server/db/providers/cassandradb/webhook_log.go b/server/db/providers/cassandradb/webhook_log.go index 9ecf939..d587e02 100644 --- a/server/db/providers/cassandradb/webhook_log.go +++ b/server/db/providers/cassandradb/webhook_log.go @@ -12,7 +12,7 @@ import ( ) // AddWebhookLog to add webhook log -func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { +func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) { if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() } @@ -30,7 +30,7 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL } // ListWebhookLogs to list webhook logs -func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { +func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) { webhookLogs := []*model.WebhookLog{} paginationClone := pagination totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.WebhookLog) @@ -38,7 +38,6 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat // so we fetch till limit + offset // and return the results from offset to limit query := fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.WebhookLog, pagination.Limit+pagination.Offset) - if webhookID != "" { totalCountQuery = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE webhook_id='%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.WebhookLog, webhookID) query = fmt.Sprintf("SELECT id, http_status, response, request, webhook_id, created_at, updated_at FROM %s WHERE webhook_id = '%s' LIMIT %d ALLOW FILTERING", KeySpace+"."+models.Collections.WebhookLog, webhookID, pagination.Limit+pagination.Offset) @@ -64,7 +63,7 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat } return &model.WebhookLogs{ - Pagination: &paginationClone, + Pagination: paginationClone, WebhookLogs: webhookLogs, }, nil } diff --git a/server/db/providers/couchbase/email_template.go b/server/db/providers/couchbase/email_template.go index bd37482..14f5ba9 100644 --- a/server/db/providers/couchbase/email_template.go +++ b/server/db/providers/couchbase/email_template.go @@ -15,7 +15,7 @@ import ( ) // AddEmailTemplate to add EmailTemplate -func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() @@ -37,7 +37,7 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.Em } // UpdateEmailTemplate to update EmailTemplate -func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { bytes, err := json.Marshal(emailTemplate) if err != nil { return nil, err @@ -67,7 +67,7 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models } // ListEmailTemplates to list EmailTemplate -func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { +func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) { emailTemplates := []*model.EmailTemplate{} paginationClone := pagination total, err := p.GetTotalDocs(ctx, models.Collections.EmailTemplate) @@ -88,7 +88,7 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin } for queryResult.Next() { - emailTemplate := models.EmailTemplate{} + var emailTemplate *models.EmailTemplate err := queryResult.Row(&emailTemplate) if err != nil { log.Fatal(err) @@ -102,54 +102,46 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin } return &model.EmailTemplates{ - Pagination: &paginationClone, + Pagination: paginationClone, EmailTemplates: emailTemplates, }, nil } // GetEmailTemplateByID to get EmailTemplate by id func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) { - emailTemplate := models.EmailTemplate{} - + var emailTemplate *models.EmailTemplate query := fmt.Sprintf(`SELECT _id, event_name, subject, design, template, created_at, updated_at FROM %s.%s WHERE _id = $1 LIMIT 1`, p.scopeName, models.Collections.EmailTemplate) q, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, PositionalParameters: []interface{}{emailTemplateID}, }) - if err != nil { return nil, err } err = q.One(&emailTemplate) - if err != nil { return nil, err } - return emailTemplate.AsAPIEmailTemplate(), nil } // GetEmailTemplateByEventName to get EmailTemplate by event_name func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) { - emailTemplate := models.EmailTemplate{} - + var emailTemplate models.EmailTemplate query := fmt.Sprintf("SELECT _id, event_name, subject, design, template, created_at, updated_at FROM %s.%s WHERE event_name=$1 LIMIT 1", p.scopeName, models.Collections.EmailTemplate) q, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, PositionalParameters: []interface{}{eventName}, }) - if err != nil { return nil, err } err = q.One(&emailTemplate) - if err != nil { return nil, err } - return emailTemplate.AsAPIEmailTemplate(), nil } diff --git a/server/db/providers/couchbase/env.go b/server/db/providers/couchbase/env.go index 3addb9f..3f24937 100644 --- a/server/db/providers/couchbase/env.go +++ b/server/db/providers/couchbase/env.go @@ -11,7 +11,7 @@ import ( ) // AddEnv to save environment information in database -func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) { if env.ID == "" { env.ID = uuid.New().String() } @@ -19,7 +19,6 @@ func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, erro env.UpdatedAt = time.Now().Unix() env.Key = env.ID env.EncryptionKey = env.Hash - insertOpt := gocb.InsertOptions{ Context: ctx, } @@ -31,7 +30,7 @@ func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, erro } // UpdateEnv to update environment information in database -func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) { env.UpdatedAt = time.Now().Unix() env.EncryptionKey = env.Hash @@ -40,17 +39,15 @@ func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, e Context: ctx, PositionalParameters: []interface{}{env.EnvData, env.UpdatedAt, env.UpdatedAt, env.ID}, }) - if err != nil { return env, err } - return env, nil } // GetEnv to get environment information from database -func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { - var env models.Env +func (p *provider) GetEnv(ctx context.Context) (*models.Env, error) { + var env *models.Env query := fmt.Sprintf("SELECT _id, env, encryption_key, created_at, updated_at FROM %s.%s LIMIT 1", p.scopeName, models.Collections.Env) q, err := p.db.Query(query, &gocb.QueryOptions{ @@ -61,7 +58,6 @@ func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { return env, err } err = q.One(&env) - if err != nil { return env, err } diff --git a/server/db/providers/couchbase/otp.go b/server/db/providers/couchbase/otp.go index cdcfde9..1fe6532 100644 --- a/server/db/providers/couchbase/otp.go +++ b/server/db/providers/couchbase/otp.go @@ -2,6 +2,7 @@ package couchbase import ( "context" + "errors" "fmt" "time" @@ -12,24 +13,36 @@ import ( // UpsertOTP to add or update otp func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { - otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) - + // check if email or phone number is present + if otpParam.Email == "" && otpParam.PhoneNumber == "" { + return nil, errors.New("email or phone_number is required") + } + uniqueField := models.FieldNameEmail + if otpParam.Email == "" && otpParam.PhoneNumber != "" { + uniqueField = models.FieldNamePhoneNumber + } + var otp *models.OTP + if uniqueField == models.FieldNameEmail { + otp, _ = p.GetOTPByEmail(ctx, otpParam.Email) + } else { + otp, _ = p.GetOTPByPhoneNumber(ctx, otpParam.PhoneNumber) + } shouldCreate := false if otp == nil { shouldCreate = true otp = &models.OTP{ - ID: uuid.NewString(), - Otp: otpParam.Otp, - Email: otpParam.Email, - ExpiresAt: otpParam.ExpiresAt, - CreatedAt: time.Now().Unix(), - UpdatedAt: time.Now().Unix(), + ID: uuid.NewString(), + Otp: otpParam.Otp, + Email: otpParam.Email, + PhoneNumber: otpParam.PhoneNumber, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), + UpdatedAt: time.Now().Unix(), } } else { otp.Otp = otpParam.Otp otp.ExpiresAt = otpParam.ExpiresAt } - otp.UpdatedAt = time.Now().Unix() if shouldCreate { insertOpt := gocb.InsertOptions{ @@ -54,7 +67,7 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { otp := models.OTP{} - query := fmt.Sprintf(`SELECT _id, email, otp, expires_at, created_at, updated_at FROM %s.%s WHERE email = $1 LIMIT 1`, p.scopeName, models.Collections.OTP) + query := fmt.Sprintf(`SELECT _id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s.%s WHERE email = $1 LIMIT 1`, p.scopeName, models.Collections.OTP) q, err := p.db.Query(query, &gocb.QueryOptions{ ScanConsistency: gocb.QueryScanConsistencyRequestPlus, PositionalParameters: []interface{}{emailAddress}, @@ -63,11 +76,27 @@ func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*mod return nil, err } err = q.One(&otp) - if err != nil { return nil, err } + return &otp, nil +} +// GetOTPByPhoneNumber to get otp for a given phone number +func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) { + otp := models.OTP{} + query := fmt.Sprintf(`SELECT _id, email, phone_number, otp, expires_at, created_at, updated_at FROM %s.%s WHERE phone_number = $1 LIMIT 1`, p.scopeName, models.Collections.OTP) + q, err := p.db.Query(query, &gocb.QueryOptions{ + ScanConsistency: gocb.QueryScanConsistencyRequestPlus, + PositionalParameters: []interface{}{phoneNumber}, + }) + if err != nil { + return nil, err + } + err = q.One(&otp) + if err != nil { + return nil, err + } return &otp, nil } diff --git a/server/db/providers/couchbase/provider.go b/server/db/providers/couchbase/provider.go index c5d5404..723e47a 100644 --- a/server/db/providers/couchbase/provider.go +++ b/server/db/providers/couchbase/provider.go @@ -166,5 +166,9 @@ func GetIndex(scopeName string) map[string][]string { otpIndex1 := fmt.Sprintf("CREATE INDEX OTPEmailIndex ON %s.%s(email)", scopeName, models.Collections.OTP) indices[models.Collections.OTP] = []string{otpIndex1} + // OTP index + otpIndex2 := fmt.Sprintf("CREATE INDEX OTPPhoneNumberIndex ON %s.%s(phone_number)", scopeName, models.Collections.OTP) + indices[models.Collections.OTP] = []string{otpIndex2} + return indices } diff --git a/server/db/providers/couchbase/session.go b/server/db/providers/couchbase/session.go index 6f0d84f..a3b9915 100644 --- a/server/db/providers/couchbase/session.go +++ b/server/db/providers/couchbase/session.go @@ -10,11 +10,10 @@ import ( ) // AddSession to save session information in database -func (p *provider) AddSession(ctx context.Context, session models.Session) error { +func (p *provider) AddSession(ctx context.Context, session *models.Session) error { if session.ID == "" { session.ID = uuid.New().String() } - session.CreatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix() insertOpt := gocb.InsertOptions{ @@ -24,7 +23,6 @@ func (p *provider) AddSession(ctx context.Context, session models.Session) error if err != nil { return err } - return nil } diff --git a/server/db/providers/couchbase/shared.go b/server/db/providers/couchbase/shared.go index 104ad02..00a8cfa 100644 --- a/server/db/providers/couchbase/shared.go +++ b/server/db/providers/couchbase/shared.go @@ -11,24 +11,19 @@ import ( func GetSetFields(webhookMap map[string]interface{}) (string, map[string]interface{}) { params := make(map[string]interface{}, 1) - updateFields := "" - for key, value := range webhookMap { if key == "_id" { continue } - if key == "_key" { continue } - if value == nil { updateFields += fmt.Sprintf("%s=$%s,", key, key) params[key] = "null" continue } - valueType := reflect.TypeOf(value) if valueType.Name() == "string" { updateFields += fmt.Sprintf("%s = $%s, ", key, key) @@ -46,14 +41,11 @@ func GetSetFields(webhookMap map[string]interface{}) (string, map[string]interfa func (p *provider) GetTotalDocs(ctx context.Context, collection string) (int64, error) { totalDocs := TotalDocs{} - countQuery := fmt.Sprintf("SELECT COUNT(*) as Total FROM %s.%s", p.scopeName, collection) queryRes, err := p.db.Query(countQuery, &gocb.QueryOptions{ Context: ctx, }) - queryRes.One(&totalDocs) - if err != nil { return totalDocs.Total, err } diff --git a/server/db/providers/couchbase/sms_verification_requests.go b/server/db/providers/couchbase/sms_verification_requests.go deleted file mode 100644 index 9201d73..0000000 --- a/server/db/providers/couchbase/sms_verification_requests.go +++ /dev/null @@ -1,23 +0,0 @@ -package couchbase - -import ( - "context" - - "github.com/authorizerdev/authorizer/server/db/models" - -) - -// SMS verification Request -func (p *provider) UpsertSMSRequest(ctx context.Context, sms_code *models.SMSVerificationRequest) (*models.SMSVerificationRequest, error) { - return sms_code, nil -} - -func (p *provider) GetCodeByPhone(ctx context.Context, phoneNumber string) (*models.SMSVerificationRequest, error) { - var sms_verification_request models.SMSVerificationRequest - - return &sms_verification_request, nil -} - -func(p *provider) DeleteSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) error { - return nil -} diff --git a/server/db/providers/couchbase/user.go b/server/db/providers/couchbase/user.go index 2dc813c..f5d2195 100644 --- a/server/db/providers/couchbase/user.go +++ b/server/db/providers/couchbase/user.go @@ -15,7 +15,7 @@ import ( ) // AddUser to save user information in database -func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) { if user.ID == "" { user.ID = uuid.New().String() } @@ -41,7 +41,7 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, } // UpdateUser to update user information in database -func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) { user.UpdatedAt = time.Now().Unix() unsertOpt := gocb.UpsertOptions{ Context: ctx, @@ -54,7 +54,7 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use } // DeleteUser to delete user information from database -func (p *provider) DeleteUser(ctx context.Context, user models.User) error { +func (p *provider) DeleteUser(ctx context.Context, user *models.User) error { removeOpt := gocb.RemoveOptions{ Context: ctx, } @@ -66,12 +66,10 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { } // ListUsers to get list of users from database -func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { +func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) { users := []*model.User{} paginationClone := pagination - userQuery := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s.%s ORDER BY id OFFSET $1 LIMIT $2", p.scopeName, models.Collections.User) - queryResult, err := p.db.Query(userQuery, &gocb.QueryOptions{ ScanConsistency: gocb.QueryScanConsistencyRequestPlus, Context: ctx, @@ -97,21 +95,20 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( return nil, err } return &model.Users{ - Pagination: &paginationClone, + Pagination: paginationClone, Users: users, }, nil } // GetUserByEmail to get user information from database using email address -func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { - user := models.User{} +func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { + var user *models.User query := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s.%s WHERE email = $1 LIMIT 1", p.scopeName, models.Collections.User) q, err := p.db.Query(query, &gocb.QueryOptions{ ScanConsistency: gocb.QueryScanConsistencyRequestPlus, Context: ctx, PositionalParameters: []interface{}{email}, }) - if err != nil { return user, err } @@ -119,13 +116,12 @@ func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.Use if err != nil { return user, err } - return user, nil } // GetUserByID to get user information from database using user ID -func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { - user := models.User{} +func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) { + var user *models.User query := fmt.Sprintf("SELECT _id, email, email_verified_at, `password`, signup_methods, given_name, family_name, middle_name, nickname, birthdate, phone_number, phone_number_verified_at, picture, roles, revoked_timestamp, is_multi_factor_auth_enabled, created_at, updated_at FROM %s.%s WHERE _id = $1 LIMIT 1", p.scopeName, models.Collections.User) q, err := p.db.Query(query, &gocb.QueryOptions{ ScanConsistency: gocb.QueryScanConsistencyRequestPlus, @@ -139,7 +135,6 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err if err != nil { return user, err } - return user, nil } @@ -174,7 +169,6 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, return err } } - return nil } @@ -194,6 +188,5 @@ func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) if err != nil { return user, err } - return user, nil } diff --git a/server/db/providers/couchbase/verification_requests.go b/server/db/providers/couchbase/verification_requests.go index 6971065..314f69a 100644 --- a/server/db/providers/couchbase/verification_requests.go +++ b/server/db/providers/couchbase/verification_requests.go @@ -13,11 +13,10 @@ import ( ) // AddVerification to save verification request in database -func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { +func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) { if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() } - verificationRequest.Key = verificationRequest.ID verificationRequest.CreatedAt = time.Now().Unix() verificationRequest.UpdatedAt = time.Now().Unix() @@ -28,13 +27,12 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque if err != nil { return verificationRequest, err } - return verificationRequest, nil } // GetVerificationRequestByToken to get verification request from database using token -func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) { - verificationRequest := models.VerificationRequest{} +func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest params := make(map[string]interface{}, 1) params["token"] = token query := fmt.Sprintf("SELECT _id, token, identifier, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s.%s WHERE token=$1 LIMIT 1", p.scopeName, models.Collections.VerificationRequest) @@ -57,7 +55,7 @@ func (p *provider) GetVerificationRequestByToken(ctx context.Context, token stri } // GetVerificationRequestByEmail to get verification request by email from database -func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { +func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) { query := fmt.Sprintf("SELECT _id, identifier, token, expires_at, email, nonce, redirect_uri, created_at, updated_at FROM %s.%s WHERE email=$1 AND identifier=$2 LIMIT 1", p.scopeName, models.Collections.VerificationRequest) queryResult, err := p.db.Query(query, &gocb.QueryOptions{ @@ -65,14 +63,11 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri PositionalParameters: []interface{}{email, identifier}, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, }) - verificationRequest := models.VerificationRequest{} - if err != nil { - return verificationRequest, err + return nil, err } - + var verificationRequest *models.VerificationRequest err = queryResult.One(&verificationRequest) - if err != nil { return verificationRequest, err } @@ -80,7 +75,7 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri } // ListVerificationRequests to get list of verification requests from database -func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) { +func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) { var verificationRequests []*model.VerificationRequest paginationClone := pagination total, err := p.GetTotalDocs(ctx, models.Collections.VerificationRequest) @@ -111,12 +106,12 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode } return &model.VerificationRequests{ VerificationRequests: verificationRequests, - Pagination: &paginationClone, + Pagination: paginationClone, }, nil } // DeleteVerificationRequest to delete verification request from database -func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error { +func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error { removeOpt := gocb.RemoveOptions{ Context: ctx, } diff --git a/server/db/providers/couchbase/webhook.go b/server/db/providers/couchbase/webhook.go index 2f51acd..92b0111 100644 --- a/server/db/providers/couchbase/webhook.go +++ b/server/db/providers/couchbase/webhook.go @@ -15,7 +15,7 @@ import ( ) // AddWebhook to add webhook -func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { if webhook.ID == "" { webhook.ID = uuid.New().String() } @@ -35,7 +35,7 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod } // UpdateWebhook to update webhook -func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() // Event is changed if !strings.Contains(webhook.EventName, "-") { @@ -68,7 +68,7 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } // ListWebhooks to list webhook -func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { +func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} paginationClone := pagination params := make(map[string]interface{}, 1) @@ -100,14 +100,14 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) return nil, err } return &model.Webhooks{ - Pagination: &paginationClone, + Pagination: paginationClone, Webhooks: webhooks, }, nil } // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { - var webhook models.Webhook + var webhook *models.Webhook params := make(map[string]interface{}, 1) params["_id"] = webhookID query := fmt.Sprintf(`SELECT _id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, models.Collections.Webhook) @@ -141,7 +141,7 @@ func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) } webhooks := []*model.Webhook{} for queryResult.Next() { - var webhook models.Webhook + var webhook *models.Webhook err := queryResult.Row(&webhook) if err != nil { log.Fatal(err) @@ -162,11 +162,9 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er Context: ctx, } _, err := p.db.Collection(models.Collections.Webhook).Remove(webhook.ID, &removeOpt) - if err != nil { return err } - query := fmt.Sprintf(`DELETE FROM %s.%s WHERE webhook_id=$webhook_id`, p.scopeName, models.Collections.WebhookLog) _, err = p.db.Query(query, &gocb.QueryOptions{ Context: ctx, @@ -176,6 +174,5 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er if err != nil { return err } - return nil } diff --git a/server/db/providers/couchbase/webhook_log.go b/server/db/providers/couchbase/webhook_log.go index 7c4fd15..0482394 100644 --- a/server/db/providers/couchbase/webhook_log.go +++ b/server/db/providers/couchbase/webhook_log.go @@ -13,15 +13,13 @@ import ( ) // AddWebhookLog to add webhook log -func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { +func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) { if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() } - webhookLog.Key = webhookLog.ID webhookLog.CreatedAt = time.Now().Unix() webhookLog.UpdatedAt = time.Now().Unix() - insertOpt := gocb.InsertOptions{ Context: ctx, } @@ -29,19 +27,16 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL if err != nil { return webhookLog.AsAPIWebhookLog(), err } - return webhookLog.AsAPIWebhookLog(), nil } // ListWebhookLogs to list webhook logs -func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { +func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) { var query string var err error - webhookLogs := []*model.WebhookLog{} params := make(map[string]interface{}, 1) paginationClone := pagination - params["webhookID"] = webhookID params["offset"] = paginationClone.Offset params["limit"] = paginationClone.Limit @@ -55,13 +50,11 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat } else { query = fmt.Sprintf("SELECT _id, http_status, response, request, webhook_id, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.WebhookLog) } - queryResult, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, ScanConsistency: gocb.QueryScanConsistencyRequestPlus, NamedParameters: params, }) - if err != nil { return nil, err } @@ -73,13 +66,12 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat } webhookLogs = append(webhookLogs, webhookLog.AsAPIWebhookLog()) } - if err := queryResult.Err(); err != nil { return nil, err } return &model.WebhookLogs{ - Pagination: &paginationClone, + Pagination: paginationClone, WebhookLogs: webhookLogs, }, nil } diff --git a/server/db/providers/dynamodb/email_template.go b/server/db/providers/dynamodb/email_template.go index 08745cf..7355bbb 100644 --- a/server/db/providers/dynamodb/email_template.go +++ b/server/db/providers/dynamodb/email_template.go @@ -12,7 +12,7 @@ import ( ) // AddEmailTemplate to add EmailTemplate -func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { collection := p.db.Table(models.Collections.EmailTemplate) if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() @@ -31,7 +31,7 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.Em } // UpdateEmailTemplate to update EmailTemplate -func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { collection := p.db.Table(models.Collections.EmailTemplate) emailTemplate.UpdatedAt = time.Now().Unix() err := UpdateByHashKey(collection, "id", emailTemplate.ID, emailTemplate) @@ -42,23 +42,19 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models } // ListEmailTemplates to list EmailTemplate -func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { - - var emailTemplate models.EmailTemplate +func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) { + var emailTemplate *models.EmailTemplate var iter dynamo.PagingIter var lastEval dynamo.PagingKey var iteration int64 = 0 - collection := p.db.Table(models.Collections.EmailTemplate) emailTemplates := []*model.EmailTemplate{} paginationClone := pagination scanner := collection.Scan() count, err := scanner.Count() - if err != nil { return nil, err } - for (paginationClone.Offset + paginationClone.Limit) > iteration { iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() for iter.NextWithContext(ctx, &emailTemplate) { @@ -69,11 +65,9 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin lastEval = iter.LastEvaluatedKey() iteration += paginationClone.Limit } - paginationClone.Total = count - return &model.EmailTemplates{ - Pagination: &paginationClone, + Pagination: paginationClone, EmailTemplates: emailTemplates, }, nil } @@ -81,7 +75,7 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin // GetEmailTemplateByID to get EmailTemplate by id func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) { collection := p.db.Table(models.Collections.EmailTemplate) - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate err := collection.Get("id", emailTemplateID).OneWithContext(ctx, &emailTemplate) if err != nil { return nil, err @@ -92,9 +86,8 @@ func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID str // GetEmailTemplateByEventName to get EmailTemplate by event_name func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) { collection := p.db.Table(models.Collections.EmailTemplate) - var emailTemplates []models.EmailTemplate - var emailTemplate models.EmailTemplate - + var emailTemplates []*models.EmailTemplate + var emailTemplate *models.EmailTemplate err := collection.Scan().Index("event_name").Filter("'event_name' = ?", eventName).Limit(1).AllWithContext(ctx, &emailTemplates) if err != nil { return nil, err @@ -112,7 +105,6 @@ func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName st func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model.EmailTemplate) error { collection := p.db.Table(models.Collections.EmailTemplate) err := collection.Delete("id", emailTemplate.ID).RunWithContext(ctx) - if err != nil { return err } diff --git a/server/db/providers/dynamodb/env.go b/server/db/providers/dynamodb/env.go index d491e19..0b356f7 100644 --- a/server/db/providers/dynamodb/env.go +++ b/server/db/providers/dynamodb/env.go @@ -11,34 +11,26 @@ import ( ) // AddEnv to save environment information in database -func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) { collection := p.db.Table(models.Collections.Env) - if env.ID == "" { env.ID = uuid.New().String() } - env.Key = env.ID - env.CreatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix() - err := collection.Put(env).RunWithContext(ctx) - if err != nil { return env, err } - return env, nil } // UpdateEnv to update environment information in database -func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) { collection := p.db.Table(models.Collections.Env) env.UpdatedAt = time.Now().Unix() - err := UpdateByHashKey(collection, "id", env.ID, env) - if err != nil { return env, err } @@ -46,26 +38,21 @@ func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, e } // GetEnv to get environment information from database -func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { - var env models.Env - +func (p *provider) GetEnv(ctx context.Context) (*models.Env, error) { + var env *models.Env collection := p.db.Table(models.Collections.Env) // As there is no Findone supported. iter := collection.Scan().Limit(1).Iter() - for iter.NextWithContext(ctx, &env) { - if env.ID == "" { + if env == nil { return env, errors.New("no documets found") } else { return env, nil } } - err := iter.Err() - if err != nil { return env, fmt.Errorf("config not found") } - return env, nil } diff --git a/server/db/providers/dynamodb/otp.go b/server/db/providers/dynamodb/otp.go index 063f634..23273e2 100644 --- a/server/db/providers/dynamodb/otp.go +++ b/server/db/providers/dynamodb/otp.go @@ -11,27 +11,39 @@ import ( // UpsertOTP to add or update otp func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { - otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) + // check if email or phone number is present + if otpParam.Email == "" && otpParam.PhoneNumber == "" { + return nil, errors.New("email or phone_number is required") + } + uniqueField := models.FieldNameEmail + if otpParam.Email == "" && otpParam.PhoneNumber != "" { + uniqueField = models.FieldNamePhoneNumber + } + var otp *models.OTP + if uniqueField == models.FieldNameEmail { + otp, _ = p.GetOTPByEmail(ctx, otpParam.Email) + } else { + otp, _ = p.GetOTPByPhoneNumber(ctx, otpParam.PhoneNumber) + } shouldCreate := false if otp == nil { id := uuid.NewString() otp = &models.OTP{ - ID: id, - Key: id, - Otp: otpParam.Otp, - Email: otpParam.Email, - ExpiresAt: otpParam.ExpiresAt, - CreatedAt: time.Now().Unix(), + ID: id, + Key: id, + Otp: otpParam.Otp, + Email: otpParam.Email, + PhoneNumber: otpParam.PhoneNumber, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), } shouldCreate = true } else { otp.Otp = otpParam.Otp otp.ExpiresAt = otpParam.ExpiresAt } - collection := p.db.Table(models.Collections.OTP) otp.UpdatedAt = time.Now().Unix() - var err error if shouldCreate { err = collection.Put(otp).RunWithContext(ctx) @@ -41,7 +53,6 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models if err != nil { return nil, err } - return otp, nil } @@ -49,32 +60,42 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { var otps []models.OTP var otp models.OTP - collection := p.db.Table(models.Collections.OTP) - err := collection.Scan().Index("email").Filter("'email' = ?", emailAddress).Limit(1).AllWithContext(ctx, &otps) - if err != nil { return nil, err } if len(otps) > 0 { otp = otps[0] return &otp, nil - } else { - return nil, errors.New("no docuemnt found") } + return nil, errors.New("no docuemnt found") +} + +// GetOTPByPhoneNumber to get otp for a given phone number +func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) { + var otps []models.OTP + var otp models.OTP + collection := p.db.Table(models.Collections.OTP) + err := collection.Scan().Filter("'phone_number' = ?", phoneNumber).Limit(1).AllWithContext(ctx, &otps) + if err != nil { + return nil, err + } + if len(otps) > 0 { + otp = otps[0] + return &otp, nil + } + return nil, errors.New("no docuemnt found") } // DeleteOTP to delete otp func (p *provider) DeleteOTP(ctx context.Context, otp *models.OTP) error { collection := p.db.Table(models.Collections.OTP) - if otp.ID != "" { err := collection.Delete("id", otp.ID).RunWithContext(ctx) if err != nil { return err } } - return nil } diff --git a/server/db/providers/dynamodb/provider.go b/server/db/providers/dynamodb/provider.go index 4bf3345..4cc4084 100644 --- a/server/db/providers/dynamodb/provider.go +++ b/server/db/providers/dynamodb/provider.go @@ -31,21 +31,19 @@ func NewProvider() (*provider, error) { if awsRegion != "" { config.Region = aws.String(awsRegion) } - // custom awsAccessKeyID, awsSecretAccessKey took first priority, if not then fetch config from aws credentials if awsAccessKeyID != "" && awsSecretAccessKey != "" { config.Credentials = credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, "") } else if dbURL != "" { + log.Debug("Tring to use database url for dynamodb") // static config in case of testing or local-setup config.Credentials = credentials.NewStaticCredentials("key", "key", "") config.Endpoint = aws.String(dbURL) } else { log.Debugf("%s or %s or %s not found. Trying to load default credentials from aws config", constants.EnvAwsRegion, constants.EnvAwsAccessKeyID, constants.EnvAwsSecretAccessKey) } - session := session.Must(session.NewSession(&config)) db := dynamo.New(session) - db.CreateTable(models.Collections.User, models.User{}).Wait() db.CreateTable(models.Collections.Session, models.Session{}).Wait() db.CreateTable(models.Collections.EmailTemplate, models.EmailTemplate{}).Wait() @@ -54,7 +52,6 @@ func NewProvider() (*provider, error) { db.CreateTable(models.Collections.VerificationRequest, models.VerificationRequest{}).Wait() db.CreateTable(models.Collections.Webhook, models.Webhook{}).Wait() db.CreateTable(models.Collections.WebhookLog, models.WebhookLog{}).Wait() - return &provider{ db: db, }, nil diff --git a/server/db/providers/dynamodb/session.go b/server/db/providers/dynamodb/session.go index 68457e5..d65da9a 100644 --- a/server/db/providers/dynamodb/session.go +++ b/server/db/providers/dynamodb/session.go @@ -9,13 +9,11 @@ import ( ) // AddSession to save session information in database -func (p *provider) AddSession(ctx context.Context, session models.Session) error { +func (p *provider) AddSession(ctx context.Context, session *models.Session) error { collection := p.db.Table(models.Collections.Session) - if session.ID == "" { session.ID = uuid.New().String() } - session.CreatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix() err := collection.Put(session).RunWithContext(ctx) diff --git a/server/db/providers/dynamodb/shared.go b/server/db/providers/dynamodb/shared.go index ed91eca..5597c0a 100644 --- a/server/db/providers/dynamodb/shared.go +++ b/server/db/providers/dynamodb/shared.go @@ -9,16 +9,13 @@ import ( func UpdateByHashKey(table dynamo.Table, hashKey string, hashValue string, item interface{}) error { existingValue, err := dynamo.MarshalItem(item) var i interface{} - if err != nil { return err } - nullableValue, err := dynamodbattribute.MarshalMap(item) if err != nil { return err } - u := table.Update(hashKey, hashValue) for k, v := range existingValue { if k == hashKey { @@ -26,7 +23,6 @@ func UpdateByHashKey(table dynamo.Table, hashKey string, hashValue string, item } u = u.Set(k, v) } - for k, v := range nullableValue { if k == hashKey { continue @@ -36,11 +32,9 @@ func UpdateByHashKey(table dynamo.Table, hashKey string, hashValue string, item u = u.SetNullable(k, v) } } - err = u.Run() if err != nil { return err } - return nil } diff --git a/server/db/providers/dynamodb/sms_verification_requests.go b/server/db/providers/dynamodb/sms_verification_requests.go deleted file mode 100644 index bd47bce..0000000 --- a/server/db/providers/dynamodb/sms_verification_requests.go +++ /dev/null @@ -1,23 +0,0 @@ -package dynamodb - -import ( - "context" - - "github.com/authorizerdev/authorizer/server/db/models" - -) - -// SMS verification Request -func (p *provider) UpsertSMSRequest(ctx context.Context, sms_code *models.SMSVerificationRequest) (*models.SMSVerificationRequest, error) { - return sms_code, nil -} - -func (p *provider) GetCodeByPhone(ctx context.Context, phoneNumber string) (*models.SMSVerificationRequest, error) { - var sms_verification_request models.SMSVerificationRequest - - return &sms_verification_request, nil -} - -func(p *provider) DeleteSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) error { - return nil -} diff --git a/server/db/providers/dynamodb/user.go b/server/db/providers/dynamodb/user.go index d7c47e3..f93956b 100644 --- a/server/db/providers/dynamodb/user.go +++ b/server/db/providers/dynamodb/user.go @@ -18,13 +18,11 @@ import ( ) // AddUser to save user information in database -func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) { collection := p.db.Table(models.Collections.User) - if user.ID == "" { user.ID = uuid.New().String() } - if user.Roles == "" { defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) if err != nil { @@ -32,18 +30,14 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, } user.Roles = defaultRoles } - if user.PhoneNumber != nil && strings.TrimSpace(refs.StringValue(user.PhoneNumber)) != "" { if u, _ := p.GetUserByPhoneNumber(ctx, refs.StringValue(user.PhoneNumber)); u != nil { return user, fmt.Errorf("user with given phone number already exists") } } - user.CreatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix() - err := collection.Put(user).RunWithContext(ctx) - if err != nil { return user, err } @@ -51,18 +45,14 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, } // UpdateUser to update user information in database -func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) { collection := p.db.Table(models.Collections.User) - if user.ID != "" { - user.UpdatedAt = time.Now().Unix() - err := UpdateByHashKey(collection, "id", user.ID, user) if err != nil { return user, err } - if err != nil { return user, err } @@ -72,18 +62,15 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use } // DeleteUser to delete user information from database -func (p *provider) DeleteUser(ctx context.Context, user models.User) error { +func (p *provider) DeleteUser(ctx context.Context, user *models.User) error { collection := p.db.Table(models.Collections.User) sessionCollection := p.db.Table(models.Collections.Session) - if user.ID != "" { err := collection.Delete("id", user.ID).Run() if err != nil { return err } - _, err = sessionCollection.Batch("id").Write().Delete(dynamo.Keys{"user_id", user.ID}).RunWithContext(ctx) - if err != nil { return err } @@ -92,23 +79,19 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { } // ListUsers to get list of users from database -func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { - var user models.User +func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) { + var user *models.User var lastEval dynamo.PagingKey var iter dynamo.PagingIter var iteration int64 = 0 - collection := p.db.Table(models.Collections.User) users := []*model.User{} - paginationClone := pagination scanner := collection.Scan() count, err := scanner.Count() - if err != nil { return nil, err } - for (paginationClone.Offset + paginationClone.Limit) > iteration { iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() for iter.NextWithContext(ctx, &user) { @@ -119,48 +102,39 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( lastEval = iter.LastEvaluatedKey() iteration += paginationClone.Limit } - err = iter.Err() - if err != nil { return nil, err } - paginationClone.Total = count - return &model.Users{ - Pagination: &paginationClone, + Pagination: paginationClone, Users: users, }, nil } // GetUserByEmail to get user information from database using email address -func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { - var users []models.User - var user models.User - +func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { + var users []*models.User + var user *models.User collection := p.db.Table(models.Collections.User) err := collection.Scan().Index("email").Filter("'email' = ?", email).AllWithContext(ctx, &users) - if err != nil { return user, nil } - if len(users) > 0 { user = users[0] return user, nil } else { return user, errors.New("no record found") } - } // GetUserByID to get user information from database using user ID -func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { +func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) { collection := p.db.Table(models.Collections.User) - var user models.User + var user *models.User err := collection.Get("id", id).OneWithContext(ctx, &user) - if err != nil { if user.Email == "" { return user, errors.New("no documets found") @@ -186,7 +160,6 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, } else { // as there is no facility to update all doc - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/SQLtoNoSQL.UpdateData.html userCollection.Scan().All(&allUsers) - for _, user := range allUsers { err = UpdateByHashKey(userCollection, "id", user.ID, data) if err == nil { @@ -194,7 +167,6 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, } } } - if err != nil { return err } else { @@ -205,19 +177,16 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, // GetUserByPhoneNumber to get user information from database using phone number func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*models.User, error) { - var users []models.User - var user models.User - + var users []*models.User + var user *models.User collection := p.db.Table(models.Collections.User) err := collection.Scan().Filter("'phone_number' = ?", phoneNumber).AllWithContext(ctx, &users) - if err != nil { return nil, err } - if len(users) > 0 { user = users[0] - return &user, nil + return user, nil } else { return nil, errors.New("no record found") } diff --git a/server/db/providers/dynamodb/verification_requests.go b/server/db/providers/dynamodb/verification_requests.go index 990c288..5fdf078 100644 --- a/server/db/providers/dynamodb/verification_requests.go +++ b/server/db/providers/dynamodb/verification_requests.go @@ -11,9 +11,8 @@ import ( ) // AddVerification to save verification request in database -func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { +func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) { collection := p.db.Table(models.Collections.VerificationRequest) - if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() verificationRequest.CreatedAt = time.Now().Unix() @@ -23,20 +22,17 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque return verificationRequest, err } } - return verificationRequest, nil } // GetVerificationRequestByToken to get verification request from database using token -func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) { +func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) { collection := p.db.Table(models.Collections.VerificationRequest) - var verificationRequest models.VerificationRequest - + var verificationRequest *models.VerificationRequest iter := collection.Scan().Filter("'token' = ?", token).Iter() for iter.NextWithContext(ctx, &verificationRequest) { return verificationRequest, nil } - err := iter.Err() if err != nil { return verificationRequest, err @@ -45,14 +41,13 @@ func (p *provider) GetVerificationRequestByToken(ctx context.Context, token stri } // GetVerificationRequestByEmail to get verification request by email from database -func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest +func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest collection := p.db.Table(models.Collections.VerificationRequest) iter := collection.Scan().Filter("'email' = ?", email).Filter("'identifier' = ?", identifier).Iter() for iter.NextWithContext(ctx, &verificationRequest) { return verificationRequest, nil } - err := iter.Err() if err != nil { return verificationRequest, err @@ -61,23 +56,19 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri } // ListVerificationRequests to get list of verification requests from database -func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) { +func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) { verificationRequests := []*model.VerificationRequest{} - var verificationRequest models.VerificationRequest + var verificationRequest *models.VerificationRequest var lastEval dynamo.PagingKey var iter dynamo.PagingIter var iteration int64 = 0 - collection := p.db.Table(models.Collections.VerificationRequest) paginationClone := pagination - scanner := collection.Scan() count, err := scanner.Count() - if err != nil { return nil, err } - for (paginationClone.Offset + paginationClone.Limit) > iteration { iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() for iter.NextWithContext(ctx, &verificationRequest) { @@ -92,20 +83,17 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode lastEval = iter.LastEvaluatedKey() iteration += paginationClone.Limit } - paginationClone.Total = count - return &model.VerificationRequests{ VerificationRequests: verificationRequests, - Pagination: &paginationClone, + Pagination: paginationClone, }, nil } // DeleteVerificationRequest to delete verification request from database -func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error { +func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error { collection := p.db.Table(models.Collections.VerificationRequest) - - if verificationRequest.ID != "" { + if verificationRequest != nil { err := collection.Delete("id", verificationRequest.ID).RunWithContext(ctx) if err != nil { diff --git a/server/db/providers/dynamodb/webhook.go b/server/db/providers/dynamodb/webhook.go index 8f1ffb7..c50e1fb 100644 --- a/server/db/providers/dynamodb/webhook.go +++ b/server/db/providers/dynamodb/webhook.go @@ -15,7 +15,7 @@ import ( ) // AddWebhook to add webhook -func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { collection := p.db.Table(models.Collections.Webhook) if webhook.ID == "" { webhook.ID = uuid.New().String() @@ -33,7 +33,7 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod } // UpdateWebhook to update webhook -func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() // Event is changed if !strings.Contains(webhook.EventName, "-") { @@ -48,9 +48,9 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } // ListWebhooks to list webhook -func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { +func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} - var webhook models.Webhook + var webhook *models.Webhook var lastEval dynamo.PagingKey var iter dynamo.PagingIter var iteration int64 = 0 @@ -77,7 +77,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) } paginationClone.Total = count return &model.Webhooks{ - Pagination: &paginationClone, + Pagination: paginationClone, Webhooks: webhooks, }, nil } @@ -85,7 +85,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { collection := p.db.Table(models.Collections.Webhook) - var webhook models.Webhook + var webhook *models.Webhook err := collection.Get("id", webhookID).OneWithContext(ctx, &webhook) if err != nil { return nil, err @@ -114,14 +114,14 @@ func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) // DeleteWebhook to delete webhook func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error { // Also delete webhook logs for given webhook id - if webhook.ID != "" { + if webhook != nil { webhookCollection := p.db.Table(models.Collections.Webhook) - pagination := model.Pagination{} webhookLogCollection := p.db.Table(models.Collections.WebhookLog) err := webhookCollection.Delete("id", webhook.ID).RunWithContext(ctx) if err != nil { return err } + pagination := &model.Pagination{} webhookLogs, errIs := p.ListWebhookLogs(ctx, pagination, webhook.ID) for _, webhookLog := range webhookLogs.WebhookLogs { err = webhookLogCollection.Delete("id", webhookLog.ID).RunWithContext(ctx) diff --git a/server/db/providers/dynamodb/webhook_log.go b/server/db/providers/dynamodb/webhook_log.go index e9d1dcd..18ba261 100644 --- a/server/db/providers/dynamodb/webhook_log.go +++ b/server/db/providers/dynamodb/webhook_log.go @@ -11,18 +11,15 @@ import ( ) // AddWebhookLog to add webhook log -func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { +func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) { collection := p.db.Table(models.Collections.WebhookLog) - if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() } - webhookLog.Key = webhookLog.ID webhookLog.CreatedAt = time.Now().Unix() webhookLog.UpdatedAt = time.Now().Unix() err := collection.Put(webhookLog).RunWithContext(ctx) - if err != nil { return nil, err } @@ -30,9 +27,9 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL } // ListWebhookLogs to list webhook logs -func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { +func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) { webhookLogs := []*model.WebhookLog{} - var webhookLog models.WebhookLog + var webhookLog *models.WebhookLog var lastEval dynamo.PagingKey var iter dynamo.PagingIter var iteration int64 = 0 @@ -42,7 +39,6 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat collection := p.db.Table(models.Collections.WebhookLog) paginationClone := pagination scanner := collection.Scan() - if webhookID != "" { iter = scanner.Index("webhook_id").Filter("'webhook_id' = ?", webhookID).Iter() for iter.NextWithContext(ctx, &webhookLog) { @@ -68,11 +64,10 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat iteration += paginationClone.Limit } } - paginationClone.Total = count // paginationClone.Cursor = iter.LastEvaluatedKey() return &model.WebhookLogs{ - Pagination: &paginationClone, + Pagination: paginationClone, WebhookLogs: webhookLogs, }, nil } diff --git a/server/db/providers/mongodb/email_template.go b/server/db/providers/mongodb/email_template.go index 0a0d1d9..c3fa31b 100644 --- a/server/db/providers/mongodb/email_template.go +++ b/server/db/providers/mongodb/email_template.go @@ -12,15 +12,13 @@ import ( ) // AddEmailTemplate to add EmailTemplate -func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() } - emailTemplate.Key = emailTemplate.ID emailTemplate.CreatedAt = time.Now().Unix() emailTemplate.UpdatedAt = time.Now().Unix() - emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection()) _, err := emailTemplateCollection.InsertOne(ctx, emailTemplate) if err != nil { @@ -30,60 +28,52 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.Em } // UpdateEmailTemplate to update EmailTemplate -func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { emailTemplate.UpdatedAt = time.Now().Unix() - emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection()) _, err := emailTemplateCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": emailTemplate.ID}}, bson.M{"$set": emailTemplate}, options.MergeUpdateOptions()) if err != nil { return nil, err } - return emailTemplate.AsAPIEmailTemplate(), nil } // ListEmailTemplates to list EmailTemplate -func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { +func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) { var emailTemplates []*model.EmailTemplate opts := options.Find() opts.SetLimit(pagination.Limit) opts.SetSkip(pagination.Offset) opts.SetSort(bson.M{"created_at": -1}) - paginationClone := pagination - emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection()) count, err := emailTemplateCollection.CountDocuments(ctx, bson.M{}, options.Count()) if err != nil { return nil, err } - paginationClone.Total = count - cursor, err := emailTemplateCollection.Find(ctx, bson.M{}, opts) if err != nil { return nil, err } defer cursor.Close(ctx) - for cursor.Next(ctx) { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate err := cursor.Decode(&emailTemplate) if err != nil { return nil, err } emailTemplates = append(emailTemplates, emailTemplate.AsAPIEmailTemplate()) } - return &model.EmailTemplates{ - Pagination: &paginationClone, + Pagination: paginationClone, EmailTemplates: emailTemplates, }, nil } // GetEmailTemplateByID to get EmailTemplate by id func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection()) err := emailTemplateCollection.FindOne(ctx, bson.M{"_id": emailTemplateID}).Decode(&emailTemplate) if err != nil { @@ -94,7 +84,7 @@ func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID str // GetEmailTemplateByEventName to get EmailTemplate by event_name func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate emailTemplateCollection := p.db.Collection(models.Collections.EmailTemplate, options.Collection()) err := emailTemplateCollection.FindOne(ctx, bson.M{"event_name": eventName}).Decode(&emailTemplate) if err != nil { @@ -110,6 +100,5 @@ func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model if err != nil { return err } - return nil } diff --git a/server/db/providers/mongodb/env.go b/server/db/providers/mongodb/env.go index a4b114c..b725612 100644 --- a/server/db/providers/mongodb/env.go +++ b/server/db/providers/mongodb/env.go @@ -12,11 +12,10 @@ import ( ) // AddEnv to save environment information in database -func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) { if env.ID == "" { env.ID = uuid.New().String() } - env.CreatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix() env.Key = env.ID @@ -29,7 +28,7 @@ func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, erro } // UpdateEnv to update environment information in database -func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) { env.UpdatedAt = time.Now().Unix() configCollection := p.db.Collection(models.Collections.Env, options.Collection()) _, err := configCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": env.ID}}, bson.M{"$set": env}, options.MergeUpdateOptions()) @@ -40,25 +39,22 @@ func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, e } // GetEnv to get environment information from database -func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { - var env models.Env +func (p *provider) GetEnv(ctx context.Context) (*models.Env, error) { + var env *models.Env configCollection := p.db.Collection(models.Collections.Env, options.Collection()) cursor, err := configCollection.Find(ctx, bson.M{}, options.Find()) if err != nil { return env, err } defer cursor.Close(ctx) - for cursor.Next(nil) { err := cursor.Decode(&env) if err != nil { return env, err } } - - if env.ID == "" { + if env == nil { return env, fmt.Errorf("config not found") } - return env, nil } diff --git a/server/db/providers/mongodb/otp.go b/server/db/providers/mongodb/otp.go index d6ff2df..d70818d 100644 --- a/server/db/providers/mongodb/otp.go +++ b/server/db/providers/mongodb/otp.go @@ -2,6 +2,7 @@ package mongodb import ( "context" + "errors" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -12,17 +13,31 @@ import ( // UpsertOTP to add or update otp func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models.OTP, error) { - otp, _ := p.GetOTPByEmail(ctx, otpParam.Email) + // check if email or phone number is present + if otpParam.Email == "" && otpParam.PhoneNumber == "" { + return nil, errors.New("email or phone_number is required") + } + uniqueField := models.FieldNameEmail + if otpParam.Email == "" && otpParam.PhoneNumber != "" { + uniqueField = models.FieldNamePhoneNumber + } + var otp *models.OTP + if uniqueField == models.FieldNameEmail { + otp, _ = p.GetOTPByEmail(ctx, otpParam.Email) + } else { + otp, _ = p.GetOTPByPhoneNumber(ctx, otpParam.PhoneNumber) + } shouldCreate := false if otp == nil { id := uuid.NewString() otp = &models.OTP{ - ID: id, - Key: id, - Otp: otpParam.Otp, - Email: otpParam.Email, - ExpiresAt: otpParam.ExpiresAt, - CreatedAt: time.Now().Unix(), + ID: id, + Key: id, + Otp: otpParam.Otp, + Email: otpParam.Email, + PhoneNumber: otpParam.PhoneNumber, + ExpiresAt: otpParam.ExpiresAt, + CreatedAt: time.Now().Unix(), } shouldCreate = true } else { @@ -41,20 +56,28 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *models.OTP) (*models if err != nil { return nil, err } - return otp, nil } // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { var otp models.OTP - otpCollection := p.db.Collection(models.Collections.OTP, options.Collection()) err := otpCollection.FindOne(ctx, bson.M{"email": emailAddress}).Decode(&otp) if err != nil { return nil, err } + return &otp, nil +} +// GetOTPByPhoneNumber to get otp for a given phone number +func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) { + var otp models.OTP + otpCollection := p.db.Collection(models.Collections.OTP, options.Collection()) + err := otpCollection.FindOne(ctx, bson.M{"phone_number": phoneNumber}).Decode(&otp) + if err != nil { + return nil, err + } return &otp, nil } diff --git a/server/db/providers/mongodb/provider.go b/server/db/providers/mongodb/provider.go index fd4a0b2..30af342 100644 --- a/server/db/providers/mongodb/provider.go +++ b/server/db/providers/mongodb/provider.go @@ -118,10 +118,7 @@ func NewProvider() (*provider, error) { Options: options.Index().SetUnique(true).SetSparse(true), }, }, options.CreateIndexes()) - - mongodb.CreateCollection(ctx, models.Collections.SMSVerificationRequest, options.CreateCollection()) - smsCollection := mongodb.Collection(models.Collections.SMSVerificationRequest, options.Collection()) - smsCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ + otpCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ { Keys: bson.M{"phone_number": 1}, Options: options.Index().SetUnique(true).SetSparse(true), diff --git a/server/db/providers/mongodb/session.go b/server/db/providers/mongodb/session.go index 4030130..860eeef 100644 --- a/server/db/providers/mongodb/session.go +++ b/server/db/providers/mongodb/session.go @@ -10,7 +10,7 @@ import ( ) // AddSession to save session information in database -func (p *provider) AddSession(ctx context.Context, session models.Session) error { +func (p *provider) AddSession(ctx context.Context, session *models.Session) error { if session.ID == "" { session.ID = uuid.New().String() } @@ -25,3 +25,8 @@ func (p *provider) AddSession(ctx context.Context, session models.Session) error } return nil } + +// DeleteSession to delete session information from database +func (p *provider) DeleteSession(ctx context.Context, userId string) error { + return nil +} diff --git a/server/db/providers/mongodb/sms_verification_requests.go b/server/db/providers/mongodb/sms_verification_requests.go deleted file mode 100644 index b2d3a13..0000000 --- a/server/db/providers/mongodb/sms_verification_requests.go +++ /dev/null @@ -1,69 +0,0 @@ -package mongodb - -import ( - "context" - "time" - - "github.com/authorizerdev/authorizer/server/db/models" - "github.com/google/uuid" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/options" -) - -// SMS verification Request -func (p *provider) UpsertSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) (*models.SMSVerificationRequest, error) { - smsVerificationRequest, _ := p.GetCodeByPhone(ctx, smsRequest.PhoneNumber) - shouldCreate := false - - if smsVerificationRequest == nil { - id := uuid.NewString() - - smsVerificationRequest = &models.SMSVerificationRequest{ - ID: id, - CreatedAt: time.Now().Unix(), - Code: smsRequest.Code, - PhoneNumber: smsRequest.PhoneNumber, - CodeExpiresAt: smsRequest.CodeExpiresAt, - } - shouldCreate = true - } - - smsVerificationRequest.UpdatedAt = time.Now().Unix() - smsRequestCollection := p.db.Collection(models.Collections.SMSVerificationRequest, options.Collection()) - - var err error - if shouldCreate { - _, err = smsRequestCollection.InsertOne(ctx, smsVerificationRequest) - } else { - _, err = smsRequestCollection.UpdateOne(ctx, bson.M{"phone_number": bson.M{"$eq": smsRequest.PhoneNumber}}, bson.M{"$set": smsVerificationRequest}, options.MergeUpdateOptions()) - } - - if err != nil { - return nil, err - } - - return smsVerificationRequest, nil -} - -func (p *provider) GetCodeByPhone(ctx context.Context, phoneNumber string) (*models.SMSVerificationRequest, error) { - var smsVerificationRequest models.SMSVerificationRequest - - smsRequestCollection := p.db.Collection(models.Collections.SMSVerificationRequest, options.Collection()) - err := smsRequestCollection.FindOne(ctx, bson.M{"phone_number": phoneNumber}).Decode(&smsVerificationRequest) - - if err != nil { - return nil, err - } - - return &smsVerificationRequest, nil -} - -func (p *provider) DeleteSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) error { - smsVerificationRequests := p.db.Collection(models.Collections.SMSVerificationRequest, options.Collection()) - _, err := smsVerificationRequests.DeleteOne(nil, bson.M{"_id": smsRequest.ID}, options.Delete()) - if err != nil { - return err - } - - return nil -} diff --git a/server/db/providers/mongodb/user.go b/server/db/providers/mongodb/user.go index 32b6a17..078322e 100644 --- a/server/db/providers/mongodb/user.go +++ b/server/db/providers/mongodb/user.go @@ -16,11 +16,10 @@ import ( ) // AddUser to save user information in database -func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) { if user.ID == "" { user.ID = uuid.New().String() } - if user.Roles == "" { defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) if err != nil { @@ -36,12 +35,11 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, if err != nil { return user, err } - return user, nil } // UpdateUser to update user information in database -func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) { user.UpdatedAt = time.Now().Unix() userCollection := p.db.Collection(models.Collections.User, options.Collection()) _, err := userCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": user.ID}}, bson.M{"$set": user}, options.MergeUpdateOptions()) @@ -52,83 +50,72 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use } // DeleteUser to delete user information from database -func (p *provider) DeleteUser(ctx context.Context, user models.User) error { +func (p *provider) DeleteUser(ctx context.Context, user *models.User) error { userCollection := p.db.Collection(models.Collections.User, options.Collection()) _, err := userCollection.DeleteOne(ctx, bson.M{"_id": user.ID}, options.Delete()) if err != nil { return err } - sessionCollection := p.db.Collection(models.Collections.Session, options.Collection()) _, err = sessionCollection.DeleteMany(ctx, bson.M{"user_id": user.ID}, options.Delete()) if err != nil { return err } - return nil } // ListUsers to get list of users from database -func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { +func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) { var users []*model.User opts := options.Find() opts.SetLimit(pagination.Limit) opts.SetSkip(pagination.Offset) opts.SetSort(bson.M{"created_at": -1}) - paginationClone := pagination - userCollection := p.db.Collection(models.Collections.User, options.Collection()) count, err := userCollection.CountDocuments(ctx, bson.M{}, options.Count()) if err != nil { return nil, err } - paginationClone.Total = count - cursor, err := userCollection.Find(ctx, bson.M{}, opts) if err != nil { return nil, err } defer cursor.Close(ctx) - for cursor.Next(ctx) { - var user models.User + var user *models.User err := cursor.Decode(&user) if err != nil { return nil, err } users = append(users, user.AsAPIUser()) } - return &model.Users{ - Pagination: &paginationClone, + Pagination: paginationClone, Users: users, }, nil } // GetUserByEmail to get user information from database using email address -func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { - var user models.User +func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { + var user *models.User userCollection := p.db.Collection(models.Collections.User, options.Collection()) err := userCollection.FindOne(ctx, bson.M{"email": email}).Decode(&user) if err != nil { return user, err } - return user, nil } // GetUserByID to get user information from database using user ID -func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { - var user models.User - +func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) { + var user *models.User userCollection := p.db.Collection(models.Collections.User, options.Collection()) err := userCollection.FindOne(ctx, bson.M{"_id": id}).Decode(&user) if err != nil { return user, err } - return user, nil } @@ -137,17 +124,14 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { // set updated_at time for all users data["updated_at"] = time.Now().Unix() - userCollection := p.db.Collection(models.Collections.User, options.Collection()) - var res *mongo.UpdateResult var err error - if ids != nil && len(ids) > 0 { + if len(ids) > 0 { res, err = userCollection.UpdateMany(ctx, bson.M{"_id": bson.M{"$in": ids}}, bson.M{"$set": data}) } else { res, err = userCollection.UpdateMany(ctx, bson.M{}, bson.M{"$set": data}) } - if err != nil { return err } else { @@ -158,13 +142,11 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, // GetUserByPhoneNumber to get user information from database using phone number func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*models.User, error) { - var user models.User - + var user *models.User userCollection := p.db.Collection(models.Collections.User, options.Collection()) err := userCollection.FindOne(ctx, bson.M{"phone_number": phoneNumber}).Decode(&user) if err != nil { return nil, err } - - return &user, nil + return user, nil } diff --git a/server/db/providers/mongodb/verification_requests.go b/server/db/providers/mongodb/verification_requests.go index ff6f908..532d8c8 100644 --- a/server/db/providers/mongodb/verification_requests.go +++ b/server/db/providers/mongodb/verification_requests.go @@ -12,7 +12,7 @@ import ( ) // AddVerification to save verification request in database -func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { +func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) { if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() @@ -30,8 +30,8 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque } // GetVerificationRequestByToken to get verification request from database using token -func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest +func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) err := verificationRequestCollection.FindOne(ctx, bson.M{"token": token}).Decode(&verificationRequest) @@ -43,8 +43,8 @@ func (p *provider) GetVerificationRequestByToken(ctx context.Context, token stri } // GetVerificationRequestByEmail to get verification request by email from database -func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest +func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) err := verificationRequestCollection.FindOne(ctx, bson.M{"email": email, "identifier": identifier}).Decode(&verificationRequest) @@ -56,7 +56,7 @@ func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email stri } // ListVerificationRequests to get list of verification requests from database -func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) { +func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) { var verificationRequests []*model.VerificationRequest opts := options.Find() @@ -77,7 +77,7 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode defer cursor.Close(ctx) for cursor.Next(ctx) { - var verificationRequest models.VerificationRequest + var verificationRequest *models.VerificationRequest err := cursor.Decode(&verificationRequest) if err != nil { return nil, err @@ -87,12 +87,12 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination mode return &model.VerificationRequests{ VerificationRequests: verificationRequests, - Pagination: &paginationClone, + Pagination: paginationClone, }, nil } // DeleteVerificationRequest to delete verification request from database -func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error { +func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error { verificationRequestCollection := p.db.Collection(models.Collections.VerificationRequest, options.Collection()) _, err := verificationRequestCollection.DeleteOne(ctx, bson.M{"_id": verificationRequest.ID}, options.Delete()) if err != nil { diff --git a/server/db/providers/mongodb/webhook.go b/server/db/providers/mongodb/webhook.go index 843aec9..ef6b382 100644 --- a/server/db/providers/mongodb/webhook.go +++ b/server/db/providers/mongodb/webhook.go @@ -14,7 +14,7 @@ import ( ) // AddWebhook to add webhook -func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { if webhook.ID == "" { webhook.ID = uuid.New().String() } @@ -32,7 +32,7 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod } // UpdateWebhook to update webhook -func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() // Event is changed if !strings.Contains(webhook.EventName, "-") { @@ -47,7 +47,7 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } // ListWebhooks to list webhook -func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { +func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) { webhooks := []*model.Webhook{} opts := options.Find() opts.SetLimit(pagination.Limit) @@ -66,7 +66,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) } defer cursor.Close(ctx) for cursor.Next(ctx) { - var webhook models.Webhook + var webhook *models.Webhook err := cursor.Decode(&webhook) if err != nil { return nil, err @@ -74,14 +74,14 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) webhooks = append(webhooks, webhook.AsAPIWebhook()) } return &model.Webhooks{ - Pagination: &paginationClone, + Pagination: paginationClone, Webhooks: webhooks, }, nil } // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { - var webhook models.Webhook + var webhook *models.Webhook webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) err := webhookCollection.FindOne(ctx, bson.M{"_id": webhookID}).Decode(&webhook) if err != nil { @@ -104,7 +104,7 @@ func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) } defer cursor.Close(ctx) for cursor.Next(ctx) { - var webhook models.Webhook + var webhook *models.Webhook err := cursor.Decode(&webhook) if err != nil { return nil, err diff --git a/server/db/providers/mongodb/webhook_log.go b/server/db/providers/mongodb/webhook_log.go index 6e1081c..0c464d8 100644 --- a/server/db/providers/mongodb/webhook_log.go +++ b/server/db/providers/mongodb/webhook_log.go @@ -12,7 +12,7 @@ import ( ) // AddWebhookLog to add webhook log -func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { +func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) { if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() } @@ -30,7 +30,7 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL } // ListWebhookLogs to list webhook logs -func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { +func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) { webhookLogs := []*model.WebhookLog{} opts := options.Find() opts.SetLimit(pagination.Limit) @@ -59,7 +59,7 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat defer cursor.Close(ctx) for cursor.Next(ctx) { - var webhookLog models.WebhookLog + var webhookLog *models.WebhookLog err := cursor.Decode(&webhookLog) if err != nil { return nil, err @@ -68,7 +68,7 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat } return &model.WebhookLogs{ - Pagination: &paginationClone, + Pagination: paginationClone, WebhookLogs: webhookLogs, }, nil } diff --git a/server/db/providers/provider_template/email_template.go b/server/db/providers/provider_template/email_template.go index e6a1f50..a306479 100644 --- a/server/db/providers/provider_template/email_template.go +++ b/server/db/providers/provider_template/email_template.go @@ -10,7 +10,7 @@ import ( ) // AddEmailTemplate to add EmailTemplate -func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() } @@ -22,13 +22,13 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.Em } // UpdateEmailTemplate to update EmailTemplate -func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { emailTemplate.UpdatedAt = time.Now().Unix() return emailTemplate.AsAPIEmailTemplate(), nil } // ListEmailTemplates to list EmailTemplate -func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { +func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) { return nil, nil } diff --git a/server/db/providers/provider_template/env.go b/server/db/providers/provider_template/env.go index af232e8..823d4e3 100644 --- a/server/db/providers/provider_template/env.go +++ b/server/db/providers/provider_template/env.go @@ -9,7 +9,7 @@ import ( ) // AddEnv to save environment information in database -func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) { if env.ID == "" { env.ID = uuid.New().String() } @@ -20,14 +20,14 @@ func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, erro } // UpdateEnv to update environment information in database -func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) { env.UpdatedAt = time.Now().Unix() return env, nil } // GetEnv to get environment information from database -func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { - var env models.Env +func (p *provider) GetEnv(ctx context.Context) (*models.Env, error) { + var env *models.Env return env, nil } diff --git a/server/db/providers/provider_template/otp.go b/server/db/providers/provider_template/otp.go index d8685e7..0716711 100644 --- a/server/db/providers/provider_template/otp.go +++ b/server/db/providers/provider_template/otp.go @@ -16,6 +16,11 @@ func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*mod return nil, nil } +// GetOTPByPhoneNumber to get otp for a given phone number +func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) { + return nil, nil +} + // DeleteOTP to delete otp func (p *provider) DeleteOTP(ctx context.Context, otp *models.OTP) error { return nil diff --git a/server/db/providers/provider_template/session.go b/server/db/providers/provider_template/session.go index c6f45ec..e398e8c 100644 --- a/server/db/providers/provider_template/session.go +++ b/server/db/providers/provider_template/session.go @@ -9,11 +9,10 @@ import ( ) // AddSession to save session information in database -func (p *provider) AddSession(ctx context.Context, session models.Session) error { +func (p *provider) AddSession(ctx context.Context, session *models.Session) error { if session.ID == "" { session.ID = uuid.New().String() } - session.CreatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix() return nil diff --git a/server/db/providers/provider_template/user.go b/server/db/providers/provider_template/user.go index 286e74d..9f4c7f8 100644 --- a/server/db/providers/provider_template/user.go +++ b/server/db/providers/provider_template/user.go @@ -12,11 +12,10 @@ import ( ) // AddUser to save user information in database -func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) { if user.ID == "" { user.ID = uuid.New().String() } - if user.Roles == "" { defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) if err != nil { @@ -24,40 +23,36 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, } user.Roles = defaultRoles } - user.CreatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix() - return user, nil } // UpdateUser to update user information in database -func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) { user.UpdatedAt = time.Now().Unix() return user, nil } // DeleteUser to delete user information from database -func (p *provider) DeleteUser(ctx context.Context, user models.User) error { +func (p *provider) DeleteUser(ctx context.Context, user *models.User) error { return nil } // ListUsers to get list of users from database -func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { +func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) { return nil, nil } // GetUserByEmail to get user information from database using email address -func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { - var user models.User - +func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { + var user *models.User return user, nil } // GetUserByID to get user information from database using user ID -func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { - var user models.User - +func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) { + var user *models.User return user, nil } @@ -66,13 +61,11 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { // set updated_at time for all users data["updated_at"] = time.Now().Unix() - return nil } // GetUserByPhoneNumber to get user information from database using phone number func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*models.User, error) { var user *models.User - return user, nil } diff --git a/server/db/providers/provider_template/verification_requests.go b/server/db/providers/provider_template/verification_requests.go index 577d2f6..c3a7f18 100644 --- a/server/db/providers/provider_template/verification_requests.go +++ b/server/db/providers/provider_template/verification_requests.go @@ -10,7 +10,7 @@ import ( ) // AddVerification to save verification request in database -func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { +func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) { if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() } @@ -22,25 +22,25 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque } // GetVerificationRequestByToken to get verification request from database using token -func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest +func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest return verificationRequest, nil } // GetVerificationRequestByEmail to get verification request by email from database -func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest +func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest return verificationRequest, nil } // ListVerificationRequests to get list of verification requests from database -func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) { +func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) { return nil, nil } // DeleteVerificationRequest to delete verification request from database -func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error { +func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error { return nil } diff --git a/server/db/providers/provider_template/webhook.go b/server/db/providers/provider_template/webhook.go index faf18fa..cf0edbe 100644 --- a/server/db/providers/provider_template/webhook.go +++ b/server/db/providers/provider_template/webhook.go @@ -12,7 +12,7 @@ import ( ) // AddWebhook to add webhook -func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { if webhook.ID == "" { webhook.ID = uuid.New().String() } @@ -25,7 +25,7 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod } // UpdateWebhook to update webhook -func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() // Event is changed if !strings.Contains(webhook.EventName, "-") { @@ -35,7 +35,7 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } // ListWebhooks to list webhook -func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { +func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) { return nil, nil } diff --git a/server/db/providers/provider_template/webhook_log.go b/server/db/providers/provider_template/webhook_log.go index 9814bc3..9ad81d2 100644 --- a/server/db/providers/provider_template/webhook_log.go +++ b/server/db/providers/provider_template/webhook_log.go @@ -10,7 +10,7 @@ import ( ) // AddWebhookLog to add webhook log -func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { +func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) { if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() } @@ -22,6 +22,6 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL } // ListWebhookLogs to list webhook logs -func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { +func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) { return nil, nil } diff --git a/server/db/providers/providers.go b/server/db/providers/providers.go index 28fbf78..65b9010 100644 --- a/server/db/providers/providers.go +++ b/server/db/providers/providers.go @@ -9,50 +9,52 @@ import ( type Provider interface { // AddUser to save user information in database - AddUser(ctx context.Context, user models.User) (models.User, error) + AddUser(ctx context.Context, user *models.User) (*models.User, error) // UpdateUser to update user information in database - UpdateUser(ctx context.Context, user models.User) (models.User, error) + UpdateUser(ctx context.Context, user *models.User) (*models.User, error) // DeleteUser to delete user information from database - DeleteUser(ctx context.Context, user models.User) error + DeleteUser(ctx context.Context, user *models.User) error // ListUsers to get list of users from database - ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) + ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) // GetUserByEmail to get user information from database using email address - GetUserByEmail(ctx context.Context, email string) (models.User, error) + GetUserByEmail(ctx context.Context, email string) (*models.User, error) // GetUserByPhoneNumber to get user information from database using phone number GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*models.User, error) // GetUserByID to get user information from database using user ID - GetUserByID(ctx context.Context, id string) (models.User, error) + GetUserByID(ctx context.Context, id string) (*models.User, error) // UpdateUsers to update multiple users, with parameters of user IDs slice // If ids set to nil / empty all the users will be updated UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error // AddVerification to save verification request in database - AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) + AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) // GetVerificationRequestByToken to get verification request from database using token - GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) + GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) // GetVerificationRequestByEmail to get verification request by email from database - GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) + GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) // ListVerificationRequests to get list of verification requests from database - ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) + ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) // DeleteVerificationRequest to delete verification request from database - DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error + DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error // AddSession to save session information in database - AddSession(ctx context.Context, session models.Session) error + AddSession(ctx context.Context, session *models.Session) error + // DeleteSession to delete session information from database + DeleteSession(ctx context.Context, userId string) error // AddEnv to save environment information in database - AddEnv(ctx context.Context, env models.Env) (models.Env, error) + AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) // UpdateEnv to update environment information in database - UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) + UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) // GetEnv to get environment information from database - GetEnv(ctx context.Context) (models.Env, error) + GetEnv(ctx context.Context) (*models.Env, error) // AddWebhook to add webhook - AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) + AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) // UpdateWebhook to update webhook - UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) + UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) // ListWebhooks to list webhook - ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) + ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) // GetWebhookByID to get webhook by id GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) // GetWebhookByEventName to get webhook by event_name @@ -61,16 +63,16 @@ type Provider interface { DeleteWebhook(ctx context.Context, webhook *model.Webhook) error // AddWebhookLog to add webhook log - AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) + AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) // ListWebhookLogs to list webhook logs - ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) + ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) // AddEmailTemplate to add EmailTemplate - AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) + AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) // UpdateEmailTemplate to update EmailTemplate - UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) + UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) // ListEmailTemplates to list EmailTemplate - ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) + ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) // GetEmailTemplateByID to get EmailTemplate by id GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) // GetEmailTemplateByEventName to get EmailTemplate by event_name @@ -82,13 +84,8 @@ type Provider interface { UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, error) // GetOTPByEmail to get otp for a given email address GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) + // GetOTPByPhoneNumber to get otp for a given phone number + GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) // DeleteOTP to delete otp DeleteOTP(ctx context.Context, otp *models.OTP) error - - // Upsert SMS code request - UpsertSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) (*models.SMSVerificationRequest, error) - // Get sms code by phone number - GetCodeByPhone(ctx context.Context, phoneNumber string) (*models.SMSVerificationRequest, error) - // Delete sms - DeleteSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) error } diff --git a/server/db/providers/sql/email_template.go b/server/db/providers/sql/email_template.go index 1a8e0d2..8928b6f 100644 --- a/server/db/providers/sql/email_template.go +++ b/server/db/providers/sql/email_template.go @@ -10,7 +10,7 @@ import ( ) // AddEmailTemplate to add EmailTemplate -func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() } @@ -27,7 +27,7 @@ func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate models.Em } // UpdateEmailTemplate to update EmailTemplate -func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models.EmailTemplate) (*model.EmailTemplate, error) { +func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *models.EmailTemplate) (*model.EmailTemplate, error) { emailTemplate.UpdatedAt = time.Now().Unix() res := p.db.Save(&emailTemplate) @@ -38,9 +38,8 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate models } // ListEmailTemplates to list EmailTemplate -func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagination) (*model.EmailTemplates, error) { - var emailTemplates []models.EmailTemplate - +func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) (*model.EmailTemplates, error) { + var emailTemplates []*models.EmailTemplate result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&emailTemplates) if result.Error != nil { return nil, result.Error @@ -60,14 +59,14 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination model.Pagin responseEmailTemplates = append(responseEmailTemplates, w.AsAPIEmailTemplate()) } return &model.EmailTemplates{ - Pagination: &paginationClone, + Pagination: paginationClone, EmailTemplates: responseEmailTemplates, }, nil } // GetEmailTemplateByID to get EmailTemplate by id func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*model.EmailTemplate, error) { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate result := p.db.Where("id = ?", emailTemplateID).First(&emailTemplate) if result.Error != nil { @@ -78,7 +77,7 @@ func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID str // GetEmailTemplateByEventName to get EmailTemplate by event_name func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*model.EmailTemplate, error) { - var emailTemplate models.EmailTemplate + var emailTemplate *models.EmailTemplate result := p.db.Where("event_name = ?", eventName).First(&emailTemplate) if result.Error != nil { @@ -95,6 +94,5 @@ func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *model if result.Error != nil { return result.Error } - return nil } diff --git a/server/db/providers/sql/env.go b/server/db/providers/sql/env.go index 1f34c38..11584a0 100644 --- a/server/db/providers/sql/env.go +++ b/server/db/providers/sql/env.go @@ -9,7 +9,7 @@ import ( ) // AddEnv to save environment information in database -func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) AddEnv(ctx context.Context, env *models.Env) (*models.Env, error) { if env.ID == "" { env.ID = uuid.New().String() } @@ -26,10 +26,9 @@ func (p *provider) AddEnv(ctx context.Context, env models.Env) (models.Env, erro } // UpdateEnv to update environment information in database -func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, error) { +func (p *provider) UpdateEnv(ctx context.Context, env *models.Env) (*models.Env, error) { env.UpdatedAt = time.Now().Unix() result := p.db.Save(&env) - if result.Error != nil { return env, result.Error } @@ -37,13 +36,11 @@ func (p *provider) UpdateEnv(ctx context.Context, env models.Env) (models.Env, e } // GetEnv to get environment information from database -func (p *provider) GetEnv(ctx context.Context) (models.Env, error) { - var env models.Env +func (p *provider) GetEnv(ctx context.Context) (*models.Env, error) { + var env *models.Env result := p.db.First(&env) - if result.Error != nil { return env, result.Error } - return env, nil } diff --git a/server/db/providers/sql/otp.go b/server/db/providers/sql/otp.go index 9aabcab..5503a7d 100644 --- a/server/db/providers/sql/otp.go +++ b/server/db/providers/sql/otp.go @@ -2,6 +2,7 @@ package sql import ( "context" + "errors" "time" "github.com/authorizerdev/authorizer/server/db/models" @@ -14,13 +15,19 @@ func (p *provider) UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, if otp.ID == "" { otp.ID = uuid.New().String() } - + // check if email or phone number is present + if otp.Email == "" && otp.PhoneNumber == "" { + return nil, errors.New("email or phone_number is required") + } + uniqueField := models.FieldNameEmail + if otp.Email == "" && otp.PhoneNumber != "" { + uniqueField = models.FieldNamePhoneNumber + } otp.Key = otp.ID otp.CreatedAt = time.Now().Unix() otp.UpdatedAt = time.Now().Unix() - res := p.db.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "email"}}, + Columns: []clause.Column{{Name: uniqueField}}, DoUpdates: clause.AssignmentColumns([]string{"otp", "expires_at", "updated_at"}), }).Create(&otp) if res.Error != nil { @@ -33,7 +40,6 @@ func (p *provider) UpsertOTP(ctx context.Context, otp *models.OTP) (*models.OTP, // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*models.OTP, error) { var otp models.OTP - result := p.db.Where("email = ?", emailAddress).First(&otp) if result.Error != nil { return nil, result.Error @@ -41,6 +47,16 @@ func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*mod return &otp, nil } +// GetOTPByPhoneNumber to get otp for a given phone number +func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*models.OTP, error) { + var otp models.OTP + result := p.db.Where("phone_number = ?", phoneNumber).First(&otp) + if result.Error != nil { + return nil, result.Error + } + return &otp, nil +} + // DeleteOTP to delete otp func (p *provider) DeleteOTP(ctx context.Context, otp *models.OTP) error { result := p.db.Delete(&models.OTP{ diff --git a/server/db/providers/sql/provider.go b/server/db/providers/sql/provider.go index 89ea31b..0512ecf 100644 --- a/server/db/providers/sql/provider.go +++ b/server/db/providers/sql/provider.go @@ -77,7 +77,7 @@ func NewProvider() (*provider, error) { logrus.Debug("Failed to drop phone number constraint:", err) } - err = sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{}, &models.Webhook{}, models.WebhookLog{}, models.EmailTemplate{}, &models.OTP{}, &models.SMSVerificationRequest{}) + err = sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{}, &models.Webhook{}, &models.WebhookLog{}, &models.EmailTemplate{}, &models.OTP{}) if err != nil { return nil, err } diff --git a/server/db/providers/sql/session.go b/server/db/providers/sql/session.go index 0ed7317..a7e3e13 100644 --- a/server/db/providers/sql/session.go +++ b/server/db/providers/sql/session.go @@ -10,7 +10,7 @@ import ( ) // AddSession to save session information in database -func (p *provider) AddSession(ctx context.Context, session models.Session) error { +func (p *provider) AddSession(ctx context.Context, session *models.Session) error { if session.ID == "" { session.ID = uuid.New().String() } @@ -27,3 +27,8 @@ func (p *provider) AddSession(ctx context.Context, session models.Session) error } return nil } + +// DeleteSession to delete session information from database +func (p *provider) DeleteSession(ctx context.Context, userId string) error { + return nil +} diff --git a/server/db/providers/sql/sms_verification_requests.go b/server/db/providers/sql/sms_verification_requests.go deleted file mode 100644 index 5035c54..0000000 --- a/server/db/providers/sql/sms_verification_requests.go +++ /dev/null @@ -1,51 +0,0 @@ -package sql - -import ( - "context" - "time" - - "github.com/authorizerdev/authorizer/server/db/models" - "github.com/google/uuid" - "gorm.io/gorm/clause" -) - -// SMS verification Request -func (p *provider) UpsertSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) (*models.SMSVerificationRequest, error) { - if smsRequest.ID == "" { - smsRequest.ID = uuid.New().String() - } - - smsRequest.CreatedAt = time.Now().Unix() - smsRequest.UpdatedAt = time.Now().Unix() - - res := p.db.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "phone_number"}}, - DoUpdates: clause.AssignmentColumns([]string{"code", "code_expires_at"}), - }).Create(smsRequest) - if res.Error != nil { - return nil, res.Error - } - - return smsRequest, nil -} - -// GetOTPByEmail to get otp for a given email address -func (p *provider) GetCodeByPhone(ctx context.Context, phoneNumber string) (*models.SMSVerificationRequest, error) { - var sms_verification_request models.SMSVerificationRequest - - result := p.db.Where("phone_number = ?", phoneNumber).First(&sms_verification_request) - if result.Error != nil { - return nil, result.Error - } - return &sms_verification_request, nil -} - -func(p *provider) DeleteSMSRequest(ctx context.Context, smsRequest *models.SMSVerificationRequest) error { - result := p.db.Delete(&models.SMSVerificationRequest{ - ID: smsRequest.ID, - }) - if result.Error != nil { - return result.Error - } - return nil -} diff --git a/server/db/providers/sql/user.go b/server/db/providers/sql/user.go index a4b40c0..5243ad6 100644 --- a/server/db/providers/sql/user.go +++ b/server/db/providers/sql/user.go @@ -17,7 +17,7 @@ import ( ) // AddUser to save user information in database -func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) AddUser(ctx context.Context, user *models.User) (*models.User, error) { if user.ID == "" { user.ID = uuid.New().String() } @@ -53,7 +53,7 @@ func (p *provider) AddUser(ctx context.Context, user models.User) (models.User, } // UpdateUser to update user information in database -func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.User, error) { +func (p *provider) UpdateUser(ctx context.Context, user *models.User) (*models.User, error) { user.UpdatedAt = time.Now().Unix() result := p.db.Save(&user) @@ -66,7 +66,7 @@ func (p *provider) UpdateUser(ctx context.Context, user models.User) (models.Use } // DeleteUser to delete user information from database -func (p *provider) DeleteUser(ctx context.Context, user models.User) error { +func (p *provider) DeleteUser(ctx context.Context, user *models.User) error { result := p.db.Where("user_id = ?", user.ID).Delete(&models.Session{}) if result.Error != nil { return result.Error @@ -81,7 +81,7 @@ func (p *provider) DeleteUser(ctx context.Context, user models.User) error { } // ListUsers to get list of users from database -func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) (*model.Users, error) { +func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) (*model.Users, error) { var users []models.User result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&users) if result.Error != nil { @@ -103,31 +103,28 @@ func (p *provider) ListUsers(ctx context.Context, pagination model.Pagination) ( paginationClone.Total = total return &model.Users{ - Pagination: &paginationClone, + Pagination: paginationClone, Users: responseUsers, }, nil } // GetUserByEmail to get user information from database using email address -func (p *provider) GetUserByEmail(ctx context.Context, email string) (models.User, error) { - var user models.User +func (p *provider) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { + var user *models.User result := p.db.Where("email = ?", email).First(&user) if result.Error != nil { return user, result.Error } - return user, nil } // GetUserByID to get user information from database using user ID -func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, error) { - var user models.User - +func (p *provider) GetUserByID(ctx context.Context, id string) (*models.User, error) { + var user *models.User result := p.db.Where("id = ?", id).First(&user) if result.Error != nil { return user, result.Error } - return user, nil } @@ -136,14 +133,12 @@ func (p *provider) GetUserByID(ctx context.Context, id string) (models.User, err func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { // set updated_at time for all users data["updated_at"] = time.Now().Unix() - var res *gorm.DB - if ids != nil && len(ids) > 0 { + if len(ids) > 0 { res = p.db.Model(&models.User{}).Where("id in ?", ids).Updates(data) } else { res = p.db.Model(&models.User{}).Updates(data) } - if res.Error != nil { return res.Error } @@ -154,10 +149,8 @@ func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*models.User, error) { var user *models.User result := p.db.Where("phone_number = ?", phoneNumber).First(&user) - if result.Error != nil { return nil, result.Error } - return user, nil } diff --git a/server/db/providers/sql/verification_requests.go b/server/db/providers/sql/verification_requests.go index 5b413b0..ac91bec 100644 --- a/server/db/providers/sql/verification_requests.go +++ b/server/db/providers/sql/verification_requests.go @@ -11,11 +11,10 @@ import ( ) // AddVerification to save verification request in database -func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) (models.VerificationRequest, error) { +func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) (*models.VerificationRequest, error) { if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() } - verificationRequest.Key = verificationRequest.ID verificationRequest.CreatedAt = time.Now().Unix() verificationRequest.UpdatedAt = time.Now().Unix() @@ -23,75 +22,61 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque Columns: []clause.Column{{Name: "email"}, {Name: "identifier"}}, DoUpdates: clause.AssignmentColumns([]string{"token", "expires_at", "nonce", "redirect_uri"}), }).Create(&verificationRequest) - if result.Error != nil { return verificationRequest, result.Error } - return verificationRequest, nil } // GetVerificationRequestByToken to get verification request from database using token -func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest +func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest result := p.db.Where("token = ?", token).First(&verificationRequest) - if result.Error != nil { return verificationRequest, result.Error } - return verificationRequest, nil } // GetVerificationRequestByEmail to get verification request by email from database -func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (models.VerificationRequest, error) { - var verificationRequest models.VerificationRequest - +func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*models.VerificationRequest, error) { + var verificationRequest *models.VerificationRequest result := p.db.Where("email = ? AND identifier = ?", email, identifier).First(&verificationRequest) - if result.Error != nil { return verificationRequest, result.Error } - return verificationRequest, nil } // ListVerificationRequests to get list of verification requests from database -func (p *provider) ListVerificationRequests(ctx context.Context, pagination model.Pagination) (*model.VerificationRequests, error) { +func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) (*model.VerificationRequests, error) { var verificationRequests []models.VerificationRequest - result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&verificationRequests) if result.Error != nil { return nil, result.Error } - responseVerificationRequests := []*model.VerificationRequest{} for _, v := range verificationRequests { responseVerificationRequests = append(responseVerificationRequests, v.AsAPIVerificationRequest()) } - var total int64 totalRes := p.db.Model(&models.VerificationRequest{}).Count(&total) if totalRes.Error != nil { return nil, totalRes.Error } - paginationClone := pagination paginationClone.Total = total - return &model.VerificationRequests{ VerificationRequests: responseVerificationRequests, - Pagination: &paginationClone, + Pagination: paginationClone, }, nil } // DeleteVerificationRequest to delete verification request from database -func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest models.VerificationRequest) error { +func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *models.VerificationRequest) error { result := p.db.Delete(&verificationRequest) - if result.Error != nil { return result.Error } - return nil } diff --git a/server/db/providers/sql/webhook.go b/server/db/providers/sql/webhook.go index 72f3cb4..54e2d13 100644 --- a/server/db/providers/sql/webhook.go +++ b/server/db/providers/sql/webhook.go @@ -12,7 +12,7 @@ import ( ) // AddWebhook to add webhook -func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) AddWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { if webhook.ID == "" { webhook.ID = uuid.New().String() } @@ -29,7 +29,7 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod } // UpdateWebhook to update webhook -func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { +func (p *provider) UpdateWebhook(ctx context.Context, webhook *models.Webhook) (*model.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() // Event is changed if !strings.Contains(webhook.EventName, "-") { @@ -43,7 +43,7 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (* } // ListWebhooks to list webhook -func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { +func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) (*model.Webhooks, error) { var webhooks []models.Webhook result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhooks) if result.Error != nil { @@ -61,14 +61,14 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) responseWebhooks = append(responseWebhooks, w.AsAPIWebhook()) } return &model.Webhooks{ - Pagination: &paginationClone, + Pagination: paginationClone, Webhooks: responseWebhooks, }, nil } // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { - var webhook models.Webhook + var webhook *models.Webhook result := p.db.Where("id = ?", webhookID).First(&webhook) if result.Error != nil { diff --git a/server/db/providers/sql/webhook_log.go b/server/db/providers/sql/webhook_log.go index 0ccbca2..cf50be2 100644 --- a/server/db/providers/sql/webhook_log.go +++ b/server/db/providers/sql/webhook_log.go @@ -12,7 +12,7 @@ import ( ) // AddWebhookLog to add webhook log -func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookLog) (*model.WebhookLog, error) { +func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *models.WebhookLog) (*model.WebhookLog, error) { if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() } @@ -32,7 +32,7 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog models.WebhookL } // ListWebhookLogs to list webhook logs -func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Pagination, webhookID string) (*model.WebhookLogs, error) { +func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) (*model.WebhookLogs, error) { var webhookLogs []models.WebhookLog var result *gorm.DB var totalRes *gorm.DB @@ -63,6 +63,6 @@ func (p *provider) ListWebhookLogs(ctx context.Context, pagination model.Paginat } return &model.WebhookLogs{ WebhookLogs: responseWebhookLogs, - Pagination: &paginationClone, + Pagination: paginationClone, }, nil } diff --git a/server/email/email.go b/server/email/email.go index 1b7d84d..e7222ac 100644 --- a/server/email/email.go +++ b/server/email/email.go @@ -72,7 +72,6 @@ func getEmailTemplate(event string, data map[string]interface{}) (*model.EmailTe return nil, err } subjectString := buf.String() - return &model.EmailTemplate{ Template: templateString, Subject: subjectString, diff --git a/server/env/env.go b/server/env/env.go index f61b093..a4bcfcc 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -19,7 +19,7 @@ import ( // InitEnv to initialize EnvData and through error if required env are not present func InitAllEnv() error { envData, err := GetEnvData() - if err != nil { + if err != nil || envData == nil { log.Info("No env data found in db, using local clone of env data") // get clone of current store envData, err = memorystore.Provider.GetEnvStore() @@ -104,6 +104,13 @@ func InitAllEnv() error { osDisableStrongPassword := os.Getenv(constants.EnvKeyDisableStrongPassword) osEnforceMultiFactorAuthentication := os.Getenv(constants.EnvKeyEnforceMultiFactorAuthentication) osDisableMultiFactorAuthentication := os.Getenv(constants.EnvKeyDisableMultiFactorAuthentication) + // phone verification var + osDisablePhoneVerification := os.Getenv(constants.EnvKeyDisablePhoneVerification) + // twilio vars + osTwilioApiKey := os.Getenv(constants.EnvKeyTwilioAPIKey) + osTwilioApiSecret := os.Getenv(constants.EnvKeyTwilioAPISecret) + osTwilioAccountSid := os.Getenv(constants.EnvKeyTwilioAccountSID) + osTwilioSender := os.Getenv(constants.EnvKeyTwilioSender) // os slice vars osAllowedOrigins := os.Getenv(constants.EnvKeyAllowedOrigins) @@ -111,15 +118,6 @@ func InitAllEnv() error { osDefaultRoles := os.Getenv(constants.EnvKeyDefaultRoles) osProtectedRoles := os.Getenv(constants.EnvKeyProtectedRoles) - // phone verification var - osDisablePhoneVerification := os.Getenv(constants.EnvKeyDisablePhoneVerification) - - // twilio vars - osTwilioApiKey := os.Getenv(constants.EnvKeyTwilioAPIKey) - osTwilioApiSecret := os.Getenv(constants.EnvKeyTwilioAPISecret) - osTwilioAccountSid := os.Getenv(constants.EnvKeyTwilioAccountSID) - osTwilioSenderFrom := os.Getenv(constants.EnvKeyTwilioSenderFrom) - ienv, ok := envData[constants.EnvKeyEnv] if !ok || ienv == "" { envData[constants.EnvKeyEnv] = osEnv @@ -145,7 +143,7 @@ func InitAllEnv() error { if val, ok := envData[constants.EnvAwsRegion]; !ok || val == "" { envData[constants.EnvAwsRegion] = osAwsRegion } - + if osAwsRegion != "" && envData[constants.EnvAwsRegion] != osAwsRegion { envData[constants.EnvAwsRegion] = osAwsRegion } @@ -691,11 +689,11 @@ func InitAllEnv() error { envData[constants.EnvKeyIsEmailServiceEnabled] = false } - if envData[constants.EnvKeySmtpHost] != "" || envData[constants.EnvKeySmtpUsername] != "" || envData[constants.EnvKeySmtpPassword] != "" || envData[constants.EnvKeySenderEmail] != "" && envData[constants.EnvKeySmtpPort] != "" { + if envData[constants.EnvKeySmtpHost] != "" && envData[constants.EnvKeySmtpUsername] != "" && envData[constants.EnvKeySmtpPassword] != "" && envData[constants.EnvKeySenderEmail] != "" && envData[constants.EnvKeySmtpPort] != "" { envData[constants.EnvKeyIsEmailServiceEnabled] = true } - if envData[constants.EnvKeyEnforceMultiFactorAuthentication].(bool) && !envData[constants.EnvKeyIsEmailServiceEnabled].(bool) { + if envData[constants.EnvKeyEnforceMultiFactorAuthentication].(bool) && !envData[constants.EnvKeyIsEmailServiceEnabled].(bool) && !envData[constants.EnvKeyIsSMSServiceEnabled].(bool) { return errors.New("to enable multi factor authentication, please enable email service") } @@ -777,29 +775,39 @@ func InitAllEnv() error { envData[constants.EnvKeyDefaultAuthorizeResponseMode] = osAuthorizeResponseMode } + if val, ok := envData[constants.EnvKeyTwilioAPISecret]; !ok || val == "" { + envData[constants.EnvKeyTwilioAPISecret] = osTwilioApiSecret + } if osTwilioApiSecret != "" && envData[constants.EnvKeyTwilioAPISecret] != osTwilioApiSecret { envData[constants.EnvKeyTwilioAPISecret] = osTwilioApiSecret } + if val, ok := envData[constants.EnvKeyTwilioAPIKey]; !ok || val == "" { + envData[constants.EnvKeyTwilioAPIKey] = osTwilioApiKey + } if osTwilioApiKey != "" && envData[constants.EnvKeyTwilioAPIKey] != osTwilioApiKey { envData[constants.EnvKeyTwilioAPIKey] = osTwilioApiKey } + if val, ok := envData[constants.EnvKeyTwilioAccountSID]; !ok || val == "" { + envData[constants.EnvKeyTwilioAccountSID] = osTwilioAccountSid + } if osTwilioAccountSid != "" && envData[constants.EnvKeyTwilioAccountSID] != osTwilioAccountSid { envData[constants.EnvKeyTwilioAccountSID] = osTwilioAccountSid } - if osTwilioSenderFrom != "" && envData[constants.EnvKeyTwilioSenderFrom] != osTwilioSenderFrom { - envData[constants.EnvKeyTwilioSenderFrom] = osTwilioSenderFrom + if val, ok := envData[constants.EnvKeyTwilioSender]; !ok || val == "" { + envData[constants.EnvKeyTwilioSender] = osTwilioSender + } + if osTwilioSender != "" && envData[constants.EnvKeyTwilioSender] != osTwilioSender { + envData[constants.EnvKeyTwilioSender] = osTwilioSender } if _, ok := envData[constants.EnvKeyDisablePhoneVerification]; !ok { envData[constants.EnvKeyDisablePhoneVerification] = osDisablePhoneVerification == "false" } - if osDisablePhoneVerification != "" { boolValue, err := strconv.ParseBool(osDisablePhoneVerification) - if err != nil { return err } @@ -808,6 +816,15 @@ func InitAllEnv() error { } } + if envData[constants.EnvKeyTwilioAPIKey] == "" || envData[constants.EnvKeyTwilioAPISecret] == "" || envData[constants.EnvKeyTwilioAccountSID] == "" || envData[constants.EnvKeyTwilioSender] == "" { + envData[constants.EnvKeyDisablePhoneVerification] = true + envData[constants.EnvKeyIsSMSServiceEnabled] = false + } + if envData[constants.EnvKeyTwilioAPIKey] != "" && envData[constants.EnvKeyTwilioAPISecret] != "" && envData[constants.EnvKeyTwilioAccountSID] != "" && envData[constants.EnvKeyTwilioSender] != "" { + envData[constants.EnvKeyDisablePhoneVerification] = false + envData[constants.EnvKeyIsSMSServiceEnabled] = true + } + err = memorystore.Provider.UpdateEnvStore(envData) if err != nil { log.Debug("Error while updating env store: ", err) diff --git a/server/env/persist_env.go b/server/env/persist_env.go index 8ba33f2..83b5857 100644 --- a/server/env/persist_env.go +++ b/server/env/persist_env.go @@ -62,7 +62,7 @@ func GetEnvData() (map[string]interface{}, error) { ctx := context.Background() env, err := db.Provider.GetEnv(ctx) // config not found in db - if err != nil { + if err != nil || env == nil { log.Debug("Error while getting env data from db: ", err) return result, err } @@ -112,7 +112,7 @@ func PersistEnv() error { ctx := context.Background() env, err := db.Provider.GetEnv(ctx) // config not found in db - if err != nil || env.EnvData == "" { + if err != nil || env == nil { // AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid hash := uuid.New().String()[:36-4] err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash) @@ -121,25 +121,21 @@ func PersistEnv() error { return err } encodedHash := crypto.EncryptB64(hash) - res, err := memorystore.Provider.GetEnvStore() if err != nil { log.Debug("Error while getting env store: ", err) return err } - encryptedConfig, err := crypto.EncryptEnvData(res) if err != nil { log.Debug("Error while encrypting env data: ", err) return err } - - env = models.Env{ + env = &models.Env{ Hash: encodedHash, EnvData: encryptedConfig, } - - env, err = db.Provider.AddEnv(ctx, env) + _, err = db.Provider.AddEnv(ctx, env) if err != nil { log.Debug("Error while persisting env data to db: ", err) return err @@ -200,7 +196,7 @@ func PersistEnv() error { envValue := strings.TrimSpace(os.Getenv(key)) if envValue != "" { switch key { - case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableMobileBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv, constants.EnvKeyDisableStrongPassword, constants.EnvKeyIsEmailServiceEnabled, constants.EnvKeyEnforceMultiFactorAuthentication, constants.EnvKeyDisableMultiFactorAuthentication, constants.EnvKeyAdminCookieSecure, constants.EnvKeyAppCookieSecure, constants.EnvKeyDisablePhoneVerification: + case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableMobileBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv, constants.EnvKeyDisableStrongPassword, constants.EnvKeyIsEmailServiceEnabled, constants.EnvKeyIsSMSServiceEnabled, constants.EnvKeyEnforceMultiFactorAuthentication, constants.EnvKeyDisableMultiFactorAuthentication, constants.EnvKeyAdminCookieSecure, constants.EnvKeyAppCookieSecure, constants.EnvKeyDisablePhoneVerification: if envValueBool, err := strconv.ParseBool(envValue); err == nil { if value.(bool) != envValueBool { storeData[key] = envValueBool diff --git a/server/go.mod b/server/go.mod index 3408404..d1747e2 100644 --- a/server/go.mod +++ b/server/go.mod @@ -5,7 +5,7 @@ go 1.16 require ( github.com/99designs/gqlgen v0.17.20 github.com/arangodb/go-driver v1.2.1 - github.com/aws/aws-sdk-go v1.44.109 + github.com/aws/aws-sdk-go v1.44.298 github.com/coreos/go-oidc/v3 v3.1.0 github.com/couchbase/gocb/v2 v2.6.0 github.com/gin-gonic/gin v1.8.1 @@ -17,7 +17,7 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.6 // indirect github.com/google/uuid v1.3.0 - github.com/guregu/dynamo v1.16.0 + github.com/guregu/dynamo v1.20.0 github.com/joho/godotenv v1.3.0 github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/pelletier/go-toml/v2 v2.0.5 // indirect @@ -30,7 +30,7 @@ require ( go.mongodb.org/mongo-driver v1.8.1 golang.org/x/crypto v0.4.0 golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914 - google.golang.org/appengine v1.6.7 // indirect + google.golang.org/appengine v1.6.7 google.golang.org/protobuf v1.28.1 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/mail.v2 v2.3.1 diff --git a/server/go.sum b/server/go.sum index 224e06d..4a2c928 100644 --- a/server/go.sum +++ b/server/go.sum @@ -51,9 +51,8 @@ github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e h1:Xg+hGrY2 github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e/go.mod h1:mq7Shfa/CaixoDxiyAAc5jZ6CVBAyPaNQCGS7mkj4Ho= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= -github.com/aws/aws-sdk-go v1.42.47/go.mod h1:OGr6lGMAKGlG9CVrYnWYDKIyb829c6EVBRjxqjmPepc= -github.com/aws/aws-sdk-go v1.44.109 h1:+Na5JPeS0kiEHoBp5Umcuuf+IDqXqD0lXnM920E31YI= -github.com/aws/aws-sdk-go v1.44.109/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= +github.com/aws/aws-sdk-go v1.44.298 h1:5qTxdubgV7PptZJmp/2qDwD2JL187ePL7VOxsSh1i3g= +github.com/aws/aws-sdk-go v1.44.298/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= @@ -63,8 +62,8 @@ github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/cenkalti/backoff/v4 v4.1.2 h1:6Yo7N8UP2K6LWZnW94DLVSSrbobcWdVzAYOisuDPIFo= -github.com/cenkalti/backoff/v4 v4.1.2/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= +github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= +github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -130,8 +129,6 @@ github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gocql/gocql v1.2.0 h1:TZhsCd7fRuye4VyHr3WCvWwIQaZUmjsqnSIXK9FcVCE= github.com/gocql/gocql v1.2.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= -github.com/gofrs/uuid v4.2.0+incompatible h1:yyYWMnhkhrKwwr8gAOcOCYxOOscHgDS9yZgBrnJfGa0= -github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= @@ -206,8 +203,8 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/guregu/dynamo v1.16.0 h1:gmI8oi1VHwYQtq7+RPBeOiSssVLgxH/Az2t+NtDtL2c= -github.com/guregu/dynamo v1.16.0/go.mod h1:W2Gqcf3MtkrS+Q6fHPGAmRtT0Dyq+TGrqfqrUC9+R/c= +github.com/guregu/dynamo v1.20.0 h1:PDdVVhRSXQFFIHlkhoKF6D8kiwI9IU6uUdz/fF6Iiy4= +github.com/guregu/dynamo v1.20.0/go.mod h1:YQ92BTYVSMIKpFEzhaVqmCJnnSIGxbNF5zvECUaEZRE= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -438,12 +435,12 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211216030914-fe4d6282115f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.3.0 h1:VWL6FNY2bEEmsGVKabSlHu5Irp34xmMRoqb/9lF9lxk= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -462,8 +459,9 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7 h1:ZrnxWX62AgTKOSagEqxvb3ffipvEDX2pl7E1TdqLqIc= golang.org/x/sync v0.0.0-20220923202941-7f9b1623fab7/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -508,12 +506,16 @@ golang.org/x/sys v0.0.0-20220224120231-95c6836cb0e7/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -523,8 +525,10 @@ golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index e849306..77eba10 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -45,13 +45,14 @@ type DirectiveRoot struct { type ComplexityRoot struct { AuthResponse struct { - AccessToken func(childComplexity int) int - ExpiresIn func(childComplexity int) int - IDToken func(childComplexity int) int - Message func(childComplexity int) int - RefreshToken func(childComplexity int) int - ShouldShowOtpScreen func(childComplexity int) int - User func(childComplexity int) int + AccessToken func(childComplexity int) int + ExpiresIn func(childComplexity int) int + IDToken func(childComplexity int) int + Message func(childComplexity int) int + RefreshToken func(childComplexity int) int + ShouldShowEmailOtpScreen func(childComplexity int) int + ShouldShowMobileOtpScreen func(childComplexity int) int + User func(childComplexity int) int } EmailTemplate struct { @@ -198,7 +199,6 @@ type ComplexityRoot struct { UpdateUser func(childComplexity int, params model.UpdateUserInput) int UpdateWebhook func(childComplexity int, params model.UpdateWebhookRequest) int VerifyEmail func(childComplexity int, params model.VerifyEmailInput) int - VerifyMobile func(childComplexity int, params model.VerifyMobileRequest) int VerifyOtp func(childComplexity int, params model.VerifyOTPRequest) int } @@ -219,6 +219,7 @@ type ComplexityRoot struct { User func(childComplexity int, params model.GetUserRequest) int Users func(childComplexity int, params *model.PaginatedInput) int ValidateJwtToken func(childComplexity int, params model.ValidateJWTTokenInput) int + ValidateSession func(childComplexity int, params *model.ValidateSessionInput) int VerificationRequests func(childComplexity int, params *model.PaginatedInput) int Webhook func(childComplexity int, params model.WebhookRequest) int WebhookLogs func(childComplexity int, params *model.ListWebhookLogRequest) int @@ -244,6 +245,7 @@ type ComplexityRoot struct { } User struct { + AppData func(childComplexity int) int Birthdate func(childComplexity int) int CreatedAt func(childComplexity int) int Email func(childComplexity int) int @@ -275,6 +277,11 @@ type ComplexityRoot struct { IsValid func(childComplexity int) int } + ValidateSessionResponse struct { + IsValid func(childComplexity int) int + User func(childComplexity int) int + } + VerificationRequest struct { CreatedAt func(childComplexity int) int Email func(childComplexity int) int @@ -339,7 +346,6 @@ type MutationResolver interface { Revoke(ctx context.Context, params model.OAuthRevokeInput) (*model.Response, error) VerifyOtp(ctx context.Context, params model.VerifyOTPRequest) (*model.AuthResponse, error) ResendOtp(ctx context.Context, params model.ResendOTPRequest) (*model.Response, error) - VerifyMobile(ctx context.Context, params model.VerifyMobileRequest) (*model.AuthResponse, error) DeleteUser(ctx context.Context, params model.DeleteUserInput) (*model.Response, error) UpdateUser(ctx context.Context, params model.UpdateUserInput) (*model.User, error) AdminSignup(ctx context.Context, params model.AdminSignupInput) (*model.Response, error) @@ -363,6 +369,7 @@ type QueryResolver interface { Session(ctx context.Context, params *model.SessionQueryInput) (*model.AuthResponse, error) Profile(ctx context.Context) (*model.User, error) ValidateJwtToken(ctx context.Context, params model.ValidateJWTTokenInput) (*model.ValidateJWTTokenResponse, error) + ValidateSession(ctx context.Context, params *model.ValidateSessionInput) (*model.ValidateSessionResponse, error) Users(ctx context.Context, params *model.PaginatedInput) (*model.Users, error) User(ctx context.Context, params model.GetUserRequest) (*model.User, error) VerificationRequests(ctx context.Context, params *model.PaginatedInput) (*model.VerificationRequests, error) @@ -424,12 +431,19 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.AuthResponse.RefreshToken(childComplexity), true - case "AuthResponse.should_show_otp_screen": - if e.complexity.AuthResponse.ShouldShowOtpScreen == nil { + case "AuthResponse.should_show_email_otp_screen": + if e.complexity.AuthResponse.ShouldShowEmailOtpScreen == nil { break } - return e.complexity.AuthResponse.ShouldShowOtpScreen(childComplexity), true + return e.complexity.AuthResponse.ShouldShowEmailOtpScreen(childComplexity), true + + case "AuthResponse.should_show_mobile_otp_screen": + if e.complexity.AuthResponse.ShouldShowMobileOtpScreen == nil { + break + } + + return e.complexity.AuthResponse.ShouldShowMobileOtpScreen(childComplexity), true case "AuthResponse.user": if e.complexity.AuthResponse.User == nil { @@ -1432,18 +1446,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.VerifyEmail(childComplexity, args["params"].(model.VerifyEmailInput)), true - case "Mutation.verify_mobile": - if e.complexity.Mutation.VerifyMobile == nil { - break - } - - args, err := ec.field_Mutation_verify_mobile_args(context.TODO(), rawArgs) - if err != nil { - return 0, false - } - - return e.complexity.Mutation.VerifyMobile(childComplexity, args["params"].(model.VerifyMobileRequest)), true - case "Mutation.verify_otp": if e.complexity.Mutation.VerifyOtp == nil { break @@ -1572,6 +1574,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.ValidateJwtToken(childComplexity, args["params"].(model.ValidateJWTTokenInput)), true + case "Query.validate_session": + if e.complexity.Query.ValidateSession == nil { + break + } + + args, err := ec.field_Query_validate_session_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.ValidateSession(childComplexity, args["params"].(*model.ValidateSessionInput)), true + case "Query._verification_requests": if e.complexity.Query.VerificationRequests == nil { break @@ -1683,6 +1697,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.TestEndpointResponse.Response(childComplexity), true + case "User.app_data": + if e.complexity.User.AppData == nil { + break + } + + return e.complexity.User.AppData(childComplexity), true + case "User.birthdate": if e.complexity.User.Birthdate == nil { break @@ -1844,6 +1865,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.ValidateJWTTokenResponse.IsValid(childComplexity), true + case "ValidateSessionResponse.is_valid": + if e.complexity.ValidateSessionResponse.IsValid == nil { + break + } + + return e.complexity.ValidateSessionResponse.IsValid(childComplexity), true + + case "ValidateSessionResponse.user": + if e.complexity.ValidateSessionResponse.User == nil { + break + } + + return e.complexity.ValidateSessionResponse.User(childComplexity), true + case "VerificationRequest.created_at": if e.complexity.VerificationRequest.CreatedAt == nil { break @@ -2093,8 +2128,8 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { ec.unmarshalInputUpdateUserInput, ec.unmarshalInputUpdateWebhookRequest, ec.unmarshalInputValidateJWTTokenInput, + ec.unmarshalInputValidateSessionInput, ec.unmarshalInputVerifyEmailInput, - ec.unmarshalInputVerifyMobileRequest, ec.unmarshalInputVerifyOTPRequest, ec.unmarshalInputWebhookRequest, ) @@ -2210,6 +2245,7 @@ type User { updated_at: Int64 revoked_timestamp: Int64 is_multi_factor_auth_enabled: Boolean + app_data: Map } type Users { @@ -2243,11 +2279,6 @@ type SMSVerificationRequests { updated_at: Int64 } -input VerifyMobileRequest { - phone_number: String! - code: String! -} - type Error { message: String! reason: String! @@ -2255,7 +2286,8 @@ type Error { type AuthResponse { message: String! - should_show_otp_screen: Boolean + should_show_email_otp_screen: Boolean + should_show_mobile_otp_screen: Boolean access_token: String id_token: String refresh_token: String @@ -2341,6 +2373,11 @@ type ValidateJWTTokenResponse { claims: Map } +type ValidateSessionResponse { + is_valid: Boolean! + user: User! +} + type GenerateJWTKeysResponse { secret: String public_key: String @@ -2481,6 +2518,7 @@ input MobileSignUpInput { # it is used to get code for an on-going auth process during login # and use that code for setting ` + "`" + `c_hash` + "`" + ` in id_token state: String + app_data: Map } input SignUpInput { @@ -2503,6 +2541,7 @@ input SignUpInput { # it is used to get code for an on-going auth process during login # and use that code for setting ` + "`" + `c_hash` + "`" + ` in id_token state: String + app_data: Map } input LoginInput { @@ -2558,6 +2597,7 @@ input UpdateProfileInput { phone_number: String picture: String is_multi_factor_auth_enabled: Boolean + app_data: Map } input UpdateUserInput { @@ -2574,6 +2614,7 @@ input UpdateUserInput { picture: String roles: [String] is_multi_factor_auth_enabled: Boolean + app_data: Map } input ForgotPasswordInput { @@ -2633,6 +2674,11 @@ input ValidateJWTTokenInput { roles: [String!] } +input ValidateSessionInput { + cookie: String! + roles: [String!] +} + input GenerateJWTKeysInput { type: String! } @@ -2666,6 +2712,7 @@ input WebhookRequest { input TestEndpointRequest { endpoint: String! event_name: String! + event_description: String headers: Map } @@ -2693,7 +2740,9 @@ input DeleteEmailTemplateRequest { } input VerifyOTPRequest { - email: String! + # either email or phone_number is required + email: String + phone_number: String otp: String! # state is used for authorization code grant flow # it is used to get code for an on-going auth process during login @@ -2702,7 +2751,8 @@ input VerifyOTPRequest { } input ResendOTPRequest { - email: String! + email: String + phone_number: String # state is used for authorization code grant flow # it is used to get code for an on-going auth process during login # and use that code for setting ` + "`" + `c_hash` + "`" + ` in id_token @@ -2729,7 +2779,6 @@ type Mutation { revoke(params: OAuthRevokeInput!): Response! verify_otp(params: VerifyOTPRequest!): AuthResponse! resend_otp(params: ResendOTPRequest!): Response! - verify_mobile(params: VerifyMobileRequest!): AuthResponse! # admin only apis _delete_user(params: DeleteUserInput!): Response! _update_user(params: UpdateUserInput!): User! @@ -2755,6 +2804,7 @@ type Query { session(params: SessionQueryInput): AuthResponse! profile: User! validate_jwt_token(params: ValidateJWTTokenInput!): ValidateJWTTokenResponse! + validate_session(params: ValidateSessionInput): ValidateSessionResponse! # admin only apis _users(params: PaginatedInput): Users! _user(params: GetUserRequest!): User! @@ -3194,21 +3244,6 @@ func (ec *executionContext) field_Mutation_verify_email_args(ctx context.Context return args, nil } -func (ec *executionContext) field_Mutation_verify_mobile_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { - var err error - args := map[string]interface{}{} - var arg0 model.VerifyMobileRequest - if tmp, ok := rawArgs["params"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("params")) - arg0, err = ec.unmarshalNVerifyMobileRequest2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐVerifyMobileRequest(ctx, tmp) - if err != nil { - return nil, err - } - } - args["params"] = arg0 - return args, nil -} - func (ec *executionContext) field_Mutation_verify_otp_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -3374,6 +3409,21 @@ func (ec *executionContext) field_Query_validate_jwt_token_args(ctx context.Cont return args, nil } +func (ec *executionContext) field_Query_validate_session_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *model.ValidateSessionInput + if tmp, ok := rawArgs["params"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("params")) + arg0, err = ec.unmarshalOValidateSessionInput2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionInput(ctx, tmp) + if err != nil { + return nil, err + } + } + args["params"] = arg0 + return args, nil +} + func (ec *executionContext) field___Type_enumValues_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -3456,8 +3506,8 @@ func (ec *executionContext) fieldContext_AuthResponse_message(ctx context.Contex return fc, nil } -func (ec *executionContext) _AuthResponse_should_show_otp_screen(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) +func (ec *executionContext) _AuthResponse_should_show_email_otp_screen(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) if err != nil { return graphql.Null } @@ -3470,7 +3520,7 @@ func (ec *executionContext) _AuthResponse_should_show_otp_screen(ctx context.Con }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.ShouldShowOtpScreen, nil + return obj.ShouldShowEmailOtpScreen, nil }) if err != nil { ec.Error(ctx, err) @@ -3484,7 +3534,48 @@ func (ec *executionContext) _AuthResponse_should_show_otp_screen(ctx context.Con return ec.marshalOBoolean2ᚖbool(ctx, field.Selections, res) } -func (ec *executionContext) fieldContext_AuthResponse_should_show_otp_screen(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { +func (ec *executionContext) fieldContext_AuthResponse_should_show_email_otp_screen(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "AuthResponse", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Boolean does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _AuthResponse_should_show_mobile_otp_screen(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ShouldShowMobileOtpScreen, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*bool) + fc.Result = res + return ec.marshalOBoolean2ᚖbool(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { fc = &graphql.FieldContext{ Object: "AuthResponse", Field: field, @@ -3735,6 +3826,8 @@ func (ec *executionContext) fieldContext_AuthResponse_user(ctx context.Context, return ec.fieldContext_User_revoked_timestamp(ctx, field) case "is_multi_factor_auth_enabled": return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field) + case "app_data": + return ec.fieldContext_User_app_data(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -7030,6 +7123,8 @@ func (ec *executionContext) fieldContext_InviteMembersResponse_Users(ctx context return ec.fieldContext_User_revoked_timestamp(ctx, field) case "is_multi_factor_auth_enabled": return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field) + case "app_data": + return ec.fieldContext_User_app_data(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -7738,8 +7833,10 @@ func (ec *executionContext) fieldContext_Mutation_signup(ctx context.Context, fi switch field.Name { case "message": return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) + case "should_show_email_otp_screen": + return ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) + case "should_show_mobile_otp_screen": + return ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) case "access_token": return ec.fieldContext_AuthResponse_access_token(ctx, field) case "id_token": @@ -7809,8 +7906,10 @@ func (ec *executionContext) fieldContext_Mutation_mobile_signup(ctx context.Cont switch field.Name { case "message": return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) + case "should_show_email_otp_screen": + return ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) + case "should_show_mobile_otp_screen": + return ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) case "access_token": return ec.fieldContext_AuthResponse_access_token(ctx, field) case "id_token": @@ -7880,8 +7979,10 @@ func (ec *executionContext) fieldContext_Mutation_login(ctx context.Context, fie switch field.Name { case "message": return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) + case "should_show_email_otp_screen": + return ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) + case "should_show_mobile_otp_screen": + return ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) case "access_token": return ec.fieldContext_AuthResponse_access_token(ctx, field) case "id_token": @@ -7951,8 +8052,10 @@ func (ec *executionContext) fieldContext_Mutation_mobile_login(ctx context.Conte switch field.Name { case "message": return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) + case "should_show_email_otp_screen": + return ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) + case "should_show_mobile_otp_screen": + return ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) case "access_token": return ec.fieldContext_AuthResponse_access_token(ctx, field) case "id_token": @@ -8188,8 +8291,10 @@ func (ec *executionContext) fieldContext_Mutation_verify_email(ctx context.Conte switch field.Name { case "message": return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) + case "should_show_email_otp_screen": + return ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) + case "should_show_mobile_otp_screen": + return ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) case "access_token": return ec.fieldContext_AuthResponse_access_token(ctx, field) case "id_token": @@ -8495,8 +8600,10 @@ func (ec *executionContext) fieldContext_Mutation_verify_otp(ctx context.Context switch field.Name { case "message": return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) + case "should_show_email_otp_screen": + return ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) + case "should_show_mobile_otp_screen": + return ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) case "access_token": return ec.fieldContext_AuthResponse_access_token(ctx, field) case "id_token": @@ -8584,77 +8691,6 @@ func (ec *executionContext) fieldContext_Mutation_resend_otp(ctx context.Context return fc, nil } -func (ec *executionContext) _Mutation_verify_mobile(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { - fc, err := ec.fieldContext_Mutation_verify_mobile(ctx, field) - if err != nil { - return graphql.Null - } - ctx = graphql.WithFieldContext(ctx, fc) - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return ec.resolvers.Mutation().VerifyMobile(rctx, fc.Args["params"].(model.VerifyMobileRequest)) - }) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } - return graphql.Null - } - res := resTmp.(*model.AuthResponse) - fc.Result = res - return ec.marshalNAuthResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐAuthResponse(ctx, field.Selections, res) -} - -func (ec *executionContext) fieldContext_Mutation_verify_mobile(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { - fc = &graphql.FieldContext{ - Object: "Mutation", - Field: field, - IsMethod: true, - IsResolver: true, - Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { - switch field.Name { - case "message": - return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) - case "access_token": - return ec.fieldContext_AuthResponse_access_token(ctx, field) - case "id_token": - return ec.fieldContext_AuthResponse_id_token(ctx, field) - case "refresh_token": - return ec.fieldContext_AuthResponse_refresh_token(ctx, field) - case "expires_in": - return ec.fieldContext_AuthResponse_expires_in(ctx, field) - case "user": - return ec.fieldContext_AuthResponse_user(ctx, field) - } - return nil, fmt.Errorf("no field named %q was found under type AuthResponse", field.Name) - }, - } - defer func() { - if r := recover(); r != nil { - err = ec.Recover(ctx, r) - ec.Error(ctx, err) - } - }() - ctx = graphql.WithFieldContext(ctx, fc) - if fc.Args, err = ec.field_Mutation_verify_mobile_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { - ec.Error(ctx, err) - return - } - return fc, nil -} - func (ec *executionContext) _Mutation__delete_user(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Mutation__delete_user(ctx, field) if err != nil { @@ -8791,6 +8827,8 @@ func (ec *executionContext) fieldContext_Mutation__update_user(ctx context.Conte return ec.fieldContext_User_revoked_timestamp(ctx, field) case "is_multi_factor_auth_enabled": return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field) + case "app_data": + return ec.fieldContext_User_app_data(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -9984,8 +10022,10 @@ func (ec *executionContext) fieldContext_Query_session(ctx context.Context, fiel switch field.Name { case "message": return ec.fieldContext_AuthResponse_message(ctx, field) - case "should_show_otp_screen": - return ec.fieldContext_AuthResponse_should_show_otp_screen(ctx, field) + case "should_show_email_otp_screen": + return ec.fieldContext_AuthResponse_should_show_email_otp_screen(ctx, field) + case "should_show_mobile_otp_screen": + return ec.fieldContext_AuthResponse_should_show_mobile_otp_screen(ctx, field) case "access_token": return ec.fieldContext_AuthResponse_access_token(ctx, field) case "id_token": @@ -10091,6 +10131,8 @@ func (ec *executionContext) fieldContext_Query_profile(ctx context.Context, fiel return ec.fieldContext_User_revoked_timestamp(ctx, field) case "is_multi_factor_auth_enabled": return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field) + case "app_data": + return ec.fieldContext_User_app_data(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -10159,6 +10201,67 @@ func (ec *executionContext) fieldContext_Query_validate_jwt_token(ctx context.Co return fc, nil } +func (ec *executionContext) _Query_validate_session(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_validate_session(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().ValidateSession(rctx, fc.Args["params"].(*model.ValidateSessionInput)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.ValidateSessionResponse) + fc.Result = res + return ec.marshalNValidateSessionResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionResponse(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Query_validate_session(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "is_valid": + return ec.fieldContext_ValidateSessionResponse_is_valid(ctx, field) + case "user": + return ec.fieldContext_ValidateSessionResponse_user(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type ValidateSessionResponse", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_validate_session_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return + } + return fc, nil +} + func (ec *executionContext) _Query__users(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Query__users(ctx, field) if err != nil { @@ -10297,6 +10400,8 @@ func (ec *executionContext) fieldContext_Query__user(ctx context.Context, field return ec.fieldContext_User_revoked_timestamp(ctx, field) case "is_multi_factor_auth_enabled": return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field) + case "app_data": + return ec.fieldContext_User_app_data(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -12158,6 +12263,47 @@ func (ec *executionContext) fieldContext_User_is_multi_factor_auth_enabled(ctx c return fc, nil } +func (ec *executionContext) _User_app_data(ctx context.Context, field graphql.CollectedField, obj *model.User) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_User_app_data(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.AppData, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(map[string]interface{}) + fc.Result = res + return ec.marshalOMap2map(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_User_app_data(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "User", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Map does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _Users_pagination(ctx context.Context, field graphql.CollectedField, obj *model.Users) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Users_pagination(ctx, field) if err != nil { @@ -12289,6 +12435,8 @@ func (ec *executionContext) fieldContext_Users_users(ctx context.Context, field return ec.fieldContext_User_revoked_timestamp(ctx, field) case "is_multi_factor_auth_enabled": return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field) + case "app_data": + return ec.fieldContext_User_app_data(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type User", field.Name) }, @@ -12381,6 +12529,136 @@ func (ec *executionContext) fieldContext_ValidateJWTTokenResponse_claims(ctx con return fc, nil } +func (ec *executionContext) _ValidateSessionResponse_is_valid(ctx context.Context, field graphql.CollectedField, obj *model.ValidateSessionResponse) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ValidateSessionResponse_is_valid(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.IsValid, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ValidateSessionResponse_is_valid(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ValidateSessionResponse", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Boolean does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _ValidateSessionResponse_user(ctx context.Context, field graphql.CollectedField, obj *model.ValidateSessionResponse) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_ValidateSessionResponse_user(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.User, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(*model.User) + fc.Result = res + return ec.marshalNUser2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUser(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_ValidateSessionResponse_user(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "ValidateSessionResponse", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "id": + return ec.fieldContext_User_id(ctx, field) + case "email": + return ec.fieldContext_User_email(ctx, field) + case "email_verified": + return ec.fieldContext_User_email_verified(ctx, field) + case "signup_methods": + return ec.fieldContext_User_signup_methods(ctx, field) + case "given_name": + return ec.fieldContext_User_given_name(ctx, field) + case "family_name": + return ec.fieldContext_User_family_name(ctx, field) + case "middle_name": + return ec.fieldContext_User_middle_name(ctx, field) + case "nickname": + return ec.fieldContext_User_nickname(ctx, field) + case "preferred_username": + return ec.fieldContext_User_preferred_username(ctx, field) + case "gender": + return ec.fieldContext_User_gender(ctx, field) + case "birthdate": + return ec.fieldContext_User_birthdate(ctx, field) + case "phone_number": + return ec.fieldContext_User_phone_number(ctx, field) + case "phone_number_verified": + return ec.fieldContext_User_phone_number_verified(ctx, field) + case "picture": + return ec.fieldContext_User_picture(ctx, field) + case "roles": + return ec.fieldContext_User_roles(ctx, field) + case "created_at": + return ec.fieldContext_User_created_at(ctx, field) + case "updated_at": + return ec.fieldContext_User_updated_at(ctx, field) + case "revoked_timestamp": + return ec.fieldContext_User_revoked_timestamp(ctx, field) + case "is_multi_factor_auth_enabled": + return ec.fieldContext_User_is_multi_factor_auth_enabled(ctx, field) + case "app_data": + return ec.fieldContext_User_app_data(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type User", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { fc, err := ec.fieldContext_VerificationRequest_id(ctx, field) if err != nil { @@ -16086,7 +16364,7 @@ func (ec *executionContext) unmarshalInputMobileSignUpInput(ctx context.Context, asMap[k] = v } - fieldsInOrder := [...]string{"email", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "password", "confirm_password", "roles", "scope", "redirect_uri", "is_multi_factor_auth_enabled", "state"} + fieldsInOrder := [...]string{"email", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "password", "confirm_password", "roles", "scope", "redirect_uri", "is_multi_factor_auth_enabled", "state", "app_data"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -16221,6 +16499,14 @@ func (ec *executionContext) unmarshalInputMobileSignUpInput(ctx context.Context, if err != nil { return it, err } + case "app_data": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("app_data")) + it.AppData, err = ec.unmarshalOMap2map(ctx, v) + if err != nil { + return it, err + } } } @@ -16326,7 +16612,7 @@ func (ec *executionContext) unmarshalInputResendOTPRequest(ctx context.Context, asMap[k] = v } - fieldsInOrder := [...]string{"email", "state"} + fieldsInOrder := [...]string{"email", "phone_number", "state"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -16337,7 +16623,15 @@ func (ec *executionContext) unmarshalInputResendOTPRequest(ctx context.Context, var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("email")) - it.Email, err = ec.unmarshalNString2string(ctx, v) + it.Email, err = ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } + case "phone_number": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("phone_number")) + it.PhoneNumber, err = ec.unmarshalOString2ᚖstring(ctx, v) if err != nil { return it, err } @@ -16486,7 +16780,7 @@ func (ec *executionContext) unmarshalInputSignUpInput(ctx context.Context, obj i asMap[k] = v } - fieldsInOrder := [...]string{"email", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "password", "confirm_password", "roles", "scope", "redirect_uri", "is_multi_factor_auth_enabled", "state"} + fieldsInOrder := [...]string{"email", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "password", "confirm_password", "roles", "scope", "redirect_uri", "is_multi_factor_auth_enabled", "state", "app_data"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -16621,6 +16915,14 @@ func (ec *executionContext) unmarshalInputSignUpInput(ctx context.Context, obj i if err != nil { return it, err } + case "app_data": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("app_data")) + it.AppData, err = ec.unmarshalOMap2map(ctx, v) + if err != nil { + return it, err + } } } @@ -16634,7 +16936,7 @@ func (ec *executionContext) unmarshalInputTestEndpointRequest(ctx context.Contex asMap[k] = v } - fieldsInOrder := [...]string{"endpoint", "event_name", "headers"} + fieldsInOrder := [...]string{"endpoint", "event_name", "event_description", "headers"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -16657,6 +16959,14 @@ func (ec *executionContext) unmarshalInputTestEndpointRequest(ctx context.Contex if err != nil { return it, err } + case "event_description": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("event_description")) + it.EventDescription, err = ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } case "headers": var err error @@ -17202,7 +17512,7 @@ func (ec *executionContext) unmarshalInputUpdateProfileInput(ctx context.Context asMap[k] = v } - fieldsInOrder := [...]string{"old_password", "new_password", "confirm_new_password", "email", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "is_multi_factor_auth_enabled"} + fieldsInOrder := [...]string{"old_password", "new_password", "confirm_new_password", "email", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "is_multi_factor_auth_enabled", "app_data"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -17313,6 +17623,14 @@ func (ec *executionContext) unmarshalInputUpdateProfileInput(ctx context.Context if err != nil { return it, err } + case "app_data": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("app_data")) + it.AppData, err = ec.unmarshalOMap2map(ctx, v) + if err != nil { + return it, err + } } } @@ -17326,7 +17644,7 @@ func (ec *executionContext) unmarshalInputUpdateUserInput(ctx context.Context, o asMap[k] = v } - fieldsInOrder := [...]string{"id", "email", "email_verified", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "roles", "is_multi_factor_auth_enabled"} + fieldsInOrder := [...]string{"id", "email", "email_verified", "given_name", "family_name", "middle_name", "nickname", "gender", "birthdate", "phone_number", "picture", "roles", "is_multi_factor_auth_enabled", "app_data"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -17437,6 +17755,14 @@ func (ec *executionContext) unmarshalInputUpdateUserInput(ctx context.Context, o if err != nil { return it, err } + case "app_data": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("app_data")) + it.AppData, err = ec.unmarshalOMap2map(ctx, v) + if err != nil { + return it, err + } } } @@ -17555,6 +17881,42 @@ func (ec *executionContext) unmarshalInputValidateJWTTokenInput(ctx context.Cont return it, nil } +func (ec *executionContext) unmarshalInputValidateSessionInput(ctx context.Context, obj interface{}) (model.ValidateSessionInput, error) { + var it model.ValidateSessionInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + fieldsInOrder := [...]string{"cookie", "roles"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } + switch k { + case "cookie": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("cookie")) + it.Cookie, err = ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + case "roles": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) + it.Roles, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputVerifyEmailInput(ctx context.Context, obj interface{}) (model.VerifyEmailInput, error) { var it model.VerifyEmailInput asMap := map[string]interface{}{} @@ -17591,42 +17953,6 @@ func (ec *executionContext) unmarshalInputVerifyEmailInput(ctx context.Context, return it, nil } -func (ec *executionContext) unmarshalInputVerifyMobileRequest(ctx context.Context, obj interface{}) (model.VerifyMobileRequest, error) { - var it model.VerifyMobileRequest - asMap := map[string]interface{}{} - for k, v := range obj.(map[string]interface{}) { - asMap[k] = v - } - - fieldsInOrder := [...]string{"phone_number", "code"} - for _, k := range fieldsInOrder { - v, ok := asMap[k] - if !ok { - continue - } - switch k { - case "phone_number": - var err error - - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("phone_number")) - it.PhoneNumber, err = ec.unmarshalNString2string(ctx, v) - if err != nil { - return it, err - } - case "code": - var err error - - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("code")) - it.Code, err = ec.unmarshalNString2string(ctx, v) - if err != nil { - return it, err - } - } - } - - return it, nil -} - func (ec *executionContext) unmarshalInputVerifyOTPRequest(ctx context.Context, obj interface{}) (model.VerifyOTPRequest, error) { var it model.VerifyOTPRequest asMap := map[string]interface{}{} @@ -17634,7 +17960,7 @@ func (ec *executionContext) unmarshalInputVerifyOTPRequest(ctx context.Context, asMap[k] = v } - fieldsInOrder := [...]string{"email", "otp", "state"} + fieldsInOrder := [...]string{"email", "phone_number", "otp", "state"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -17645,7 +17971,15 @@ func (ec *executionContext) unmarshalInputVerifyOTPRequest(ctx context.Context, var err error ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("email")) - it.Email, err = ec.unmarshalNString2string(ctx, v) + it.Email, err = ec.unmarshalOString2ᚖstring(ctx, v) + if err != nil { + return it, err + } + case "phone_number": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("phone_number")) + it.PhoneNumber, err = ec.unmarshalOString2ᚖstring(ctx, v) if err != nil { return it, err } @@ -17724,9 +18058,13 @@ func (ec *executionContext) _AuthResponse(ctx context.Context, sel ast.Selection if out.Values[i] == graphql.Null { invalids++ } - case "should_show_otp_screen": + case "should_show_email_otp_screen": - out.Values[i] = ec._AuthResponse_should_show_otp_screen(ctx, field, obj) + out.Values[i] = ec._AuthResponse_should_show_email_otp_screen(ctx, field, obj) + + case "should_show_mobile_otp_screen": + + out.Values[i] = ec._AuthResponse_should_show_mobile_otp_screen(ctx, field, obj) case "access_token": @@ -18533,15 +18871,6 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) return ec._Mutation_resend_otp(ctx, field) }) - if out.Values[i] == graphql.Null { - invalids++ - } - case "verify_mobile": - - out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { - return ec._Mutation_verify_mobile(ctx, field) - }) - if out.Values[i] == graphql.Null { invalids++ } @@ -18866,6 +19195,29 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) } + out.Concurrently(i, func() graphql.Marshaler { + return rrm(innerCtx) + }) + case "validate_session": + field := field + + innerFunc := func(ctx context.Context) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_validate_session(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) + } + out.Concurrently(i, func() graphql.Marshaler { return rrm(innerCtx) }) @@ -19317,6 +19669,10 @@ func (ec *executionContext) _User(ctx context.Context, sel ast.SelectionSet, obj out.Values[i] = ec._User_is_multi_factor_auth_enabled(ctx, field, obj) + case "app_data": + + out.Values[i] = ec._User_app_data(ctx, field, obj) + default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -19395,6 +19751,41 @@ func (ec *executionContext) _ValidateJWTTokenResponse(ctx context.Context, sel a return out } +var validateSessionResponseImplementors = []string{"ValidateSessionResponse"} + +func (ec *executionContext) _ValidateSessionResponse(ctx context.Context, sel ast.SelectionSet, obj *model.ValidateSessionResponse) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, validateSessionResponseImplementors) + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("ValidateSessionResponse") + case "is_valid": + + out.Values[i] = ec._ValidateSessionResponse_is_valid(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + case "user": + + out.Values[i] = ec._ValidateSessionResponse_user(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var verificationRequestImplementors = []string{"VerificationRequest"} func (ec *executionContext) _VerificationRequest(ctx context.Context, sel ast.SelectionSet, obj *model.VerificationRequest) graphql.Marshaler { @@ -20470,6 +20861,20 @@ func (ec *executionContext) marshalNValidateJWTTokenResponse2ᚖgithubᚗcomᚋa return ec._ValidateJWTTokenResponse(ctx, sel, v) } +func (ec *executionContext) marshalNValidateSessionResponse2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionResponse(ctx context.Context, sel ast.SelectionSet, v model.ValidateSessionResponse) graphql.Marshaler { + return ec._ValidateSessionResponse(ctx, sel, &v) +} + +func (ec *executionContext) marshalNValidateSessionResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionResponse(ctx context.Context, sel ast.SelectionSet, v *model.ValidateSessionResponse) graphql.Marshaler { + if v == nil { + if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { + ec.Errorf(ctx, "the requested element is null which the schema does not allow") + } + return graphql.Null + } + return ec._ValidateSessionResponse(ctx, sel, v) +} + func (ec *executionContext) marshalNVerificationRequest2ᚕᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐVerificationRequestᚄ(ctx context.Context, sel ast.SelectionSet, v []*model.VerificationRequest) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup @@ -20543,11 +20948,6 @@ func (ec *executionContext) unmarshalNVerifyEmailInput2githubᚗcomᚋauthorizer return res, graphql.ErrorOnPath(ctx, err) } -func (ec *executionContext) unmarshalNVerifyMobileRequest2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐVerifyMobileRequest(ctx context.Context, v interface{}) (model.VerifyMobileRequest, error) { - res, err := ec.unmarshalInputVerifyMobileRequest(ctx, v) - return res, graphql.ErrorOnPath(ctx, err) -} - func (ec *executionContext) unmarshalNVerifyOTPRequest2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐVerifyOTPRequest(ctx context.Context, v interface{}) (model.VerifyOTPRequest, error) { res, err := ec.unmarshalInputVerifyOTPRequest(ctx, v) return res, graphql.ErrorOnPath(ctx, err) @@ -21158,6 +21558,14 @@ func (ec *executionContext) marshalOUser2ᚖgithubᚗcomᚋauthorizerdevᚋautho return ec._User(ctx, sel, v) } +func (ec *executionContext) unmarshalOValidateSessionInput2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidateSessionInput(ctx context.Context, v interface{}) (*model.ValidateSessionInput, error) { + if v == nil { + return nil, nil + } + res, err := ec.unmarshalInputValidateSessionInput(ctx, v) + return &res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) marshalO__EnumValue2ᚕgithubᚗcomᚋ99designsᚋgqlgenᚋgraphqlᚋintrospectionᚐEnumValueᚄ(ctx context.Context, sel ast.SelectionSet, v []introspection.EnumValue) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 7a1e376..a76b0e3 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -26,13 +26,14 @@ type AdminSignupInput struct { } type AuthResponse struct { - Message string `json:"message"` - ShouldShowOtpScreen *bool `json:"should_show_otp_screen"` - AccessToken *string `json:"access_token"` - IDToken *string `json:"id_token"` - RefreshToken *string `json:"refresh_token"` - ExpiresIn *int64 `json:"expires_in"` - User *User `json:"user"` + Message string `json:"message"` + ShouldShowEmailOtpScreen *bool `json:"should_show_email_otp_screen"` + ShouldShowMobileOtpScreen *bool `json:"should_show_mobile_otp_screen"` + AccessToken *string `json:"access_token"` + IDToken *string `json:"id_token"` + RefreshToken *string `json:"refresh_token"` + ExpiresIn *int64 `json:"expires_in"` + User *User `json:"user"` } type DeleteEmailTemplateRequest struct { @@ -120,7 +121,6 @@ type Env struct { AdminCookieSecure bool `json:"ADMIN_COOKIE_SECURE"` DefaultAuthorizeResponseType *string `json:"DEFAULT_AUTHORIZE_RESPONSE_TYPE"` DefaultAuthorizeResponseMode *string `json:"DEFAULT_AUTHORIZE_RESPONSE_MODE"` - SmsCodeExpiryTime *string `json:"SMS_CODE_EXPIRY_TIME"` } type Error struct { @@ -207,22 +207,23 @@ type MobileLoginInput struct { } type MobileSignUpInput struct { - Email *string `json:"email"` - GivenName *string `json:"given_name"` - FamilyName *string `json:"family_name"` - MiddleName *string `json:"middle_name"` - Nickname *string `json:"nickname"` - Gender *string `json:"gender"` - Birthdate *string `json:"birthdate"` - PhoneNumber string `json:"phone_number"` - Picture *string `json:"picture"` - Password string `json:"password"` - ConfirmPassword string `json:"confirm_password"` - Roles []string `json:"roles"` - Scope []string `json:"scope"` - RedirectURI *string `json:"redirect_uri"` - IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` - State *string `json:"state"` + Email *string `json:"email"` + GivenName *string `json:"given_name"` + FamilyName *string `json:"family_name"` + MiddleName *string `json:"middle_name"` + Nickname *string `json:"nickname"` + Gender *string `json:"gender"` + Birthdate *string `json:"birthdate"` + PhoneNumber string `json:"phone_number"` + Picture *string `json:"picture"` + Password string `json:"password"` + ConfirmPassword string `json:"confirm_password"` + Roles []string `json:"roles"` + Scope []string `json:"scope"` + RedirectURI *string `json:"redirect_uri"` + IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + State *string `json:"state"` + AppData map[string]interface{} `json:"app_data"` } type OAuthRevokeInput struct { @@ -246,8 +247,9 @@ type PaginationInput struct { } type ResendOTPRequest struct { - Email string `json:"email"` - State *string `json:"state"` + Email *string `json:"email"` + PhoneNumber *string `json:"phone_number"` + State *string `json:"state"` } type ResendVerifyEmailInput struct { @@ -281,28 +283,30 @@ type SessionQueryInput struct { } type SignUpInput struct { - Email string `json:"email"` - GivenName *string `json:"given_name"` - FamilyName *string `json:"family_name"` - MiddleName *string `json:"middle_name"` - Nickname *string `json:"nickname"` - Gender *string `json:"gender"` - Birthdate *string `json:"birthdate"` - PhoneNumber *string `json:"phone_number"` - Picture *string `json:"picture"` - Password string `json:"password"` - ConfirmPassword string `json:"confirm_password"` - Roles []string `json:"roles"` - Scope []string `json:"scope"` - RedirectURI *string `json:"redirect_uri"` - IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` - State *string `json:"state"` + Email string `json:"email"` + GivenName *string `json:"given_name"` + FamilyName *string `json:"family_name"` + MiddleName *string `json:"middle_name"` + Nickname *string `json:"nickname"` + Gender *string `json:"gender"` + Birthdate *string `json:"birthdate"` + PhoneNumber *string `json:"phone_number"` + Picture *string `json:"picture"` + Password string `json:"password"` + ConfirmPassword string `json:"confirm_password"` + Roles []string `json:"roles"` + Scope []string `json:"scope"` + RedirectURI *string `json:"redirect_uri"` + IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + State *string `json:"state"` + AppData map[string]interface{} `json:"app_data"` } type TestEndpointRequest struct { - Endpoint string `json:"endpoint"` - EventName string `json:"event_name"` - Headers map[string]interface{} `json:"headers"` + Endpoint string `json:"endpoint"` + EventName string `json:"event_name"` + EventDescription *string `json:"event_description"` + Headers map[string]interface{} `json:"headers"` } type TestEndpointResponse struct { @@ -378,35 +382,37 @@ type UpdateEnvInput struct { } type UpdateProfileInput struct { - OldPassword *string `json:"old_password"` - NewPassword *string `json:"new_password"` - ConfirmNewPassword *string `json:"confirm_new_password"` - Email *string `json:"email"` - GivenName *string `json:"given_name"` - FamilyName *string `json:"family_name"` - MiddleName *string `json:"middle_name"` - Nickname *string `json:"nickname"` - Gender *string `json:"gender"` - Birthdate *string `json:"birthdate"` - PhoneNumber *string `json:"phone_number"` - Picture *string `json:"picture"` - IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + OldPassword *string `json:"old_password"` + NewPassword *string `json:"new_password"` + ConfirmNewPassword *string `json:"confirm_new_password"` + Email *string `json:"email"` + GivenName *string `json:"given_name"` + FamilyName *string `json:"family_name"` + MiddleName *string `json:"middle_name"` + Nickname *string `json:"nickname"` + Gender *string `json:"gender"` + Birthdate *string `json:"birthdate"` + PhoneNumber *string `json:"phone_number"` + Picture *string `json:"picture"` + IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + AppData map[string]interface{} `json:"app_data"` } type UpdateUserInput struct { - ID string `json:"id"` - Email *string `json:"email"` - EmailVerified *bool `json:"email_verified"` - GivenName *string `json:"given_name"` - FamilyName *string `json:"family_name"` - MiddleName *string `json:"middle_name"` - Nickname *string `json:"nickname"` - Gender *string `json:"gender"` - Birthdate *string `json:"birthdate"` - PhoneNumber *string `json:"phone_number"` - Picture *string `json:"picture"` - Roles []*string `json:"roles"` - IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + ID string `json:"id"` + Email *string `json:"email"` + EmailVerified *bool `json:"email_verified"` + GivenName *string `json:"given_name"` + FamilyName *string `json:"family_name"` + MiddleName *string `json:"middle_name"` + Nickname *string `json:"nickname"` + Gender *string `json:"gender"` + Birthdate *string `json:"birthdate"` + PhoneNumber *string `json:"phone_number"` + Picture *string `json:"picture"` + Roles []*string `json:"roles"` + IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + AppData map[string]interface{} `json:"app_data"` } type UpdateWebhookRequest struct { @@ -419,25 +425,26 @@ type UpdateWebhookRequest struct { } type User struct { - ID string `json:"id"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - SignupMethods string `json:"signup_methods"` - GivenName *string `json:"given_name"` - FamilyName *string `json:"family_name"` - MiddleName *string `json:"middle_name"` - Nickname *string `json:"nickname"` - PreferredUsername *string `json:"preferred_username"` - Gender *string `json:"gender"` - Birthdate *string `json:"birthdate"` - PhoneNumber *string `json:"phone_number"` - PhoneNumberVerified *bool `json:"phone_number_verified"` - Picture *string `json:"picture"` - Roles []string `json:"roles"` - CreatedAt *int64 `json:"created_at"` - UpdatedAt *int64 `json:"updated_at"` - RevokedTimestamp *int64 `json:"revoked_timestamp"` - IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + ID string `json:"id"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + SignupMethods string `json:"signup_methods"` + GivenName *string `json:"given_name"` + FamilyName *string `json:"family_name"` + MiddleName *string `json:"middle_name"` + Nickname *string `json:"nickname"` + PreferredUsername *string `json:"preferred_username"` + Gender *string `json:"gender"` + Birthdate *string `json:"birthdate"` + PhoneNumber *string `json:"phone_number"` + PhoneNumberVerified *bool `json:"phone_number_verified"` + Picture *string `json:"picture"` + Roles []string `json:"roles"` + CreatedAt *int64 `json:"created_at"` + UpdatedAt *int64 `json:"updated_at"` + RevokedTimestamp *int64 `json:"revoked_timestamp"` + IsMultiFactorAuthEnabled *bool `json:"is_multi_factor_auth_enabled"` + AppData map[string]interface{} `json:"app_data"` } type Users struct { @@ -456,6 +463,16 @@ type ValidateJWTTokenResponse struct { Claims map[string]interface{} `json:"claims"` } +type ValidateSessionInput struct { + Cookie string `json:"cookie"` + Roles []string `json:"roles"` +} + +type ValidateSessionResponse struct { + IsValid bool `json:"is_valid"` + User *User `json:"user"` +} + type VerificationRequest struct { ID string `json:"id"` Identifier *string `json:"identifier"` @@ -478,15 +495,11 @@ type VerifyEmailInput struct { State *string `json:"state"` } -type VerifyMobileRequest struct { - PhoneNumber string `json:"phone_number"` - Code string `json:"code"` -} - type VerifyOTPRequest struct { - Email string `json:"email"` - Otp string `json:"otp"` - State *string `json:"state"` + Email *string `json:"email"` + PhoneNumber *string `json:"phone_number"` + Otp string `json:"otp"` + State *string `json:"state"` } type Webhook struct { diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index 4013b12..9125acd 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -51,6 +51,7 @@ type User { updated_at: Int64 revoked_timestamp: Int64 is_multi_factor_auth_enabled: Boolean + app_data: Map } type Users { @@ -84,11 +85,6 @@ type SMSVerificationRequests { updated_at: Int64 } -input VerifyMobileRequest { - phone_number: String! - code: String! -} - type Error { message: String! reason: String! @@ -96,7 +92,8 @@ type Error { type AuthResponse { message: String! - should_show_otp_screen: Boolean + should_show_email_otp_screen: Boolean + should_show_mobile_otp_screen: Boolean access_token: String id_token: String refresh_token: String @@ -182,6 +179,11 @@ type ValidateJWTTokenResponse { claims: Map } +type ValidateSessionResponse { + is_valid: Boolean! + user: User! +} + type GenerateJWTKeysResponse { secret: String public_key: String @@ -322,6 +324,7 @@ input MobileSignUpInput { # it is used to get code for an on-going auth process during login # and use that code for setting `c_hash` in id_token state: String + app_data: Map } input SignUpInput { @@ -344,6 +347,7 @@ input SignUpInput { # it is used to get code for an on-going auth process during login # and use that code for setting `c_hash` in id_token state: String + app_data: Map } input LoginInput { @@ -399,6 +403,7 @@ input UpdateProfileInput { phone_number: String picture: String is_multi_factor_auth_enabled: Boolean + app_data: Map } input UpdateUserInput { @@ -415,6 +420,7 @@ input UpdateUserInput { picture: String roles: [String] is_multi_factor_auth_enabled: Boolean + app_data: Map } input ForgotPasswordInput { @@ -474,6 +480,11 @@ input ValidateJWTTokenInput { roles: [String!] } +input ValidateSessionInput { + cookie: String! + roles: [String!] +} + input GenerateJWTKeysInput { type: String! } @@ -507,6 +518,7 @@ input WebhookRequest { input TestEndpointRequest { endpoint: String! event_name: String! + event_description: String headers: Map } @@ -534,7 +546,9 @@ input DeleteEmailTemplateRequest { } input VerifyOTPRequest { - email: String! + # either email or phone_number is required + email: String + phone_number: String otp: String! # state is used for authorization code grant flow # it is used to get code for an on-going auth process during login @@ -543,7 +557,8 @@ input VerifyOTPRequest { } input ResendOTPRequest { - email: String! + email: String + phone_number: String # state is used for authorization code grant flow # it is used to get code for an on-going auth process during login # and use that code for setting `c_hash` in id_token @@ -570,7 +585,6 @@ type Mutation { revoke(params: OAuthRevokeInput!): Response! verify_otp(params: VerifyOTPRequest!): AuthResponse! resend_otp(params: ResendOTPRequest!): Response! - verify_mobile(params: VerifyMobileRequest!): AuthResponse! # admin only apis _delete_user(params: DeleteUserInput!): Response! _update_user(params: UpdateUserInput!): User! @@ -596,6 +610,7 @@ type Query { session(params: SessionQueryInput): AuthResponse! profile: User! validate_jwt_token(params: ValidateJWTTokenInput!): ValidateJWTTokenResponse! + validate_session(params: ValidateSessionInput): ValidateSessionResponse! # admin only apis _users(params: PaginatedInput): Users! _user(params: GetUserRequest!): User! diff --git a/server/graph/schema.resolvers.go b/server/graph/schema.resolvers.go index 75dad85..eecb6b2 100644 --- a/server/graph/schema.resolvers.go +++ b/server/graph/schema.resolvers.go @@ -81,11 +81,6 @@ func (r *mutationResolver) ResendOtp(ctx context.Context, params model.ResendOTP return resolvers.ResendOTPResolver(ctx, params) } -// VerifyMobile is the resolver for the verify_mobile field. -func (r *mutationResolver) VerifyMobile(ctx context.Context, params model.VerifyMobileRequest) (*model.AuthResponse, error) { - return resolvers.VerifyMobileResolver(ctx, params) -} - // DeleteUser is the resolver for the _delete_user field. func (r *mutationResolver) DeleteUser(ctx context.Context, params model.DeleteUserInput) (*model.Response, error) { return resolvers.DeleteUserResolver(ctx, params) @@ -191,6 +186,11 @@ func (r *queryResolver) ValidateJwtToken(ctx context.Context, params model.Valid return resolvers.ValidateJwtTokenResolver(ctx, params) } +// ValidateSession is the resolver for the validate_session field. +func (r *queryResolver) ValidateSession(ctx context.Context, params *model.ValidateSessionInput) (*model.ValidateSessionResponse, error) { + return resolvers.ValidateSessionResolver(ctx, params) +} + // Users is the resolver for the _users field. func (r *queryResolver) Users(ctx context.Context, params *model.PaginatedInput) (*model.Users, error) { return resolvers.UsersResolver(ctx, params) diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index 1c547e2..782da41 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -32,11 +32,11 @@ func OAuthCallbackHandler() gin.HandlerFunc { return func(ctx *gin.Context) { provider := ctx.Param("oauth_provider") state := ctx.Request.FormValue("state") - sessionState, err := memorystore.Provider.GetState(state) if sessionState == "" || err != nil { log.Debug("Invalid oauth state: ", state) ctx.JSON(400, gin.H{"error": "invalid oauth state"}) + return } // contains random token, redirect url, role sessionSplit := strings.Split(state, "___") @@ -46,32 +46,34 @@ func OAuthCallbackHandler() gin.HandlerFunc { ctx.JSON(400, gin.H{"error": "invalid redirect url"}) return } - // remove state from store go memorystore.Provider.RemoveState(state) - stateValue := sessionSplit[0] redirectURL := sessionSplit[1] inputRoles := strings.Split(sessionSplit[2], ",") scopes := strings.Split(sessionSplit[3], ",") - - user := models.User{} + var user *models.User oauthCode := ctx.Request.FormValue("code") + if oauthCode == "" { + log.Debug("Invalid oauth code: ", oauthCode) + ctx.JSON(400, gin.H{"error": "invalid oauth code"}) + return + } switch provider { case constants.AuthRecipeMethodGoogle: - user, err = processGoogleUserInfo(oauthCode) + user, err = processGoogleUserInfo(ctx, oauthCode) case constants.AuthRecipeMethodGithub: - user, err = processGithubUserInfo(oauthCode) + user, err = processGithubUserInfo(ctx, oauthCode) case constants.AuthRecipeMethodFacebook: - user, err = processFacebookUserInfo(oauthCode) + user, err = processFacebookUserInfo(ctx, oauthCode) case constants.AuthRecipeMethodLinkedIn: - user, err = processLinkedInUserInfo(oauthCode) + user, err = processLinkedInUserInfo(ctx, oauthCode) case constants.AuthRecipeMethodApple: - user, err = processAppleUserInfo(oauthCode) + user, err = processAppleUserInfo(ctx, oauthCode) case constants.AuthRecipeMethodTwitter: - user, err = processTwitterUserInfo(oauthCode, sessionState) + user, err = processTwitterUserInfo(ctx, oauthCode, sessionState) case constants.AuthRecipeMethodMicrosoft: - user, err = processMicrosoftUserInfo(oauthCode) + user, err = processMicrosoftUserInfo(ctx, oauthCode) default: log.Info("Invalid oauth provider") err = fmt.Errorf(`invalid oauth provider`) @@ -260,10 +262,12 @@ func OAuthCallbackHandler() gin.HandlerFunc { go func() { if isSignUp { utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, provider, user) + // User is also logged in with signup + utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, provider, user) } else { utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, provider, user) } - db.Provider.AddSession(ctx, models.Session{ + db.Provider.AddSession(ctx, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(ctx.Request), IP: utils.GetIP(ctx.Request), @@ -279,15 +283,13 @@ func OAuthCallbackHandler() gin.HandlerFunc { } } -func processGoogleUserInfo(code string) (models.User, error) { - user := models.User{} - ctx := context.Background() +func processGoogleUserInfo(ctx context.Context, code string) (*models.User, error) { + var user *models.User oauth2Token, err := oauth.OAuthProviders.GoogleConfig.Exchange(ctx, code) if err != nil { log.Debug("Failed to exchange code for token: ", err) return user, fmt.Errorf("invalid google exchange code: %s", err.Error()) } - verifier := oauth.OIDCProviders.GoogleOIDC.Verifier(&oidc.Config{ClientID: oauth.OAuthProviders.GoogleConfig.ClientID}) // Extract the ID Token from OAuth2 token. @@ -312,9 +314,9 @@ func processGoogleUserInfo(code string) (models.User, error) { return user, nil } -func processGithubUserInfo(code string) (models.User, error) { - user := models.User{} - oauth2Token, err := oauth.OAuthProviders.GithubConfig.Exchange(context.TODO(), code) +func processGithubUserInfo(ctx context.Context, code string) (*models.User, error) { + var user *models.User + oauth2Token, err := oauth.OAuthProviders.GithubConfig.Exchange(ctx, code) if err != nil { log.Debug("Failed to exchange code for token: ", err) return user, fmt.Errorf("invalid github exchange code: %s", err.Error()) @@ -409,7 +411,7 @@ func processGithubUserInfo(code string) (models.User, error) { } } - user = models.User{ + user = &models.User{ GivenName: &firstName, FamilyName: &lastName, Picture: &picture, @@ -419,9 +421,9 @@ func processGithubUserInfo(code string) (models.User, error) { return user, nil } -func processFacebookUserInfo(code string) (models.User, error) { - user := models.User{} - oauth2Token, err := oauth.OAuthProviders.FacebookConfig.Exchange(context.TODO(), code) +func processFacebookUserInfo(ctx context.Context, code string) (*models.User, error) { + var user *models.User + oauth2Token, err := oauth.OAuthProviders.FacebookConfig.Exchange(ctx, code) if err != nil { log.Debug("Invalid facebook exchange code: ", err) return user, fmt.Errorf("invalid facebook exchange code: %s", err.Error()) @@ -460,7 +462,7 @@ func processFacebookUserInfo(code string) (models.User, error) { lastName := fmt.Sprintf("%v", userRawData["last_name"]) picture := fmt.Sprintf("%v", picDataObject["url"]) - user = models.User{ + user = &models.User{ GivenName: &firstName, FamilyName: &lastName, Picture: &picture, @@ -470,9 +472,9 @@ func processFacebookUserInfo(code string) (models.User, error) { return user, nil } -func processLinkedInUserInfo(code string) (models.User, error) { - user := models.User{} - oauth2Token, err := oauth.OAuthProviders.LinkedInConfig.Exchange(context.TODO(), code) +func processLinkedInUserInfo(ctx context.Context, code string) (*models.User, error) { + var user *models.User + oauth2Token, err := oauth.OAuthProviders.LinkedInConfig.Exchange(ctx, code) if err != nil { log.Debug("Failed to exchange code for token: ", err) return user, fmt.Errorf("invalid linkedin exchange code: %s", err.Error()) @@ -542,7 +544,7 @@ func processLinkedInUserInfo(code string) (models.User, error) { profilePicture := userRawData["profilePicture"].(map[string]interface{})["displayImage~"].(map[string]interface{})["elements"].([]interface{})[0].(map[string]interface{})["identifiers"].([]interface{})[0].(map[string]interface{})["identifier"].(string) emailAddress := emailRawData["elements"].([]interface{})[0].(map[string]interface{})["handle~"].(map[string]interface{})["emailAddress"].(string) - user = models.User{ + user = &models.User{ GivenName: &firstName, FamilyName: &lastName, Picture: &profilePicture, @@ -552,9 +554,9 @@ func processLinkedInUserInfo(code string) (models.User, error) { return user, nil } -func processAppleUserInfo(code string) (models.User, error) { - user := models.User{} - oauth2Token, err := oauth.OAuthProviders.AppleConfig.Exchange(context.TODO(), code) +func processAppleUserInfo(ctx context.Context, code string) (*models.User, error) { + var user *models.User + oauth2Token, err := oauth.OAuthProviders.AppleConfig.Exchange(ctx, code) if err != nil { log.Debug("Failed to exchange code for token: ", err) return user, fmt.Errorf("invalid apple exchange code: %s", err.Error()) @@ -605,9 +607,9 @@ func processAppleUserInfo(code string) (models.User, error) { return user, err } -func processTwitterUserInfo(code, verifier string) (models.User, error) { - user := models.User{} - oauth2Token, err := oauth.OAuthProviders.TwitterConfig.Exchange(context.TODO(), code, oauth2.SetAuthURLParam("code_verifier", verifier)) +func processTwitterUserInfo(ctx context.Context, code, verifier string) (*models.User, error) { + var user *models.User + oauth2Token, err := oauth.OAuthProviders.TwitterConfig.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", verifier)) if err != nil { log.Debug("Failed to exchange code for token: ", err) return user, fmt.Errorf("invalid twitter exchange code: %s", err.Error()) @@ -662,7 +664,7 @@ func processTwitterUserInfo(code, verifier string) (models.User, error) { nickname := userRawData["username"].(string) profilePicture := userRawData["profile_image_url"].(string) - user = models.User{ + user = &models.User{ GivenName: &firstName, FamilyName: &lastName, Picture: &profilePicture, @@ -673,24 +675,24 @@ func processTwitterUserInfo(code, verifier string) (models.User, error) { } // process microsoft user information -func processMicrosoftUserInfo(code string) (models.User, error) { - user := models.User{} - ctx := context.Background() +func processMicrosoftUserInfo(ctx context.Context, code string) (*models.User, error) { + var user *models.User oauth2Token, err := oauth.OAuthProviders.MicrosoftConfig.Exchange(ctx, code) if err != nil { log.Debug("Failed to exchange code for token: ", err) - return user, fmt.Errorf("invalid google exchange code: %s", err.Error()) + return user, fmt.Errorf("invalid microsoft exchange code: %s", err.Error()) } - - verifier := oauth.OIDCProviders.MicrosoftOIDC.Verifier(&oidc.Config{ClientID: oauth.OAuthProviders.MicrosoftConfig.ClientID}) - + // we need to skip issuer check because for common tenant it will return internal issuer which does not match + verifier := oauth.OIDCProviders.MicrosoftOIDC.Verifier(&oidc.Config{ + ClientID: oauth.OAuthProviders.MicrosoftConfig.ClientID, + SkipIssuerCheck: true, + }) // Extract the ID Token from OAuth2 token. rawIDToken, ok := oauth2Token.Extra("id_token").(string) if !ok { log.Debug("Failed to extract ID Token from OAuth2 token") return user, fmt.Errorf("unable to extract id_token") } - // Parse and verify ID Token payload. idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { diff --git a/server/handlers/verify_email.go b/server/handlers/verify_email.go index 452820f..47f61d3 100644 --- a/server/handlers/verify_email.go +++ b/server/handlers/verify_email.go @@ -175,11 +175,12 @@ func VerifyEmailHandler() gin.HandlerFunc { go func() { if isSignUp { utils.RegisterEvent(c, constants.UserSignUpWebhookEvent, loginMethod, user) + // User is also logged in with signup + utils.RegisterEvent(c, constants.UserLoginWebhookEvent, loginMethod, user) } else { utils.RegisterEvent(c, constants.UserLoginWebhookEvent, loginMethod, user) } - - db.Provider.AddSession(c, models.Session{ + db.Provider.AddSession(c, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(c.Request), IP: utils.GetIP(c.Request), diff --git a/server/memorystore/memory_store.go b/server/memorystore/memory_store.go index 4285860..2b3c9b2 100644 --- a/server/memorystore/memory_store.go +++ b/server/memorystore/memory_store.go @@ -33,6 +33,7 @@ func InitMemStore() error { constants.EnvKeyDisableSignUp: false, constants.EnvKeyDisableStrongPassword: false, constants.EnvKeyIsEmailServiceEnabled: false, + constants.EnvKeyIsSMSServiceEnabled: false, constants.EnvKeyEnforceMultiFactorAuthentication: false, constants.EnvKeyDisableMultiFactorAuthentication: false, constants.EnvKeyAppCookieSecure: true, diff --git a/server/memorystore/providers/inmemory/provider.go b/server/memorystore/providers/inmemory/provider.go index 952092d..e726502 100644 --- a/server/memorystore/providers/inmemory/provider.go +++ b/server/memorystore/providers/inmemory/provider.go @@ -7,18 +7,20 @@ import ( ) type provider struct { - mutex sync.Mutex - sessionStore *stores.SessionStore - stateStore *stores.StateStore - envStore *stores.EnvStore + mutex sync.Mutex + sessionStore *stores.SessionStore + mfasessionStore *stores.SessionStore + stateStore *stores.StateStore + envStore *stores.EnvStore } // NewInMemoryStore returns a new in-memory store. func NewInMemoryProvider() (*provider, error) { return &provider{ - mutex: sync.Mutex{}, - envStore: stores.NewEnvStore(), - sessionStore: stores.NewSessionStore(), - stateStore: stores.NewStateStore(), + mutex: sync.Mutex{}, + envStore: stores.NewEnvStore(), + sessionStore: stores.NewSessionStore(), + mfasessionStore: stores.NewSessionStore(), + stateStore: stores.NewStateStore(), }, nil } diff --git a/server/memorystore/providers/inmemory/store.go b/server/memorystore/providers/inmemory/store.go index 4a8e8ce..b20fb62 100644 --- a/server/memorystore/providers/inmemory/store.go +++ b/server/memorystore/providers/inmemory/store.go @@ -42,6 +42,27 @@ func (c *provider) DeleteSessionForNamespace(namespace string) error { return nil } +// SetMfaSession sets the mfa session with key and value of userId +func (c *provider) SetMfaSession(userId, key string, expiration int64) error { + c.mfasessionStore.Set(userId, key, userId, expiration) + return nil +} + +// GetMfaSession returns value of given mfa session +func (c *provider) GetMfaSession(userId, key string) (string, error) { + val := c.mfasessionStore.Get(userId, key) + if val == "" { + return "", fmt.Errorf("Not found") + } + return val, nil +} + +// DeleteMfaSession deletes given mfa session from in-memory store. +func (c *provider) DeleteMfaSession(userId, key string) error { + c.mfasessionStore.Remove(userId, key) + return nil +} + // SetState sets the state in the in-memory store. func (c *provider) SetState(key, state string) error { if os.Getenv("ENV") != constants.TestEnv { diff --git a/server/memorystore/providers/provider_tests.go b/server/memorystore/providers/provider_tests.go index e569fe8..47f4dba 100644 --- a/server/memorystore/providers/provider_tests.go +++ b/server/memorystore/providers/provider_tests.go @@ -112,4 +112,15 @@ func ProviderTests(t *testing.T, p Provider) { key, err = p.GetUserSession("auth_provider1:124", "access_token_key") assert.Empty(t, key) assert.Error(t, err) + + err = p.SetMfaSession("auth_provider:123", "session123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) + key, err = p.GetMfaSession("auth_provider:123", "session123") + assert.NoError(t, err) + assert.Equal(t, "auth_provider:123", key) + err = p.DeleteMfaSession("auth_provider:123", "session123") + assert.NoError(t, err) + key, err = p.GetMfaSession("auth_provider:123", "session123") + assert.Error(t, err) + assert.Empty(t, key) } diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go index db58aa7..331e34a 100644 --- a/server/memorystore/providers/providers.go +++ b/server/memorystore/providers/providers.go @@ -12,6 +12,12 @@ type Provider interface { DeleteAllUserSessions(userId string) error // DeleteSessionForNamespace deletes the session for a given namespace DeleteSessionForNamespace(namespace string) error + // SetMfaSession sets the mfa session with key and value of userId + SetMfaSession(userId, key string, expiration int64) error + // GetMfaSession returns value of given mfa session + GetMfaSession(userId, key string) (string, error) + // DeleteMfaSession deletes given mfa session from in-memory store. + DeleteMfaSession(userId, key string) error // SetState sets the login state (key, value form) in the session store SetState(key, state string) error diff --git a/server/memorystore/providers/redis/provider.go b/server/memorystore/providers/redis/provider.go index 894a75e..17fb475 100644 --- a/server/memorystore/providers/redis/provider.go +++ b/server/memorystore/providers/redis/provider.go @@ -9,6 +9,10 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + dialTimeout = 60 * time.Second +) + // RedisClient is the interface for redis client & redis cluster client type RedisClient interface { HMSet(ctx context.Context, key string, values ...interface{}) *redis.BoolCmd @@ -41,8 +45,7 @@ func NewRedisProvider(redisURL string) (*provider, error) { urls := []string{opt.Addr} urlList := redisURLHostPortsList[1:] urls = append(urls, urlList...) - clusterOpt := &redis.ClusterOptions{Addrs: urls} - + clusterOpt := &redis.ClusterOptions{Addrs: urls, DialTimeout: dialTimeout} rdb := redis.NewClusterClient(clusterOpt) ctx := context.Background() _, err = rdb.Ping(ctx).Result() @@ -62,7 +65,7 @@ func NewRedisProvider(redisURL string) (*provider, error) { log.Debug("error parsing redis url: ", err) return nil, err } - + opt.DialTimeout = dialTimeout rdb := redis.NewClient(opt) ctx := context.Background() _, err = rdb.Ping(ctx).Result() diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go index 058e95e..a6ff08f 100644 --- a/server/memorystore/providers/redis/store.go +++ b/server/memorystore/providers/redis/store.go @@ -16,6 +16,8 @@ var ( envStorePrefix = "authorizer_env" ) +const mfaSessionPrefix = "mfa_sess_" + // SetUserSession sets the user session for given user identifier in form recipe:user_id func (c *provider) SetUserSession(userId, key, token string, expiration int64) error { currentTime := time.Now() @@ -91,6 +93,37 @@ func (c *provider) DeleteSessionForNamespace(namespace string) error { return nil } +// SetMfaSession sets the mfa session with key and value of userId +func (c *provider) SetMfaSession(userId, key string, expiration int64) error { + currentTime := time.Now() + expireTime := time.Unix(expiration, 0) + duration := expireTime.Sub(currentTime) + err := c.store.Set(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, userId, key), userId, duration).Err() + if err != nil { + log.Debug("Error saving user session to redis: ", err) + return err + } + return nil +} + +// GetMfaSession returns value of given mfa session +func (c *provider) GetMfaSession(userId, key string) (string, error) { + data, err := c.store.Get(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, userId, key)).Result() + if err != nil { + return "", err + } + return data, nil +} + +// DeleteMfaSession deletes given mfa session from in-memory store. +func (c *provider) DeleteMfaSession(userId, key string) error { + if err := c.store.Del(c.ctx, fmt.Sprintf("%s%s:%s", mfaSessionPrefix, userId, key)).Err(); err != nil { + log.Debug("Error deleting user session from redis: ", err) + // continue + } + return nil +} + // SetState sets the state in redis store. func (c *provider) SetState(key, value string) error { err := c.store.Set(c.ctx, stateStorePrefix+key, value, 0).Err() @@ -143,7 +176,7 @@ func (c *provider) GetEnvStore() (map[string]interface{}, error) { return nil, err } for key, value := range data { - if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableMobileBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp || key == constants.EnvKeyDisableStrongPassword || key == constants.EnvKeyIsEmailServiceEnabled || key == constants.EnvKeyEnforceMultiFactorAuthentication || key == constants.EnvKeyDisableMultiFactorAuthentication || key == constants.EnvKeyAppCookieSecure || key == constants.EnvKeyAdminCookieSecure { + if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableMobileBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp || key == constants.EnvKeyDisableStrongPassword || key == constants.EnvKeyIsEmailServiceEnabled || key == constants.EnvKeyIsSMSServiceEnabled || key == constants.EnvKeyEnforceMultiFactorAuthentication || key == constants.EnvKeyDisableMultiFactorAuthentication || key == constants.EnvKeyAppCookieSecure || key == constants.EnvKeyAdminCookieSecure { boolValue, err := strconv.ParseBool(value) if err != nil { return res, err diff --git a/server/oauth/oauth.go b/server/oauth/oauth.go index 7841909..3f02916 100644 --- a/server/oauth/oauth.go +++ b/server/oauth/oauth.go @@ -10,11 +10,16 @@ import ( githubOAuth2 "golang.org/x/oauth2/github" linkedInOAuth2 "golang.org/x/oauth2/linkedin" microsoftOAuth2 "golang.org/x/oauth2/microsoft" + "google.golang.org/appengine/log" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/memorystore" ) +const ( + microsoftCommonTenant = "common" +) + // OAuthProviders is a struct that contains reference all the OAuth providers type OAuthProvider struct { GoogleConfig *oauth2.Config @@ -171,12 +176,16 @@ func InitOAuth() error { microsoftClientSecret = "" } microsoftActiveDirTenantID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyMicrosoftActiveDirectoryTenantID) - if err != nil { - microsoftActiveDirTenantID = "" + if err != nil || microsoftActiveDirTenantID == "" { + microsoftActiveDirTenantID = microsoftCommonTenant } - if microsoftClientID != "" && microsoftClientSecret != "" && microsoftActiveDirTenantID != "" { + if microsoftClientID != "" && microsoftClientSecret != "" { + if microsoftActiveDirTenantID == microsoftCommonTenant { + ctx = oidc.InsecureIssuerURLContext(ctx, fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", microsoftActiveDirTenantID)) + } p, err := oidc.NewProvider(ctx, fmt.Sprintf("https://login.microsoftonline.com/%s/v2.0", microsoftActiveDirTenantID)) if err != nil { + log.Debugf(ctx, "Error while creating OIDC provider for Microsoft: %v", err) return err } OIDCProviders.MicrosoftOIDC = p diff --git a/server/resolvers/add_email_template.go b/server/resolvers/add_email_template.go index 311b18a..487edc2 100644 --- a/server/resolvers/add_email_template.go +++ b/server/resolvers/add_email_template.go @@ -47,7 +47,7 @@ func AddEmailTemplateResolver(ctx context.Context, params model.AddEmailTemplate design = "" } - _, err = db.Provider.AddEmailTemplate(ctx, models.EmailTemplate{ + _, err = db.Provider.AddEmailTemplate(ctx, &models.EmailTemplate{ EventName: params.EventName, Template: params.Template, Subject: params.Subject, diff --git a/server/resolvers/add_webhook.go b/server/resolvers/add_webhook.go index 596b1e0..3380779 100644 --- a/server/resolvers/add_webhook.go +++ b/server/resolvers/add_webhook.go @@ -43,7 +43,7 @@ func AddWebhookResolver(ctx context.Context, params model.AddWebhookRequest) (*m if params.EventDescription == nil { params.EventDescription = refs.NewStringRef(strings.Join(strings.Split(params.EventName, "."), " ")) } - _, err = db.Provider.AddWebhook(ctx, models.Webhook{ + _, err = db.Provider.AddWebhook(ctx, &models.Webhook{ EventDescription: refs.StringValue(params.EventDescription), EventName: params.EventName, EndPoint: params.Endpoint, diff --git a/server/resolvers/email_templates.go b/server/resolvers/email_templates.go index 7230400..0e1ee66 100644 --- a/server/resolvers/email_templates.go +++ b/server/resolvers/email_templates.go @@ -25,7 +25,6 @@ func EmailTemplatesResolver(ctx context.Context, params *model.PaginatedInput) ( } pagination := utils.GetPagination(params) - emailTemplates, err := db.Provider.ListEmailTemplate(ctx, pagination) if err != nil { log.Debug("failed to get email templates: ", err) diff --git a/server/resolvers/forgot_password.go b/server/resolvers/forgot_password.go index f497b31..028ff11 100644 --- a/server/resolvers/forgot_password.go +++ b/server/resolvers/forgot_password.go @@ -81,7 +81,7 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu log.Debug("Failed to create verification token", err) return res, err } - _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{ + _, err = db.Provider.AddVerificationRequest(ctx, &models.VerificationRequest{ Token: verificationToken, Identifier: constants.VerificationTypeForgotPassword, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), diff --git a/server/resolvers/invite_members.go b/server/resolvers/invite_members.go index c15ca7e..ee5e0a1 100644 --- a/server/resolvers/invite_members.go +++ b/server/resolvers/invite_members.go @@ -105,7 +105,7 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput) defaultRoles = strings.Split(defaultRolesString, ",") } - user := models.User{ + user := &models.User{ Email: email, Roles: strings.Join(defaultRoles, ","), } @@ -128,7 +128,7 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput) log.Debug("Failed to create verification token: ", err) } - verificationRequest := models.VerificationRequest{ + verificationRequest := &models.VerificationRequest{ Token: verificationToken, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), Email: email, diff --git a/server/resolvers/login.go b/server/resolvers/login.go index 28a2289..9dcec5a 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -106,23 +106,32 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes } isMFADisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableMultiFactorAuthentication) - if err != nil || !isEmailServiceEnabled { + if err != nil || !isMFADisabled { log.Debug("MFA service not enabled: ", err) } // If email service is not enabled continue the process in any way if refs.BoolValue(user.IsMultiFactorAuthEnabled) && isEmailServiceEnabled && !isMFADisabled { otp := utils.GenerateOTP() + expires := time.Now().Add(1 * time.Minute).Unix() otpData, err := db.Provider.UpsertOTP(ctx, &models.OTP{ Email: user.Email, Otp: otp, - ExpiresAt: time.Now().Add(1 * time.Minute).Unix(), + ExpiresAt: expires, }) if err != nil { log.Debug("Failed to add otp: ", err) return nil, err } + mfaSession := uuid.NewString() + err = memorystore.Provider.SetMfaSession(user.ID, mfaSession, expires) + if err != nil { + log.Debug("Failed to add mfasession: ", err) + return nil, err + } + cookie.SetMfaSession(gc, mfaSession) + go func() { // exec it as go routine so that we can reduce the api latency go email.SendEmail([]string{params.Email}, constants.VerificationTypeOTP, map[string]interface{}{ @@ -136,8 +145,8 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes }() return &model.AuthResponse{ - Message: "Please check the OTP in your inbox", - ShouldShowOtpScreen: refs.NewBoolRef(true), + Message: "Please check the OTP in your inbox", + ShouldShowEmailOtpScreen: refs.NewBoolRef(true), }, nil } @@ -162,7 +171,6 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes if nonce == "" { nonce = uuid.New().String() } - authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth, nonce, code) if err != nil { log.Debug("Failed to create auth token", err) @@ -203,7 +211,7 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes go func() { utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodBasicAuth, user) - db.Provider.AddSession(ctx, models.Session{ + db.Provider.AddSession(ctx, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(gc.Request), IP: utils.GetIP(gc.Request), diff --git a/server/resolvers/magic_link_login.go b/server/resolvers/magic_link_login.go index 5ce90c8..a500c27 100644 --- a/server/resolvers/magic_link_login.go +++ b/server/resolvers/magic_link_login.go @@ -55,7 +55,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu inputRoles := []string{} - user := models.User{ + user := &models.User{ Email: params.Email, } @@ -207,7 +207,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu if err != nil { log.Debug("Failed to create verification token: ", err) } - _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{ + _, err = db.Provider.AddVerificationRequest(ctx, &models.VerificationRequest{ Token: verificationToken, Identifier: verificationType, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), diff --git a/server/resolvers/meta.go b/server/resolvers/meta.go index 5322517..9290a41 100644 --- a/server/resolvers/meta.go +++ b/server/resolvers/meta.go @@ -101,12 +101,6 @@ func MetaResolver(ctx context.Context) (*model.Meta, error) { microsoftClientSecret = "" } - microsoftActiveDirTenantID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyMicrosoftActiveDirectoryTenantID) - if err != nil { - log.Debug("Failed to get Microsoft Active Directory Tenant ID from environment variable", err) - microsoftActiveDirTenantID = "" - } - isBasicAuthDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) if err != nil { log.Debug("Failed to get Disable Basic Authentication from environment variable", err) @@ -152,7 +146,7 @@ func MetaResolver(ctx context.Context) (*model.Meta, error) { IsLinkedinLoginEnabled: linkedClientID != "" && linkedInClientSecret != "", IsAppleLoginEnabled: appleClientID != "" && appleClientSecret != "", IsTwitterLoginEnabled: twitterClientID != "" && twitterClientSecret != "", - IsMicrosoftLoginEnabled: microsoftClientID != "" && microsoftClientSecret != "" && microsoftActiveDirTenantID != "", + IsMicrosoftLoginEnabled: microsoftClientID != "" && microsoftClientSecret != "", IsBasicAuthenticationEnabled: !isBasicAuthDisabled, IsEmailVerificationEnabled: !isEmailVerificationDisabled, IsMagicLinkLoginEnabled: !isMagicLinkLoginDisabled, diff --git a/server/resolvers/mobile_login.go b/server/resolvers/mobile_login.go index 9da0a53..381f9a4 100644 --- a/server/resolvers/mobile_login.go +++ b/server/resolvers/mobile_login.go @@ -17,6 +17,7 @@ import ( "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" + "github.com/authorizerdev/authorizer/server/smsproviders" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" "github.com/authorizerdev/authorizer/server/validators" @@ -94,55 +95,67 @@ func MobileLoginResolver(ctx context.Context, params model.MobileLoginInput) (*m roles = params.Roles } + disablePhoneVerification, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisablePhoneVerification) + if err != nil { + log.Debug("Error getting disable phone verification: ", err) + } + if disablePhoneVerification { + now := time.Now().Unix() + user.PhoneNumberVerifiedAt = &now + } + isSMSServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsSMSServiceEnabled) + if err != nil || !isSMSServiceEnabled { + log.Debug("SMS service not enabled: ", err) + } + if disablePhoneVerification { + now := time.Now().Unix() + user.PhoneNumberVerifiedAt = &now + } + isMFADisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableMultiFactorAuthentication) + if err != nil || !isMFADisabled { + log.Debug("MFA service not enabled: ", err) + } + if !disablePhoneVerification && isSMSServiceEnabled && !isMFADisabled { + duration, _ := time.ParseDuration("10m") + smsCode := utils.GenerateOTP() + + smsBody := strings.Builder{} + smsBody.WriteString("Your verification code is: ") + smsBody.WriteString(smsCode) + expires := time.Now().Add(duration).Unix() + _, err := db.Provider.UpsertOTP(ctx, &models.OTP{ + PhoneNumber: params.PhoneNumber, + Otp: smsCode, + ExpiresAt: expires, + }) + if err != nil { + log.Debug("error while upserting OTP: ", err.Error()) + return nil, err + } + + mfaSession := uuid.NewString() + err = memorystore.Provider.SetMfaSession(user.ID, mfaSession, expires) + if err != nil { + log.Debug("Failed to add mfasession: ", err) + return nil, err + } + cookie.SetMfaSession(gc, mfaSession) + + go func() { + utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodMobileBasicAuth, user) + smsproviders.SendSMS(params.PhoneNumber, smsBody.String()) + }() + return &model.AuthResponse{ + Message: "Please check the OTP", + ShouldShowMobileOtpScreen: refs.NewBoolRef(true), + }, nil + } + scope := []string{"openid", "email", "profile"} if params.Scope != nil && len(scope) > 0 { scope = params.Scope } - /* - // TODO use sms authentication for MFA - isEmailServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) - if err != nil || !isEmailServiceEnabled { - log.Debug("Email service not enabled: ", err) - } - - isMFADisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableMultiFactorAuthentication) - if err != nil || !isEmailServiceEnabled { - log.Debug("MFA service not enabled: ", err) - } - - // If email service is not enabled continue the process in any way - if refs.BoolValue(user.IsMultiFactorAuthEnabled) && isEmailServiceEnabled && !isMFADisabled { - otp := utils.GenerateOTP() - otpData, err := db.Provider.UpsertOTP(ctx, &models.OTP{ - Email: user.Email, - Otp: otp, - ExpiresAt: time.Now().Add(1 * time.Minute).Unix(), - }) - if err != nil { - log.Debug("Failed to add otp: ", err) - return nil, err - } - - go func() { - // exec it as go routine so that we can reduce the api latency - go email.SendEmail([]string{params.PhoneNumber}, constants.VerificationTypeOTP, map[string]interface{}{ - "user": user.ToMap(), - "organization": utils.GetOrganization(), - "otp": otpData.Otp, - }) - if err != nil { - log.Debug("Failed to send otp email: ", err) - } - }() - - return &model.AuthResponse{ - Message: "Please check the OTP in your inbox", - ShouldShowOtpScreen: refs.NewBoolRef(true), - }, nil - } - */ - code := "" codeChallenge := "" nonce := "" @@ -165,7 +178,7 @@ func MobileLoginResolver(ctx context.Context, params model.MobileLoginInput) (*m nonce = uuid.New().String() } - authToken, err := token.CreateAuthToken(gc, *user, roles, scope, constants.AuthRecipeMethodMobileBasicAuth, nonce, code) + authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodMobileBasicAuth, nonce, code) if err != nil { log.Debug("Failed to create auth token", err) return res, err @@ -204,8 +217,8 @@ func MobileLoginResolver(ctx context.Context, params model.MobileLoginInput) (*m } go func() { - utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodMobileBasicAuth, *user) - db.Provider.AddSession(ctx, models.Session{ + utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodMobileBasicAuth, user) + db.Provider.AddSession(ctx, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(gc.Request), IP: utils.GetIP(gc.Request), diff --git a/server/resolvers/mobile_signup.go b/server/resolvers/mobile_signup.go index 9aee0a6..60b3a88 100644 --- a/server/resolvers/mobile_signup.go +++ b/server/resolvers/mobile_signup.go @@ -8,7 +8,7 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" @@ -17,9 +17,9 @@ import ( "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" + "github.com/authorizerdev/authorizer/server/smsproviders" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" - "github.com/authorizerdev/authorizer/server/smsproviders" "github.com/authorizerdev/authorizer/server/validators" ) @@ -92,7 +92,6 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) if err != nil { log.Debug("Failed to get user by email: ", err) } - if existingUser != nil { if existingUser.PhoneNumberVerifiedAt != nil { // email is verified @@ -105,7 +104,6 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) } inputRoles := []string{} - if len(params.Roles) > 0 { // check if roles exists rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) @@ -132,9 +130,9 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) } } - user := models.User{ - Email: emailInput, - PhoneNumber: &mobile, + user := &models.User{ + Email: emailInput, + PhoneNumber: &mobile, } user.Roles = strings.Join(inputRoles, ",") @@ -179,7 +177,7 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) log.Debug("MFA service not enabled: ", err) isMFAEnforced = false } - + if isMFAEnforced { user.IsMultiFactorAuthEnabled = refs.NewBoolRef(true) } @@ -189,6 +187,10 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) now := time.Now().Unix() user.PhoneNumberVerifiedAt = &now } + isSMSServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsSMSServiceEnabled) + if err != nil || !isSMSServiceEnabled { + log.Debug("SMS service not enabled: ", err) + } user.SignupMethods = constants.AuthRecipeMethodMobileBasicAuth user, err = db.Provider.AddUser(ctx, user) @@ -197,11 +199,10 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) log.Debug("Failed to add user: ", err) return res, err } - - if !disablePhoneVerification { + if !disablePhoneVerification && isSMSServiceEnabled { duration, _ := time.ParseDuration("10m") smsCode := utils.GenerateOTP() - + smsBody := strings.Builder{} smsBody.WriteString("Your verification code is: ") smsBody.WriteString(smsCode) @@ -211,15 +212,23 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) log.Debug("error while upserting user: ", err.Error()) return nil, err } - + _, err = db.Provider.UpsertOTP(ctx, &models.OTP{ + PhoneNumber: mobile, + Otp: smsCode, + ExpiresAt: time.Now().Add(duration).Unix(), + }) + if err != nil { + log.Debug("error while upserting OTP: ", err.Error()) + return nil, err + } go func() { - db.Provider.UpsertSMSRequest(ctx, &models.SMSVerificationRequest{ - PhoneNumber: mobile, - Code: smsCode, - CodeExpiresAt: time.Now().Add(duration).Unix(), - }) smsproviders.SendSMS(mobile, smsBody.String()) + utils.RegisterEvent(ctx, constants.UserCreatedWebhookEvent, constants.AuthRecipeMethodBasicAuth, user) }() + return &model.AuthResponse{ + Message: "Please check the OTP in your inbox", + ShouldShowMobileOtpScreen: refs.NewBoolRef(true), + }, nil } roles := strings.Split(user.Roles, ",") @@ -290,7 +299,9 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput) go func() { utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, constants.AuthRecipeMethodMobileBasicAuth, user) - db.Provider.AddSession(ctx, models.Session{ + // User is also logged in with signup + utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodMobileBasicAuth, user) + db.Provider.AddSession(ctx, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(gc.Request), IP: utils.GetIP(gc.Request), diff --git a/server/resolvers/resend_otp.go b/server/resolvers/resend_otp.go index 65d9cf1..74da143 100644 --- a/server/resolvers/resend_otp.go +++ b/server/resolvers/resend_otp.go @@ -12,23 +12,46 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/email" + emailHelper "github.com/authorizerdev/authorizer/server/email" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" + "github.com/authorizerdev/authorizer/server/smsproviders" "github.com/authorizerdev/authorizer/server/utils" ) // ResendOTPResolver is a resolver for resend otp mutation func ResendOTPResolver(ctx context.Context, params model.ResendOTPRequest) (*model.Response, error) { + email := strings.ToLower(strings.Trim(refs.StringValue(params.Email), " ")) + phoneNumber := strings.Trim(refs.StringValue(params.PhoneNumber), " ") log := log.WithFields(log.Fields{ - "email": params.Email, + "email": email, + "phone_number": phoneNumber, }) - params.Email = strings.ToLower(params.Email) - user, err := db.Provider.GetUserByEmail(ctx, params.Email) + if email == "" && phoneNumber == "" { + log.Debug("Email or phone number is required") + return nil, errors.New("email or phone number is required") + } + var user *models.User + var err error + if email != "" { + isEmailServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) + if err != nil || !isEmailServiceEnabled { + log.Debug("Email service not enabled: ", err) + return nil, errors.New("email service not enabled") + } + user, err = db.Provider.GetUserByEmail(ctx, email) + } else { + isSMSServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) + if err != nil || !isSMSServiceEnabled { + log.Debug("Email service not enabled: ", err) + return nil, errors.New("email service not enabled") + } + user, err = db.Provider.GetUserByPhoneNumber(ctx, phoneNumber) + } if err != nil { log.Debug("Failed to get user by email: ", err) - return nil, fmt.Errorf(`user with this email not found`) + return nil, fmt.Errorf(`user with this email/phone not found`) } if user.RevokedTimestamp != nil { @@ -36,35 +59,38 @@ func ResendOTPResolver(ctx context.Context, params model.ResendOTPRequest) (*mod return nil, fmt.Errorf(`user access has been revoked`) } - if user.EmailVerifiedAt == nil { + if email != "" && user.EmailVerifiedAt == nil { log.Debug("User email is not verified") return nil, fmt.Errorf(`email not verified`) } + if phoneNumber != "" && user.PhoneNumberVerifiedAt == nil { + log.Debug("User phone number is not verified") + return nil, fmt.Errorf(`phone number not verified`) + } + if !refs.BoolValue(user.IsMultiFactorAuthEnabled) { log.Debug("User multi factor authentication is not enabled") return nil, fmt.Errorf(`multi factor authentication not enabled`) } - isEmailServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) - if err != nil || !isEmailServiceEnabled { - log.Debug("Email service not enabled: ", err) - return nil, errors.New("email service not enabled") - } - isMFADisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableMultiFactorAuthentication) if err != nil || isMFADisabled { log.Debug("MFA service not enabled: ", err) return nil, errors.New("multi factor authentication is disabled for this instance") } - // get otp by email - otpData, err := db.Provider.GetOTPByEmail(ctx, params.Email) + // get otp by email or phone number + var otpData *models.OTP + if email != "" { + otpData, err = db.Provider.GetOTPByEmail(ctx, refs.StringValue(params.Email)) + } else { + otpData, err = db.Provider.GetOTPByPhoneNumber(ctx, refs.StringValue(params.PhoneNumber)) + } if err != nil { log.Debug("Failed to get otp for given email: ", err) return nil, err } - if otpData == nil { log.Debug("No otp found for given email: ", params.Email) return &model.Response{ @@ -73,28 +99,30 @@ func ResendOTPResolver(ctx context.Context, params model.ResendOTPRequest) (*mod } otp := utils.GenerateOTP() - otpData, err = db.Provider.UpsertOTP(ctx, &models.OTP{ + if _, err := db.Provider.UpsertOTP(ctx, &models.OTP{ Email: user.Email, Otp: otp, ExpiresAt: time.Now().Add(1 * time.Minute).Unix(), - }) - if err != nil { - log.Debug("Error generating new otp: ", err) + }); err != nil { + log.Debug("Error upserting otp: ", err) return nil, err } - go func() { + if email != "" { // exec it as go routine so that we can reduce the api latency - go email.SendEmail([]string{params.Email}, constants.VerificationTypeOTP, map[string]interface{}{ + go emailHelper.SendEmail([]string{email}, constants.VerificationTypeOTP, map[string]interface{}{ "user": user.ToMap(), "organization": utils.GetOrganization(), "otp": otp, }) - if err != nil { - log.Debug("Error sending otp email: ", otp) - } - }() - + } else { + smsBody := strings.Builder{} + smsBody.WriteString("Your verification code is: ") + smsBody.WriteString(otp) + // exec it as go routine so that we can reduce the api latency + go smsproviders.SendSMS(phoneNumber, smsBody.String()) + } + log.Info("OTP has been resent") return &model.Response{ Message: `OTP has been sent. Please check your inbox`, }, nil diff --git a/server/resolvers/resend_verify_email.go b/server/resolvers/resend_verify_email.go index 6c94024..b5a789f 100644 --- a/server/resolvers/resend_verify_email.go +++ b/server/resolvers/resend_verify_email.go @@ -67,7 +67,7 @@ func ResendVerifyEmailResolver(ctx context.Context, params model.ResendVerifyEma if err != nil { log.Debug("Failed to create verification token: ", err) } - _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{ + _, err = db.Provider.AddVerificationRequest(ctx, &models.VerificationRequest{ Token: verificationToken, Identifier: params.Identifier, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index 757efeb..cefcfb2 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -2,6 +2,8 @@ package resolvers import ( "context" + "encoding/json" + "errors" "fmt" "strings" "time" @@ -81,13 +83,15 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR log.Debug("Failed to get user by email: ", err) } - if existingUser.EmailVerifiedAt != nil { - // email is verified - log.Debug("Email is already verified and signed up.") - return res, fmt.Errorf(`%s has already signed up`, params.Email) - } else if existingUser.ID != "" && existingUser.EmailVerifiedAt == nil { - log.Debug("Email is already signed up. Verification pending...") - return res, fmt.Errorf("%s has already signed up. please complete the email verification process or reset the password", params.Email) + if existingUser != nil { + if existingUser.EmailVerifiedAt != nil { + // email is verified + log.Debug("Email is already verified and signed up.") + return res, fmt.Errorf(`%s has already signed up`, params.Email) + } else if existingUser.ID != "" && existingUser.EmailVerifiedAt == nil { + log.Debug("Email is already signed up. Verification pending...") + return res, fmt.Errorf("%s has already signed up. please complete the email verification process or reset the password", params.Email) + } } inputRoles := []string{} @@ -116,13 +120,10 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR inputRoles = strings.Split(inputRolesString, ",") } } - - user := models.User{ + user := &models.User{ Email: params.Email, } - user.Roles = strings.Join(inputRoles, ",") - password, _ := crypto.EncryptPassword(params.Password) user.Password = &password @@ -172,6 +173,17 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR user.IsMultiFactorAuthEnabled = refs.NewBoolRef(true) } + if params.AppData != nil { + appDataString := "" + appDataBytes, err := json.Marshal(params.AppData) + if err != nil { + log.Debug("failed to marshall source app_data: ", err) + return nil, errors.New("malformed app_data") + } + appDataString = string(appDataBytes) + user.AppData = &appDataString + } + user.SignupMethods = constants.AuthRecipeMethodBasicAuth isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) if err != nil { @@ -208,7 +220,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR log.Debug("Failed to create verification token: ", err) return res, err } - _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{ + _, err = db.Provider.AddVerificationRequest(ctx, &models.VerificationRequest{ Token: verificationToken, Identifier: verificationType, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), @@ -302,7 +314,9 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR go func() { utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, constants.AuthRecipeMethodBasicAuth, user) - db.Provider.AddSession(ctx, models.Session{ + // User is also logged in with signup + utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, constants.AuthRecipeMethodBasicAuth, user) + db.Provider.AddSession(ctx, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(gc.Request), IP: utils.GetIP(gc.Request), diff --git a/server/resolvers/update_email_template.go b/server/resolvers/update_email_template.go index cf4e948..c08a49f 100644 --- a/server/resolvers/update_email_template.go +++ b/server/resolvers/update_email_template.go @@ -34,7 +34,7 @@ func UpdateEmailTemplateResolver(ctx context.Context, params model.UpdateEmailTe return nil, err } - emailTemplateDetails := models.EmailTemplate{ + emailTemplateDetails := &models.EmailTemplate{ ID: emailTemplate.ID, Key: emailTemplate.ID, EventName: emailTemplate.EventName, diff --git a/server/resolvers/update_env.go b/server/resolvers/update_env.go index d9d6881..96388aa 100644 --- a/server/resolvers/update_env.go +++ b/server/resolvers/update_env.go @@ -33,7 +33,7 @@ func clearSessionIfRequired(currentData, updatedData map[string]interface{}) { isCurrentGithubLoginEnabled := currentData[constants.EnvKeyGithubClientID] != nil && currentData[constants.EnvKeyGithubClientSecret] != nil && currentData[constants.EnvKeyGithubClientID].(string) != "" && currentData[constants.EnvKeyGithubClientSecret].(string) != "" isCurrentLinkedInLoginEnabled := currentData[constants.EnvKeyLinkedInClientID] != nil && currentData[constants.EnvKeyLinkedInClientSecret] != nil && currentData[constants.EnvKeyLinkedInClientID].(string) != "" && currentData[constants.EnvKeyLinkedInClientSecret].(string) != "" isCurrentTwitterLoginEnabled := currentData[constants.EnvKeyTwitterClientID] != nil && currentData[constants.EnvKeyTwitterClientSecret] != nil && currentData[constants.EnvKeyTwitterClientID].(string) != "" && currentData[constants.EnvKeyTwitterClientSecret].(string) != "" - isCurrentMicrosoftLoginEnabled := currentData[constants.EnvKeyMicrosoftClientID] != nil && currentData[constants.EnvKeyMicrosoftClientSecret] != nil && currentData[constants.EnvKeyMicrosoftActiveDirectoryTenantID] != nil && currentData[constants.EnvKeyMicrosoftClientID].(string) != "" && currentData[constants.EnvKeyMicrosoftClientSecret].(string) != "" && currentData[constants.EnvKeyMicrosoftActiveDirectoryTenantID].(string) != "" + isCurrentMicrosoftLoginEnabled := currentData[constants.EnvKeyMicrosoftClientID] != nil && currentData[constants.EnvKeyMicrosoftClientSecret] != nil && currentData[constants.EnvKeyMicrosoftClientID].(string) != "" && currentData[constants.EnvKeyMicrosoftClientSecret].(string) != "" isUpdatedBasicAuthEnabled := !updatedData[constants.EnvKeyDisableBasicAuthentication].(bool) isUpdatedMobileBasicAuthEnabled := !updatedData[constants.EnvKeyDisableMobileBasicAuthentication].(bool) @@ -44,7 +44,7 @@ func clearSessionIfRequired(currentData, updatedData map[string]interface{}) { isUpdatedGithubLoginEnabled := updatedData[constants.EnvKeyGithubClientID] != nil && updatedData[constants.EnvKeyGithubClientSecret] != nil && updatedData[constants.EnvKeyGithubClientID].(string) != "" && updatedData[constants.EnvKeyGithubClientSecret].(string) != "" isUpdatedLinkedInLoginEnabled := updatedData[constants.EnvKeyLinkedInClientID] != nil && updatedData[constants.EnvKeyLinkedInClientSecret] != nil && updatedData[constants.EnvKeyLinkedInClientID].(string) != "" && updatedData[constants.EnvKeyLinkedInClientSecret].(string) != "" isUpdatedTwitterLoginEnabled := updatedData[constants.EnvKeyTwitterClientID] != nil && updatedData[constants.EnvKeyTwitterClientSecret] != nil && updatedData[constants.EnvKeyTwitterClientID].(string) != "" && updatedData[constants.EnvKeyTwitterClientSecret].(string) != "" - isUpdatedMicrosoftLoginEnabled := updatedData[constants.EnvKeyMicrosoftClientID] != nil && updatedData[constants.EnvKeyMicrosoftClientSecret] != nil && updatedData[constants.EnvKeyMicrosoftActiveDirectoryTenantID] != nil && updatedData[constants.EnvKeyMicrosoftClientID].(string) != "" && updatedData[constants.EnvKeyMicrosoftClientSecret].(string) != "" && updatedData[constants.EnvKeyMicrosoftActiveDirectoryTenantID].(string) != "" + isUpdatedMicrosoftLoginEnabled := updatedData[constants.EnvKeyMicrosoftClientID] != nil && updatedData[constants.EnvKeyMicrosoftClientSecret] != nil && updatedData[constants.EnvKeyMicrosoftClientID].(string) != "" && updatedData[constants.EnvKeyMicrosoftClientSecret].(string) != "" if isCurrentBasicAuthEnabled && !isUpdatedBasicAuthEnabled { memorystore.Provider.DeleteSessionForNamespace(constants.AuthRecipeMethodBasicAuth) @@ -267,6 +267,13 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model updatedData[constants.EnvKeyIsEmailServiceEnabled] = true } + if updatedData[constants.EnvKeyTwilioAPIKey] == "" || updatedData[constants.EnvKeyTwilioAPISecret] == "" || updatedData[constants.EnvKeyTwilioAccountSID] == "" || updatedData[constants.EnvKeyTwilioSender] == "" { + updatedData[constants.EnvKeyIsSMSServiceEnabled] = false + if !updatedData[constants.EnvKeyIsSMSServiceEnabled].(bool) { + updatedData[constants.EnvKeyDisablePhoneVerification] = true + } + } + if !currentData[constants.EnvKeyEnforceMultiFactorAuthentication].(bool) && updatedData[constants.EnvKeyEnforceMultiFactorAuthentication].(bool) && !updatedData[constants.EnvKeyDisableMultiFactorAuthentication].(bool) { go db.Provider.UpdateUsers(ctx, map[string]interface{}{ "is_multi_factor_auth_enabled": true, diff --git a/server/resolvers/update_profile.go b/server/resolvers/update_profile.go index da74258..c478182 100644 --- a/server/resolvers/update_profile.go +++ b/server/resolvers/update_profile.go @@ -2,6 +2,7 @@ package resolvers import ( "context" + "encoding/json" "errors" "fmt" "strings" @@ -47,7 +48,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) } // validate if all params are not empty - if params.GivenName == nil && params.FamilyName == nil && params.Picture == nil && params.MiddleName == nil && params.Nickname == nil && params.OldPassword == nil && params.Email == nil && params.Birthdate == nil && params.Gender == nil && params.PhoneNumber == nil && params.NewPassword == nil && params.ConfirmNewPassword == nil && params.IsMultiFactorAuthEnabled == nil { + if params.GivenName == nil && params.FamilyName == nil && params.Picture == nil && params.MiddleName == nil && params.Nickname == nil && params.OldPassword == nil && params.Email == nil && params.Birthdate == nil && params.Gender == nil && params.PhoneNumber == nil && params.NewPassword == nil && params.ConfirmNewPassword == nil && params.IsMultiFactorAuthEnabled == nil && params.AppData == nil { log.Debug("All params are empty") return res, fmt.Errorf("please enter at least one param to update") } @@ -56,7 +57,6 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) log := log.WithFields(log.Fields{ "user_id": userID, }) - user, err := db.Provider.GetUserByID(ctx, userID) if err != nil { log.Debug("Failed to get user by id: ", err) @@ -99,7 +99,16 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) if params.Picture != nil && refs.StringValue(user.Picture) != refs.StringValue(params.Picture) { user.Picture = params.Picture } - + if params.AppData != nil { + appDataString := "" + appDataBytes, err := json.Marshal(params.AppData) + if err != nil { + log.Debug("failed to marshall source app_data: ", err) + return nil, errors.New("malformed app_data") + } + appDataString = string(appDataBytes) + user.AppData = &appDataString + } if params.IsMultiFactorAuthEnabled != nil && refs.BoolValue(user.IsMultiFactorAuthEnabled) != refs.BoolValue(params.IsMultiFactorAuthEnabled) { if refs.BoolValue(params.IsMultiFactorAuthEnabled) { isEnvServiceEnabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsEmailServiceEnabled) @@ -242,7 +251,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) log.Debug("Failed to create verification token: ", err) return res, err } - _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{ + _, err = db.Provider.AddVerificationRequest(ctx, &models.VerificationRequest{ Token: verificationToken, Identifier: verificationType, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), diff --git a/server/resolvers/update_user.go b/server/resolvers/update_user.go index bd61428..ee751ab 100644 --- a/server/resolvers/update_user.go +++ b/server/resolvers/update_user.go @@ -2,6 +2,7 @@ package resolvers import ( "context" + "encoding/json" "errors" "fmt" "strings" @@ -95,6 +96,17 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod user.Picture = params.Picture } + if params.AppData != nil { + appDataString := "" + appDataBytes, err := json.Marshal(params.AppData) + if err != nil { + log.Debug("failed to marshall source app_data: ", err) + return nil, errors.New("malformed app_data") + } + appDataString = string(appDataBytes) + user.AppData = &appDataString + } + if params.IsMultiFactorAuthEnabled != nil && refs.BoolValue(user.IsMultiFactorAuthEnabled) != refs.BoolValue(params.IsMultiFactorAuthEnabled) { user.IsMultiFactorAuthEnabled = params.IsMultiFactorAuthEnabled if refs.BoolValue(params.IsMultiFactorAuthEnabled) { @@ -147,7 +159,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod if err != nil { log.Debug("Failed to create verification token: ", err) } - _, err = db.Provider.AddVerificationRequest(ctx, models.VerificationRequest{ + _, err = db.Provider.AddVerificationRequest(ctx, &models.VerificationRequest{ Token: verificationToken, Identifier: verificationType, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), diff --git a/server/resolvers/update_webhook.go b/server/resolvers/update_webhook.go index 5783984..3d09568 100644 --- a/server/resolvers/update_webhook.go +++ b/server/resolvers/update_webhook.go @@ -41,7 +41,7 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques } headersString = string(headerBytes) } - webhookDetails := models.Webhook{ + webhookDetails := &models.Webhook{ ID: webhook.ID, Key: webhook.ID, EventName: refs.StringValue(webhook.EventName), diff --git a/server/resolvers/validate_session.go b/server/resolvers/validate_session.go new file mode 100644 index 0000000..39adb1f --- /dev/null +++ b/server/resolvers/validate_session.go @@ -0,0 +1,61 @@ +package resolvers + +import ( + "context" + "errors" + "fmt" + + "github.com/authorizerdev/authorizer/server/cookie" + "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/token" + "github.com/authorizerdev/authorizer/server/utils" + log "github.com/sirupsen/logrus" +) + +// ValidateSessionResolver is used to validate a cookie session without its rotation +func ValidateSessionResolver(ctx context.Context, params *model.ValidateSessionInput) (*model.ValidateSessionResponse, error) { + gc, err := utils.GinContextFromContext(ctx) + if err != nil { + log.Debug("Failed to get GinContext: ", err) + return nil, err + } + sessionToken := params.Cookie + if sessionToken == "" { + sessionToken, err = cookie.GetSession(gc) + if err != nil { + log.Debug("Failed to get session token: ", err) + return nil, errors.New("unauthorized") + } + } + claims, err := token.ValidateBrowserSession(gc, sessionToken) + if err != nil { + log.Debug("Failed to validate session token", err) + return nil, errors.New("unauthorized") + } + userID := claims.Subject + log := log.WithFields(log.Fields{ + "user_id": userID, + }) + user, err := db.Provider.GetUserByID(ctx, userID) + if err != nil { + log.Debug("Failed to get user: ", err) + return nil, err + } + // refresh token has "roles" as claim + claimRoleInterface := claims.Roles + claimRoles := []string{} + claimRoles = append(claimRoles, claimRoleInterface...) + if params != nil && params.Roles != nil && len(params.Roles) > 0 { + for _, v := range params.Roles { + if !utils.StringSliceContains(claimRoles, v) { + log.Debug("User does not have required role: ", claimRoles, v) + return nil, fmt.Errorf(`unauthorized`) + } + } + } + return &model.ValidateSessionResponse{ + IsValid: true, + User: user.AsAPIUser(), + }, nil +} diff --git a/server/resolvers/verification_requests.go b/server/resolvers/verification_requests.go index 4a629de..9a55be7 100644 --- a/server/resolvers/verification_requests.go +++ b/server/resolvers/verification_requests.go @@ -27,7 +27,6 @@ func VerificationRequestsResolver(ctx context.Context, params *model.PaginatedIn } pagination := utils.GetPagination(params) - res, err := db.Provider.ListVerificationRequests(ctx, pagination) if err != nil { log.Debug("Failed to get verification requests: ", err) diff --git a/server/resolvers/verify_email.go b/server/resolvers/verify_email.go index d1fd81d..d263629 100644 --- a/server/resolvers/verify_email.go +++ b/server/resolvers/verify_email.go @@ -125,11 +125,13 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m go func() { if isSignUp { utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, loginMethod, user) + // User is also logged in with signup + utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, loginMethod, user) } else { utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, loginMethod, user) } - db.Provider.AddSession(ctx, models.Session{ + db.Provider.AddSession(ctx, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(gc.Request), IP: utils.GetIP(gc.Request), diff --git a/server/resolvers/verify_mobile.go b/server/resolvers/verify_mobile.go deleted file mode 100644 index 4e077d7..0000000 --- a/server/resolvers/verify_mobile.go +++ /dev/null @@ -1,62 +0,0 @@ -package resolvers - -import ( - "fmt" - "context" - "time" - - "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/utils" - "github.com/authorizerdev/authorizer/server/db" - log "github.com/sirupsen/logrus" -) - -func VerifyMobileResolver(ctx context.Context, params model.VerifyMobileRequest) (*model.AuthResponse, error) { - var res *model.AuthResponse - - _, err := utils.GinContextFromContext(ctx) - if err != nil { - log.Debug("Failed to get GinContext: ", err) - return res, err - } - - smsVerificationRequest, err := db.Provider.GetCodeByPhone(ctx, params.PhoneNumber) - if err != nil { - log.Debug("Failed to get sms request by phone: ", err) - return res, err - } - - if smsVerificationRequest.Code != params.Code { - log.Debug("Failed to verify request: bad credentials") - return res, fmt.Errorf(`bad credentials`) - } - - expiresIn := smsVerificationRequest.CodeExpiresAt - time.Now().Unix() - if expiresIn < 0 { - log.Debug("Failed to verify sms request: Timeout") - return res, fmt.Errorf("time expired") - } - - res = &model.AuthResponse{ - Message: "successful", - } - - user, err := db.Provider.GetUserByPhoneNumber(ctx, params.PhoneNumber) - if user.PhoneNumberVerifiedAt == nil { - now := time.Now().Unix() - user.PhoneNumberVerifiedAt = &now - } - - _, err = db.Provider.UpdateUser(ctx, *user) - if err != nil { - log.Debug("Failed to update user: ", err) - return res, err - } - - err = db.Provider.DeleteSMSRequest(ctx, smsVerificationRequest) - if err != nil { - log.Debug("Failed to delete sms request: ", err.Error()) - } - - return res, err -} diff --git a/server/resolvers/verify_otp.go b/server/resolvers/verify_otp.go index 80080d9..bab3323 100644 --- a/server/resolvers/verify_otp.go +++ b/server/resolvers/verify_otp.go @@ -28,35 +28,62 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod return res, err } - otp, err := db.Provider.GetOTPByEmail(ctx, params.Email) + mfaSession, err := cookie.GetMfaSession(gc) if err != nil { log.Debug("Failed to get otp request by email: ", err) - return res, fmt.Errorf(`invalid email: %s`, err.Error()) + return res, fmt.Errorf(`invalid session: %s`, err.Error()) } + if refs.StringValue(params.Email) == "" && refs.StringValue(params.PhoneNumber) == "" { + log.Debug("Email or phone number is required") + return res, fmt.Errorf(`email or phone_number is required`) + } + + currentField := models.FieldNameEmail + if refs.StringValue(params.Email) == "" { + currentField = models.FieldNamePhoneNumber + } + var otp *models.OTP + if currentField == models.FieldNameEmail { + otp, err = db.Provider.GetOTPByEmail(ctx, refs.StringValue(params.Email)) + } else { + otp, err = db.Provider.GetOTPByPhoneNumber(ctx, refs.StringValue(params.PhoneNumber)) + } + if otp == nil && err != nil { + log.Debugf("Failed to get otp request for %s: %s", currentField, err.Error()) + return res, fmt.Errorf(`invalid %s: %s`, currentField, err.Error()) + } if params.Otp != otp.Otp { log.Debug("Failed to verify otp request: Incorrect value") return res, fmt.Errorf(`invalid otp`) } - expiresIn := otp.ExpiresAt - time.Now().Unix() - if expiresIn < 0 { log.Debug("Failed to verify otp request: Timeout") return res, fmt.Errorf("otp expired") } - - user, err := db.Provider.GetUserByEmail(ctx, params.Email) - if err != nil { - log.Debug("Failed to get user by email: ", err) + var user *models.User + if currentField == models.FieldNameEmail { + user, err = db.Provider.GetUserByEmail(ctx, refs.StringValue(params.Email)) + } else { + user, err = db.Provider.GetUserByPhoneNumber(ctx, refs.StringValue(params.PhoneNumber)) + } + if user == nil || err != nil { + log.Debug("Failed to get user by email or phone number: ", err) return res, err } - isSignUp := user.EmailVerifiedAt == nil + if _, err := memorystore.Provider.GetMfaSession(user.ID, mfaSession); err != nil { + log.Debug("Failed to get mfa session: ", err) + return res, fmt.Errorf(`invalid session: %s`, err.Error()) + } + isSignUp := user.EmailVerifiedAt == nil && user.PhoneNumberVerifiedAt == nil // TODO - Add Login method in DB when we introduce OTP for social media login loginMethod := constants.AuthRecipeMethodBasicAuth - + if currentField == models.FieldNamePhoneNumber { + loginMethod = constants.AuthRecipeMethodMobileOTP + } roles := strings.Split(user.Roles, ",") scope := []string{"openid", "email", "profile"} code := "" @@ -97,11 +124,13 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod db.Provider.DeleteOTP(gc, otp) if isSignUp { utils.RegisterEvent(ctx, constants.UserSignUpWebhookEvent, loginMethod, user) + // User is also logged in with signup + utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, loginMethod, user) } else { utils.RegisterEvent(ctx, constants.UserLoginWebhookEvent, loginMethod, user) } - db.Provider.AddSession(ctx, models.Session{ + db.Provider.AddSession(ctx, &models.Session{ UserID: user.ID, UserAgent: utils.GetUserAgent(gc.Request), IP: utils.GetIP(gc.Request), diff --git a/server/resolvers/webhook_logs.go b/server/resolvers/webhook_logs.go index 7c9cc7b..cd7b62d 100644 --- a/server/resolvers/webhook_logs.go +++ b/server/resolvers/webhook_logs.go @@ -25,7 +25,7 @@ func WebhookLogsResolver(ctx context.Context, params *model.ListWebhookLogReques return nil, fmt.Errorf("unauthorized") } - var pagination model.Pagination + var pagination *model.Pagination var webhookID string if params != nil { @@ -37,7 +37,7 @@ func WebhookLogsResolver(ctx context.Context, params *model.ListWebhookLogReques pagination = utils.GetPagination(nil) webhookID = "" } - + // TODO fix webhookLogs, err := db.Provider.ListWebhookLogs(ctx, pagination, webhookID) if err != nil { log.Debug("failed to get webhook logs: ", err) diff --git a/server/resolvers/webhooks.go b/server/resolvers/webhooks.go index 5a6ccbb..733df82 100644 --- a/server/resolvers/webhooks.go +++ b/server/resolvers/webhooks.go @@ -25,7 +25,6 @@ func WebhooksResolver(ctx context.Context, params *model.PaginatedInput) (*model } pagination := utils.GetPagination(params) - webhooks, err := db.Provider.ListWebhook(ctx, pagination) if err != nil { log.Debug("failed to get webhooks: ", err) diff --git a/server/smsproviders/twilio.go b/server/smsproviders/twilio.go index 093e924..f4f1acb 100644 --- a/server/smsproviders/twilio.go +++ b/server/smsproviders/twilio.go @@ -1,43 +1,38 @@ package smsproviders import ( - twilio "github.com/twilio/twilio-go" - api "github.com/twilio/twilio-go/rest/api/v2010" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/memorystore" log "github.com/sirupsen/logrus" + twilio "github.com/twilio/twilio-go" + api "github.com/twilio/twilio-go/rest/api/v2010" ) +// SendSMS util to send sms // TODO: Should be restructured to interface when another provider is added func SendSMS(sendTo, messageBody string) error { - twilioAPISecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyTwilioAPISecret) - if err != nil || twilioAPISecret == ""{ - log.Errorf("Failed to get api secret: ", err) + if err != nil || twilioAPISecret == "" { + log.Debug("Failed to get api secret: ", err) return err } - twilioAPIKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyTwilioAPIKey) - if err != nil || twilioAPIKey == ""{ - log.Errorf("Failed to get api key: ", err) + if err != nil || twilioAPIKey == "" { + log.Debug("Failed to get api key: ", err) return err } - - twilioSenderFrom, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyTwilioSenderFrom) + twilioSenderFrom, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyTwilioSender) if err != nil || twilioSenderFrom == "" { - log.Errorf("Failed to get sender: ", err) + log.Debug("Failed to get sender: ", err) return err } - // accountSID is not a must to send sms on twilio twilioAccountSID, _ := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyTwilioAccountSID) - client := twilio.NewRestClientWithParams(twilio.ClientParams{ Username: twilioAPIKey, Password: twilioAPISecret, AccountSid: twilioAccountSID, }) - message := &api.CreateMessageParams{} message.SetBody(messageBody) message.SetFrom(twilioSenderFrom) diff --git a/server/test/admin_signup_test.go b/server/test/admin_signup_test.go index 9596f4d..9e42553 100644 --- a/server/test/admin_signup_test.go +++ b/server/test/admin_signup_test.go @@ -17,11 +17,10 @@ func adminSignupTests(t *testing.T, s TestSetup) { _, err := resolvers.AdminSignupResolver(ctx, model.AdminSignupInput{ AdminSecret: "admin", }) - assert.NotNil(t, err) // reset env for test to pass - memorystore.Provider.UpdateEnvVariable(constants.EnvKeyAdminSecret, "") - + err = memorystore.Provider.UpdateEnvVariable(constants.EnvKeyAdminSecret, "") + assert.Nil(t, err) _, err = resolvers.AdminSignupResolver(ctx, model.AdminSignupInput{ AdminSecret: "admin123", }) diff --git a/server/test/delete_email_template_test.go b/server/test/delete_email_template_test.go index ef79db9..c32b6d5 100644 --- a/server/test/delete_email_template_test.go +++ b/server/test/delete_email_template_test.go @@ -24,7 +24,7 @@ func deleteEmailTemplateTest(t *testing.T, s TestSetup) { req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // get all email templates - emailTemplates, err := db.Provider.ListEmailTemplate(ctx, model.Pagination{ + emailTemplates, err := db.Provider.ListEmailTemplate(ctx, &model.Pagination{ Limit: 10, Page: 1, Offset: 0, @@ -41,7 +41,7 @@ func deleteEmailTemplateTest(t *testing.T, s TestSetup) { assert.NotEmpty(t, res.Message) } - emailTemplates, err = db.Provider.ListEmailTemplate(ctx, model.Pagination{ + emailTemplates, err = db.Provider.ListEmailTemplate(ctx, &model.Pagination{ Limit: 10, Page: 1, Offset: 0, diff --git a/server/test/delete_webhook_test.go b/server/test/delete_webhook_test.go index ab9b9f2..3404a42 100644 --- a/server/test/delete_webhook_test.go +++ b/server/test/delete_webhook_test.go @@ -24,7 +24,7 @@ func deleteWebhookTest(t *testing.T, s TestSetup) { req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // get all webhooks - webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{ + webhooks, err := db.Provider.ListWebhook(ctx, &model.Pagination{ Limit: 20, Page: 1, Offset: 0, @@ -41,14 +41,14 @@ func deleteWebhookTest(t *testing.T, s TestSetup) { assert.NotEmpty(t, res.Message) } - webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{ + webhooks, err = db.Provider.ListWebhook(ctx, &model.Pagination{ Limit: 20, Page: 1, Offset: 0, }) assert.NoError(t, err) assert.Len(t, webhooks.Webhooks, 0) - webhookLogs, err := db.Provider.ListWebhookLogs(ctx, model.Pagination{ + webhookLogs, err := db.Provider.ListWebhookLogs(ctx, &model.Pagination{ Limit: 100, Page: 1, Offset: 0, diff --git a/server/test/resolvers_test.go b/server/test/integration_test.go similarity index 89% rename from server/test/resolvers_test.go rename to server/test/integration_test.go index 4c83bf3..3d4bb3d 100644 --- a/server/test/resolvers_test.go +++ b/server/test/integration_test.go @@ -9,6 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/env" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" @@ -46,7 +47,6 @@ func TestResolvers(t *testing.T) { for dbType, dbURL := range databases { ctx := context.Background() - memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseURL, dbURL) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseType, dbType) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseName, testDb) @@ -57,6 +57,11 @@ func TestResolvers(t *testing.T) { if dbType == constants.DbTypeDynamoDB { memorystore.Provider.UpdateEnvVariable(constants.EnvAwsRegion, "ap-south-1") os.Setenv(constants.EnvAwsRegion, "ap-south-1") + os.Unsetenv(constants.EnvAwsAccessKeyID) + os.Unsetenv(constants.EnvAwsSecretAccessKey) + // Remove aws credentials from env, so that local dynamodb can be used + memorystore.Provider.UpdateEnvVariable(constants.EnvAwsAccessKeyID, "") + memorystore.Provider.UpdateEnvVariable(constants.EnvAwsSecretAccessKey, "") } if dbType == constants.DbTypeCouchbaseDB { memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseUsername, "Administrator") @@ -74,8 +79,10 @@ func TestResolvers(t *testing.T) { // clean the persisted config for test to use fresh config envData, err := db.Provider.GetEnv(ctx) - if err == nil && envData.ID != "" { - envData.EnvData = "" + if err == nil && envData == nil { + envData = &models.Env{ + EnvData: "", + } _, err = db.Provider.UpdateEnv(ctx, envData) if err != nil { t.Logf("Error updating env: %s", err.Error()) @@ -135,7 +142,7 @@ func TestResolvers(t *testing.T) { validateJwtTokenTest(t, s) verifyOTPTest(t, s) resendOTPTest(t, s) - verifyMobileTest(t, s) + validateSessionTests(t, s) updateAllUsersTest(t, s) webhookLogsTest(t, s) // get logs after above resolver tests are done diff --git a/server/test/mobile_login_test.go b/server/test/mobile_login_test.go index 48b7690..6f0823c 100644 --- a/server/test/mobile_login_test.go +++ b/server/test/mobile_login_test.go @@ -1,14 +1,18 @@ package test import ( + "fmt" "strings" "testing" + "time" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -26,11 +30,6 @@ func mobileLoginTests(t *testing.T, s TestSetup) { }) assert.NoError(t, err) assert.NotNil(t, signUpRes) - assert.Equal(t, email, signUpRes.User.Email) - assert.Equal(t, phoneNumber, refs.StringValue(signUpRes.User.PhoneNumber)) - assert.True(t, strings.Contains(signUpRes.User.SignupMethods, constants.AuthRecipeMethodMobileBasicAuth)) - assert.Len(t, strings.Split(signUpRes.User.SignupMethods, ","), 1) - res, err := resolvers.MobileLoginResolver(ctx, model.MobileLoginInput{ PhoneNumber: phoneNumber, Password: "random_test", @@ -45,7 +44,6 @@ func mobileLoginTests(t *testing.T, s TestSetup) { }) assert.Error(t, err) assert.Nil(t, res) - // should fail because phone is not verified res, err = resolvers.MobileLoginResolver(ctx, model.MobileLoginInput{ PhoneNumber: phoneNumber, @@ -53,26 +51,28 @@ func mobileLoginTests(t *testing.T, s TestSetup) { }) assert.NotNil(t, err, "should fail because phone is not verified") assert.Nil(t, res) - - smsRequest, err := db.Provider.GetCodeByPhone(ctx, phoneNumber) + smsRequest, err := db.Provider.GetOTPByPhoneNumber(ctx, phoneNumber) assert.NoError(t, err) - assert.NotEmpty(t, smsRequest.Code) - - verifySMSRequest, err := resolvers.VerifyMobileResolver(ctx, model.VerifyMobileRequest{ - PhoneNumber: phoneNumber, - Code: smsRequest.Code, + assert.NotEmpty(t, smsRequest.Otp) + // Get user by phone number + user, err := db.Provider.GetUserByPhoneNumber(ctx, phoneNumber) + assert.NoError(t, err) + assert.NotNil(t, user) + // Set mfa cookie session + mfaSession := uuid.NewString() + memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix()) + cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession) + cookie = strings.TrimSuffix(cookie, ";") + req, ctx := createContext(s) + req.Header.Set("Cookie", cookie) + verifySMSRequest, err := resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{ + PhoneNumber: &phoneNumber, + Otp: smsRequest.Otp, }) assert.Nil(t, err) assert.NotEqual(t, verifySMSRequest.Message, "", "message should not be empty") - - res, err = resolvers.MobileLoginResolver(ctx, model.MobileLoginInput{ - PhoneNumber: phoneNumber, - Password: s.TestInfo.Password, - }) - assert.NoError(t, err) - assert.NotEmpty(t, res.AccessToken) - assert.NotEmpty(t, res.IDToken) - + assert.NotEmpty(t, verifySMSRequest.AccessToken) + assert.NotEmpty(t, verifySMSRequest.IDToken) cleanData(email) }) } diff --git a/server/test/mobile_signup_test.go b/server/test/mobile_signup_test.go index 11deccc..e0982e1 100644 --- a/server/test/mobile_signup_test.go +++ b/server/test/mobile_signup_test.go @@ -1,13 +1,18 @@ package test import ( + "fmt" + "strings" "testing" + "time" "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -65,16 +70,36 @@ func mobileSingupTest(t *testing.T, s TestSetup) { }) assert.Error(t, err) assert.Nil(t, res) - + phoneNumber := "1234567890" res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ - PhoneNumber: "1234567890", + PhoneNumber: phoneNumber, Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) assert.NoError(t, err) - assert.NotEmpty(t, res.AccessToken) - assert.Equal(t, "1234567890@authorizer.dev", res.User.Email) - + assert.NotNil(t, res) + assert.True(t, *res.ShouldShowMobileOtpScreen) + // Verify with otp + otp, err := db.Provider.GetOTPByPhoneNumber(ctx, phoneNumber) + assert.Nil(t, err) + assert.NotEmpty(t, otp.Otp) + // Get user by phone number + user, err := db.Provider.GetUserByPhoneNumber(ctx, phoneNumber) + assert.NoError(t, err) + assert.NotNil(t, user) + // Set mfa cookie session + mfaSession := uuid.NewString() + memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix()) + cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession) + cookie = strings.TrimSuffix(cookie, ";") + req, ctx := createContext(s) + req.Header.Set("Cookie", cookie) + otpRes, err := resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{ + PhoneNumber: &phoneNumber, + Otp: otp.Otp, + }) + assert.Nil(t, err) + assert.NotEmpty(t, otpRes.Message) res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ PhoneNumber: "1234567890", Password: s.TestInfo.Password, diff --git a/server/test/resend_otp_test.go b/server/test/resend_otp_test.go index 73e715d..3f1e738 100644 --- a/server/test/resend_otp_test.go +++ b/server/test/resend_otp_test.go @@ -2,13 +2,18 @@ package test import ( "context" + "fmt" + "strings" "testing" + "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -51,7 +56,7 @@ func resendOTPTest(t *testing.T, s TestSetup) { assert.NotNil(t, updateRes) // Resend otp should return error as no initial opt is being sent resendOtpRes, err := resolvers.ResendOTPResolver(ctx, model.ResendOTPRequest{ - Email: email, + Email: refs.NewStringRef(email), }) assert.Error(t, err) assert.Nil(t, resendOtpRes) @@ -72,7 +77,7 @@ func resendOTPTest(t *testing.T, s TestSetup) { // resend otp resendOtpRes, err = resolvers.ResendOTPResolver(ctx, model.ResendOTPRequest{ - Email: email, + Email: refs.NewStringRef(email), }) assert.NoError(t, err) assert.NotEmpty(t, resendOtpRes.Message) @@ -84,13 +89,23 @@ func resendOTPTest(t *testing.T, s TestSetup) { // Should return error for older otp verifyOtpRes, err := resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{ - Email: email, + Email: &email, Otp: otp.Otp, }) assert.Error(t, err) assert.Nil(t, verifyOtpRes) + // Get user by email + user, err := db.Provider.GetUserByEmail(ctx, email) + assert.NoError(t, err) + assert.NotNil(t, user) + // Set mfa cookie session + mfaSession := uuid.NewString() + memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix()) + cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession) + cookie = strings.TrimSuffix(cookie, ";") + req.Header.Set("Cookie", cookie) verifyOtpRes, err = resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{ - Email: email, + Email: &email, Otp: newOtp.Otp, }) assert.NoError(t, err) diff --git a/server/test/test.go b/server/test/test.go index b2727ea..217ad39 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -126,6 +126,10 @@ func testSetup() TestSetup { memorystore.Provider.UpdateEnvVariable(constants.EnvKeySmtpPassword, "test") memorystore.Provider.UpdateEnvVariable(constants.EnvKeySenderEmail, "info@yopmail.com") memorystore.Provider.UpdateEnvVariable(constants.EnvKeyProtectedRoles, "admin") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyTwilioAPIKey, "test") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyTwilioAPISecret, "test") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyTwilioAccountSID, "test") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyTwilioSender, "1234567890") err = db.InitDB() if err != nil { diff --git a/server/test/update_all_users_tests.go b/server/test/update_all_users_tests.go index 375158f..b2e507f 100644 --- a/server/test/update_all_users_tests.go +++ b/server/test/update_all_users_tests.go @@ -18,7 +18,7 @@ func updateAllUsersTest(t *testing.T, s TestSetup) { t.Run("Should update all users", func(t *testing.T) { _, ctx := createContext(s) for i := 0; i < 10; i++ { - user := models.User{ + user := &models.User{ Email: fmt.Sprintf("update_all_user_%d_%s", i, s.TestInfo.Email), SignupMethods: constants.AuthRecipeMethodBasicAuth, Roles: "user", @@ -33,7 +33,7 @@ func updateAllUsersTest(t *testing.T, s TestSetup) { }, nil) assert.NoError(t, err) - listUsers, err := db.Provider.ListUsers(ctx, model.Pagination{ + listUsers, err := db.Provider.ListUsers(ctx, &model.Pagination{ Limit: 20, Offset: 0, }) @@ -49,7 +49,7 @@ func updateAllUsersTest(t *testing.T, s TestSetup) { }, updateIds) assert.NoError(t, err) - listUsers, err = db.Provider.ListUsers(ctx, model.Pagination{ + listUsers, err = db.Provider.ListUsers(ctx, &model.Pagination{ Limit: 20, Offset: 0, }) diff --git a/server/test/update_webhook_test.go b/server/test/update_webhook_test.go index 14ccb94..6e2d023 100644 --- a/server/test/update_webhook_test.go +++ b/server/test/update_webhook_test.go @@ -27,7 +27,7 @@ func updateWebhookTest(t *testing.T, s TestSetup) { webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) assert.NoError(t, err) assert.NotNil(t, webhooks) - assert.Equal(t, 2, len(webhooks)) + assert.GreaterOrEqual(t, len(webhooks), 2) for _, webhook := range webhooks { // it should completely replace headers webhook.Headers = map[string]interface{}{ @@ -58,7 +58,7 @@ func updateWebhookTest(t *testing.T, s TestSetup) { // Check if webhooks with new name is as per expected len accessWebhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserAccessEnabledWebhookEvent) assert.NoError(t, err) - assert.Equal(t, 3, len(accessWebhooks)) + assert.GreaterOrEqual(t, len(accessWebhooks), 3) // Revert name change res, err = resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ ID: w.ID, @@ -69,7 +69,7 @@ func updateWebhookTest(t *testing.T, s TestSetup) { updatedWebhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) assert.NoError(t, err) assert.NotNil(t, updatedWebhooks) - assert.Equal(t, 2, len(updatedWebhooks)) + assert.GreaterOrEqual(t, len(updatedWebhooks), 2) for _, updatedWebhook := range updatedWebhooks { assert.Contains(t, refs.StringValue(updatedWebhook.EventName), constants.UserDeletedWebhookEvent) assert.Len(t, updatedWebhook.Headers, 1) diff --git a/server/test/validate_jwt_token_test.go b/server/test/validate_jwt_token_test.go index d2ab257..f9a108a 100644 --- a/server/test/validate_jwt_token_test.go +++ b/server/test/validate_jwt_token_test.go @@ -39,7 +39,7 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) { }) scope := []string{"openid", "email", "profile", "offline_access"} - user := models.User{ + user := &models.User{ ID: uuid.New().String(), Email: "jwt_test_" + s.TestInfo.Email, Roles: "user", diff --git a/server/test/validate_session_test.go b/server/test/validate_session_test.go new file mode 100644 index 0000000..b9573cb --- /dev/null +++ b/server/test/validate_session_test.go @@ -0,0 +1,62 @@ +package test + +import ( + "fmt" + "strings" + "testing" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/authorizerdev/authorizer/server/token" + "github.com/stretchr/testify/assert" +) + +// ValidateSessionTests tests all the validate session resolvers +func validateSessionTests(t *testing.T, s TestSetup) { + t.Helper() + t.Run(`should validate session`, func(t *testing.T) { + req, ctx := createContext(s) + email := "validate_session." + s.TestInfo.Email + + resolvers.SignupResolver(ctx, model.SignUpInput{ + Email: email, + Password: s.TestInfo.Password, + ConfirmPassword: s.TestInfo.Password, + }) + _, err := resolvers.ValidateSessionResolver(ctx, &model.ValidateSessionInput{}) + assert.NotNil(t, err, "unauthorized") + verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) + verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ + Token: verificationRequest.Token, + }) + assert.NoError(t, err) + assert.NotNil(t, verifyRes) + accessToken := *verifyRes.AccessToken + assert.NotEmpty(t, accessToken) + claims, err := token.ParseJWTToken(accessToken) + assert.NoError(t, err) + assert.NotEmpty(t, claims) + sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + verifyRes.User.ID + sessionToken, err := memorystore.Provider.GetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+claims["nonce"].(string)) + assert.NoError(t, err) + assert.NotEmpty(t, sessionToken) + cookie := fmt.Sprintf("%s=%s;", constants.AppCookieName+"_session", sessionToken) + cookie = strings.TrimSuffix(cookie, ";") + res, err := resolvers.ValidateSessionResolver(ctx, &model.ValidateSessionInput{ + Cookie: sessionToken, + }) + assert.Nil(t, err) + assert.True(t, res.IsValid) + req.Header.Set("Cookie", cookie) + res, err = resolvers.ValidateSessionResolver(ctx, &model.ValidateSessionInput{}) + assert.Nil(t, err) + assert.True(t, res.IsValid) + assert.Equal(t, res.User.ID, verifyRes.User.ID) + cleanData(email) + }) +} diff --git a/server/test/verification_requests_test.go b/server/test/verification_requests_test.go index 0d0ce65..e5d5d73 100644 --- a/server/test/verification_requests_test.go +++ b/server/test/verification_requests_test.go @@ -44,9 +44,7 @@ func verificationRequestsTest(t *testing.T, s TestSetup) { assert.Nil(t, err) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) requests, err = resolvers.VerificationRequestsResolver(ctx, pagination) - assert.Nil(t, err) - rLen := len(requests.VerificationRequests) assert.GreaterOrEqual(t, rLen, 1) diff --git a/server/test/verify_mobile_test.go b/server/test/verify_mobile_test.go deleted file mode 100644 index b4daa5e..0000000 --- a/server/test/verify_mobile_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package test - -import ( - "strings" - "testing" - - "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/refs" - "github.com/authorizerdev/authorizer/server/resolvers" - "github.com/stretchr/testify/assert" -) - -func verifyMobileTest(t *testing.T, s TestSetup) { - t.Helper() - t.Run(`should verify mobile`, func(t *testing.T) { - _, ctx := createContext(s) - email := "mobile_verification." + s.TestInfo.Email - phoneNumber := "2234567890" - signUpRes, err := resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ - Email: refs.NewStringRef(email), - PhoneNumber: phoneNumber, - Password: s.TestInfo.Password, - ConfirmPassword: s.TestInfo.Password, - }) - assert.NoError(t, err) - assert.NotNil(t, signUpRes) - assert.Equal(t, email, signUpRes.User.Email) - assert.Equal(t, phoneNumber, refs.StringValue(signUpRes.User.PhoneNumber)) - assert.True(t, strings.Contains(signUpRes.User.SignupMethods, constants.AuthRecipeMethodMobileBasicAuth)) - assert.Len(t, strings.Split(signUpRes.User.SignupMethods, ","), 1) - - res, err := resolvers.MobileLoginResolver(ctx, model.MobileLoginInput{ - PhoneNumber: phoneNumber, - Password: "random_test", - }) - assert.Error(t, err) - assert.Nil(t, res) - - // should fail because phone is not verified - res, err = resolvers.MobileLoginResolver(ctx, model.MobileLoginInput{ - PhoneNumber: phoneNumber, - Password: s.TestInfo.Password, - }) - assert.NotNil(t, err, "should fail because phone is not verified") - assert.Nil(t, res) - - // get code from db - smsRequest, err := db.Provider.GetCodeByPhone(ctx, phoneNumber) - assert.NoError(t, err) - assert.NotEmpty(t, smsRequest.Code) - - // throw an error if the code is not correct - verifySMSRequest, err := resolvers.VerifyMobileResolver(ctx, model.VerifyMobileRequest{ - PhoneNumber: phoneNumber, - Code: "rand_12@1", - }) - assert.NotNil(t, err, "should fail because of bad credentials") - assert.Nil(t, verifySMSRequest) - - verifySMSRequest, err = resolvers.VerifyMobileResolver(ctx, model.VerifyMobileRequest{ - PhoneNumber: phoneNumber, - Code: smsRequest.Code, - }) - assert.Nil(t, err) - assert.NotEqual(t, verifySMSRequest.Message, "", "message should not be empty") - - res, err = resolvers.MobileLoginResolver(ctx, model.MobileLoginInput{ - PhoneNumber: phoneNumber, - Password: s.TestInfo.Password, - }) - assert.NoError(t, err) - assert.NotEmpty(t, res.AccessToken) - assert.NotEmpty(t, res.IDToken) - - cleanData(email) - }) -} diff --git a/server/test/verify_otp_test.go b/server/test/verify_otp_test.go index 9e074cd..455ac12 100644 --- a/server/test/verify_otp_test.go +++ b/server/test/verify_otp_test.go @@ -2,13 +2,18 @@ package test import ( "context" + "fmt" + "strings" "testing" + "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/refs" "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -63,9 +68,18 @@ func verifyOTPTest(t *testing.T, s TestSetup) { otp, err := db.Provider.GetOTPByEmail(ctx, email) assert.NoError(t, err) assert.NotEmpty(t, otp.Otp) - + // Get user by email + user, err := db.Provider.GetUserByEmail(ctx, email) + assert.NoError(t, err) + assert.NotNil(t, user) + // Set mfa cookie session + mfaSession := uuid.NewString() + memorystore.Provider.SetMfaSession(user.ID, mfaSession, time.Now().Add(1*time.Minute).Unix()) + cookie := fmt.Sprintf("%s=%s;", constants.MfaCookieName+"_session", mfaSession) + cookie = strings.TrimSuffix(cookie, ";") + req.Header.Set("Cookie", cookie) verifyOtpRes, err := resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{ - Email: email, + Email: &email, Otp: otp.Otp, }) assert.Nil(t, err) diff --git a/server/test/webhook_test.go b/server/test/webhook_test.go index 0fb789f..a556f9d 100644 --- a/server/test/webhook_test.go +++ b/server/test/webhook_test.go @@ -28,7 +28,7 @@ func webhookTest(t *testing.T, s TestSetup) { webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent) assert.NoError(t, err) assert.NotNil(t, webhooks) - assert.Equal(t, 2, len(webhooks)) + assert.GreaterOrEqual(t, len(webhooks), 2) for _, webhook := range webhooks { res, err := resolvers.WebhookResolver(ctx, model.WebhookRequest{ ID: webhook.ID, diff --git a/server/test/webhooks_test.go b/server/test/webhooks_test.go index 6ed1bb2..74cad74 100644 --- a/server/test/webhooks_test.go +++ b/server/test/webhooks_test.go @@ -30,6 +30,6 @@ func webhooksTest(t *testing.T, s TestSetup) { }) assert.NoError(t, err) assert.NotEmpty(t, webhooks) - assert.Len(t, webhooks.Webhooks, len(s.TestInfo.TestWebhookEventTypes)*2) + assert.GreaterOrEqual(t, len(webhooks.Webhooks), len(s.TestInfo.TestWebhookEventTypes)*2) }) } diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 6d2c942..f482db8 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -51,7 +51,7 @@ type SessionData struct { } // CreateAuthToken creates a new auth token when userlogs in -func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, loginMethod, nonce string, code string) (*Token, error) { +func CreateAuthToken(gc *gin.Context, user *models.User, roles, scope []string, loginMethod, nonce string, code string) (*Token, error) { hostname := parsers.GetHost(gc) _, fingerPrintHash, sessionTokenExpiresAt, err := CreateSessionToken(user, nonce, roles, scope, loginMethod) if err != nil { @@ -104,7 +104,7 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, l } // CreateSessionToken creates a new session token -func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, int64, error) { +func CreateSessionToken(user *models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, int64, error) { expiresAt := time.Now().AddDate(1, 0, 0).Unix() fingerPrintMap := &SessionData{ Nonce: nonce, @@ -125,7 +125,7 @@ func CreateSessionToken(user models.User, nonce string, roles, scope []string, l } // CreateRefreshToken util to create JWT token -func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonce, loginMethod string) (string, int64, error) { +func CreateRefreshToken(user *models.User, roles, scopes []string, hostname, nonce, loginMethod string) (string, int64, error) { // expires in 1 year expiryBound := time.Hour * 8760 expiresAt := time.Now().Add(expiryBound).Unix() @@ -157,7 +157,7 @@ func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonc // CreateAccessToken util to create JWT token, based on // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT -func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce, loginMethod string) (string, int64, error) { +func CreateAccessToken(user *models.User, roles, scopes []string, hostName, nonce, loginMethod string) (string, int64, error) { expireTime, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime) if err != nil { return "", 0, err @@ -372,7 +372,7 @@ func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionD // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT // For response_type (code) / authorization_code grant nonce should be empty // for implicit flow it should be present to verify with actual state -func CreateIDToken(user models.User, roles []string, hostname, nonce, atHash, cHash, loginMethod string) (string, int64, error) { +func CreateIDToken(user *models.User, roles []string, hostname, nonce, atHash, cHash, loginMethod string) (string, int64, error) { expireTime, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime) if err != nil { return "", 0, err diff --git a/server/utils/pagination.go b/server/utils/pagination.go index e27eb53..384eebd 100644 --- a/server/utils/pagination.go +++ b/server/utils/pagination.go @@ -7,10 +7,9 @@ import ( // GetPagination helps getting pagination data from paginated input // also returns default limit and offset if pagination data is not present -func GetPagination(paginatedInput *model.PaginatedInput) model.Pagination { +func GetPagination(paginatedInput *model.PaginatedInput) *model.Pagination { limit := int64(constants.DefaultLimit) page := int64(1) - if paginatedInput != nil && paginatedInput.Pagination != nil { if paginatedInput.Pagination.Limit != nil { limit = *paginatedInput.Pagination.Limit @@ -21,7 +20,7 @@ func GetPagination(paginatedInput *model.PaginatedInput) model.Pagination { } } - return model.Pagination{ + return &model.Pagination{ Limit: limit, Offset: (page - 1) * limit, Page: page, diff --git a/server/utils/webhook.go b/server/utils/webhook.go index 705c571..2571cf5 100644 --- a/server/utils/webhook.go +++ b/server/utils/webhook.go @@ -16,10 +16,12 @@ import ( log "github.com/sirupsen/logrus" ) -func RegisterEvent(ctx context.Context, eventName string, authRecipe string, user models.User) error { +// RegisterEvent util to register event +// TODO change user to user ref +func RegisterEvent(ctx context.Context, eventName string, authRecipe string, user *models.User) error { webhooks, err := db.Provider.GetWebhookByEventName(ctx, eventName) if err != nil { - log.Debug("Error getting webhook: %v", err) + log.Debug("error getting webhook: %v", err) return err } for _, webhook := range webhooks { @@ -61,7 +63,7 @@ func RegisterEvent(ctx context.Context, eventName string, authRecipe string, use continue } if envKey == constants.TestEnv { - _, err := db.Provider.AddWebhookLog(ctx, models.WebhookLog{ + _, err := db.Provider.AddWebhookLog(ctx, &models.WebhookLog{ HttpStatus: 200, Request: string(requestBody), Response: string(`{"message": "test"}`), @@ -102,7 +104,7 @@ func RegisterEvent(ctx context.Context, eventName string, authRecipe string, use } statusCode := int64(resp.StatusCode) - _, err = db.Provider.AddWebhookLog(ctx, models.WebhookLog{ + _, err = db.Provider.AddWebhookLog(ctx, &models.WebhookLog{ HttpStatus: statusCode, Request: string(requestBody), Response: string(responseBytes),