diff --git a/.env.sample b/.env.sample index c3ca441..6f9874b 100644 --- a/.env.sample +++ b/.env.sample @@ -1,3 +1,4 @@ +ENV=production DATABASE_URL=data.db DATABASE_TYPE=sqlite CUSTOM_ACCESS_TOKEN_SCRIPT="function(user,tokenPayload){var data = tokenPayload;data.extra = {'x-extra-id': user.id};return data;}" \ No newline at end of file diff --git a/.env.test b/.env.test new file mode 100644 index 0000000..0df0238 --- /dev/null +++ b/.env.test @@ -0,0 +1,9 @@ +ENV=test +DATABASE_URL=test.db +DATABASE_TYPE=sqlite +CUSTOM_ACCESS_TOKEN_SCRIPT="function(user,tokenPayload){var data = tokenPayload;data.extra = {'x-extra-id': user.id};return data;}" +SMTP_HOST=smtp.mailtrap.io +SMTP_PORT=2525 +SMTP_USERNAME=test +SMTP_PASSWORD=test +SENDER_EMAIL="info@authorizer.dev" \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8b70df4..7fdb338 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ dashboard/build build .env data.db +test.db .DS_Store .env.local *.tar.gz diff --git a/Makefile b/Makefile index 635edec..c883a9c 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ build-dashboard: clean: rm -rf build test: - cd server && go clean --testcache && go test -v ./test + rm -rf server/test/test.db && rm -rf test.db && cd server && go clean --testcache && go test -p 1 -v ./test generate: cd server && go get github.com/99designs/gqlgen/cmd@v0.14.0 && go run github.com/99designs/gqlgen generate \ No newline at end of file diff --git a/dashboard/src/components/EnvComponents/DatabaseCredentials.tsx b/dashboard/src/components/EnvComponents/DatabaseCredentials.tsx index 12c85e1..510b43c 100644 --- a/dashboard/src/components/EnvComponents/DatabaseCredentials.tsx +++ b/dashboard/src/components/EnvComponents/DatabaseCredentials.tsx @@ -1,88 +1,89 @@ -import React from "react"; -import { Flex, Stack, Center, Text, useMediaQuery } from "@chakra-ui/react"; +import React from 'react'; +import { Flex, Stack, Center, Text, useMediaQuery } from '@chakra-ui/react'; -import InputField from "../../components/InputField"; -import { TextInputType } from "../../constants"; +import InputField from '../../components/InputField'; +import { TextInputType } from '../../constants'; const DatabaseCredentials = ({ variables, setVariables }: any) => { - const [isNotSmallerScreen] = useMediaQuery("(min-width:600px)"); - return ( -
- {" "} - - Database Credentials - - - - Note: Database related environment variables cannot be updated from - dashboard :( - - - - DataBase Name: - -
- -
-
- - - DataBase Type: - -
- -
-
- - - DataBase URL: - -
- -
-
-
-
- ); + const [isNotSmallerScreen] = useMediaQuery('(min-width:600px)'); + return ( +
+ {' '} + + Database Credentials + + + + Note: Database related environment variables cannot be updated from + dashboard. Please use .env file or OS environment variables to update + it. + + + + DataBase Name: + +
+ +
+
+ + + DataBase Type: + +
+ +
+
+ + + DataBase URL: + +
+ +
+
+
+
+ ); }; -export default DatabaseCredentials; \ No newline at end of file +export default DatabaseCredentials; diff --git a/dashboard/src/components/EnvComponents/UICustomization.tsx b/dashboard/src/components/EnvComponents/Features.tsx similarity index 95% rename from dashboard/src/components/EnvComponents/UICustomization.tsx rename to dashboard/src/components/EnvComponents/Features.tsx index 756c739..aa5ed33 100644 --- a/dashboard/src/components/EnvComponents/UICustomization.tsx +++ b/dashboard/src/components/EnvComponents/Features.tsx @@ -3,7 +3,7 @@ import { Flex, Stack, Text } from '@chakra-ui/react'; import InputField from '../InputField'; import { SwitchInputType } from '../../constants'; -const UICustomization = ({ variables, setVariables }: any) => { +const Features = ({ variables, setVariables }: any) => { return (
{' '} @@ -76,4 +76,4 @@ const UICustomization = ({ variables, setVariables }: any) => { ); }; -export default UICustomization; +export default Features; diff --git a/dashboard/src/components/EnvComponents/SessionStorage.tsx b/dashboard/src/components/EnvComponents/SessionStorage.tsx index 8570f47..1aee4e2 100644 --- a/dashboard/src/components/EnvComponents/SessionStorage.tsx +++ b/dashboard/src/components/EnvComponents/SessionStorage.tsx @@ -1,36 +1,42 @@ -import React from "react"; -import { Flex, Stack, Center, Text, useMediaQuery } from "@chakra-ui/react"; -import InputField from "../InputField"; +import React from 'react'; +import { Flex, Stack, Center, Text, useMediaQuery } from '@chakra-ui/react'; +import InputField from '../InputField'; const SessionStorage = ({ variables, setVariables, RedisURL }: any) => { - const [isNotSmallerScreen] = useMediaQuery("(min-width:600px)"); - return ( -
- {" "} - - Session Storage - - - - - Redis URL: - -
- -
-
-
-
- ); + const [isNotSmallerScreen] = useMediaQuery('(min-width:600px)'); + return ( +
+ {' '} + + Session Storage + + + Note: Redis related environment variables cannot be updated from + dashboard. Please use .env file or OS environment variables to update + it. + + + + + Redis URL: + +
+ +
+
+
+
+ ); }; -export default SessionStorage; \ No newline at end of file +export default SessionStorage; diff --git a/dashboard/src/components/InviteMembersModal.tsx b/dashboard/src/components/InviteMembersModal.tsx index 2878722..9ede01b 100644 --- a/dashboard/src/components/InviteMembersModal.tsx +++ b/dashboard/src/components/InviteMembersModal.tsx @@ -22,7 +22,7 @@ import { InputRightElement, Text, Link, - Tooltip + Tooltip, } from '@chakra-ui/react'; import { useClient } from 'urql'; import { FaUserPlus, FaMinusCircle, FaPlus, FaUpload } from 'react-icons/fa'; @@ -187,22 +187,22 @@ const InviteMembersModal = ({ isDisabled={disabled} size="sm" > -
- {disabled ? ( - - Invite Members - - ) : ( - "Invite Members" - )} -
{" "} +
+ {disabled ? ( + + Invite Members + + ) : ( + 'Invite Members' + )} +
{' '} diff --git a/dashboard/src/components/Menu.tsx b/dashboard/src/components/Menu.tsx index 7822062..8593bb0 100644 --- a/dashboard/src/components/Menu.tsx +++ b/dashboard/src/components/Menu.tsx @@ -98,9 +98,9 @@ const LinkItems: Array = [ }, { name: 'Access Token', icon: SiOpenaccess, route: '/access-token' }, { - name: 'UI Customization', + name: 'Features', icon: BiCustomize, - route: '/ui-customization', + route: '/features', }, { name: 'Database', icon: RiDatabase2Line, route: '/db-cred' }, { diff --git a/dashboard/src/constants.ts b/dashboard/src/constants.ts index 3b9986f..e0f9d3d 100644 --- a/dashboard/src/constants.ts +++ b/dashboard/src/constants.ts @@ -62,6 +62,7 @@ export const SwitchInputType = { DISABLE_EMAIL_VERIFICATION: 'DISABLE_EMAIL_VERIFICATION', DISABLE_BASIC_AUTHENTICATION: 'DISABLE_BASIC_AUTHENTICATION', DISABLE_SIGN_UP: 'DISABLE_SIGN_UP', + DISABLE_REDIS_FOR_ENV: 'DISABLE_REDIS_FOR_ENV', }; export const DateInputType = { @@ -138,7 +139,7 @@ export const envSubViews = { WHITELIST_VARIABLES: 'whitelist-variables', ORGANIZATION_INFO: 'organization-info', ACCESS_TOKEN: 'access-token', - UI_CUSTOMIZATION: 'ui-customization', + FEATURES: 'features', ADMIN_SECRET: 'admin-secret', DB_CRED: 'db-cred', }; diff --git a/dashboard/src/graphql/queries/index.ts b/dashboard/src/graphql/queries/index.ts index 1adf02c..cd55475 100644 --- a/dashboard/src/graphql/queries/index.ts +++ b/dashboard/src/graphql/queries/index.ts @@ -49,6 +49,7 @@ export const EnvVariablesQuery = ` DISABLE_EMAIL_VERIFICATION, DISABLE_BASIC_AUTHENTICATION, DISABLE_SIGN_UP, + DISABLE_REDIS_FOR_ENV, CUSTOM_ACCESS_TOKEN_SCRIPT, DATABASE_NAME, DATABASE_TYPE, diff --git a/dashboard/src/pages/Environment.tsx b/dashboard/src/pages/Environment.tsx index 78f79c4..169c62f 100644 --- a/dashboard/src/pages/Environment.tsx +++ b/dashboard/src/pages/Environment.tsx @@ -25,7 +25,7 @@ import EmailConfigurations from '../components/EnvComponents/EmailConfiguration' import DomainWhiteListing from '../components/EnvComponents/DomainWhitelisting'; import OrganizationInfo from '../components/EnvComponents/OrganizationInfo'; import AccessToken from '../components/EnvComponents/AccessToken'; -import UICustomization from '../components/EnvComponents/UICustomization'; +import Features from '../components/EnvComponents/Features'; import SecurityAdminSecret from '../components/EnvComponents/SecurityAdminSecret'; import DatabaseCredentials from '../components/EnvComponents/DatabaseCredentials'; @@ -259,12 +259,9 @@ const Environment = () => { setVariables={setEnvVariables} /> ); - case envSubViews.UI_CUSTOMIZATION: + case envSubViews.FEATURES: return ( - + ); case envSubViews.ADMIN_SECRET: return ( diff --git a/server/cli/cli.go b/server/cli/cli.go new file mode 100644 index 0000000..391d632 --- /dev/null +++ b/server/cli/cli.go @@ -0,0 +1,14 @@ +package cli + +var ( + // ARG_DB_URL is the cli arg variable for the database url + ARG_DB_URL *string + // ARG_DB_TYPE is the cli arg variable for the database type + ARG_DB_TYPE *string + // ARG_ENV_FILE is the cli arg variable for the env file + ARG_ENV_FILE *string + // ARG_LOG_LEVEL is the cli arg variable for the log level + ARG_LOG_LEVEL *string + // ARG_REDIS_URL is the cli arg variable for the redis url + ARG_REDIS_URL *string +) diff --git a/server/constants/cookie.go b/server/constants/cookie.go new file mode 100644 index 0000000..71320a9 --- /dev/null +++ b/server/constants/cookie.go @@ -0,0 +1,8 @@ +package constants + +const ( + // AppCookieName is the name of the cookie that is used to store the application token + AppCookieName = "cookie" + // AdminCookieName is the name of the cookie that is used to store the admin token + AdminCookieName = "authorizer-admin" +) diff --git a/server/constants/env.go b/server/constants/env.go index b73048b..4f02e64 100644 --- a/server/constants/env.go +++ b/server/constants/env.go @@ -5,11 +5,11 @@ var VERSION = "0.0.1" const ( // Envstore identifier // StringStore string store identifier - StringStoreIdentifier = "stringStore" - // BoolStore bool store identifier - BoolStoreIdentifier = "boolStore" - // SliceStore slice store identifier - SliceStoreIdentifier = "sliceStore" + // StringStoreIdentifier = "stringStore" + // // BoolStore bool store identifier + // BoolStoreIdentifier = "boolStore" + // // SliceStore slice store identifier + // SliceStoreIdentifier = "sliceStore" // EnvKeyEnv key for env variable ENV EnvKeyEnv = "ENV" @@ -19,7 +19,6 @@ const ( EnvKeyAuthorizerURL = "AUTHORIZER_URL" // EnvKeyPort key for env variable PORT EnvKeyPort = "PORT" - // EnvKeyAccessTokenExpiryTime key for env variable ACCESS_TOKEN_EXPIRY_TIME EnvKeyAccessTokenExpiryTime = "ACCESS_TOKEN_EXPIRY_TIME" // EnvKeyAdminSecret key for env variable ADMIN_SECRET @@ -62,34 +61,12 @@ const ( EnvKeyJwtPrivateKey = "JWT_PRIVATE_KEY" // EnvKeyJwtPublicKey key for env variable JWT_PUBLIC_KEY EnvKeyJwtPublicKey = "JWT_PUBLIC_KEY" - // EnvKeyAllowedOrigins key for env variable ALLOWED_ORIGINS - EnvKeyAllowedOrigins = "ALLOWED_ORIGINS" // EnvKeyAppURL key for env variable APP_URL EnvKeyAppURL = "APP_URL" // EnvKeyRedisURL key for env variable REDIS_URL EnvKeyRedisURL = "REDIS_URL" - // EnvKeyCookieName key for env variable COOKIE_NAME - EnvKeyCookieName = "COOKIE_NAME" - // EnvKeyAdminCookieName key for env variable ADMIN_COOKIE_NAME - EnvKeyAdminCookieName = "ADMIN_COOKIE_NAME" // EnvKeyResetPasswordURL key for env variable RESET_PASSWORD_URL EnvKeyResetPasswordURL = "RESET_PASSWORD_URL" - // EnvKeyDisableEmailVerification key for env variable DISABLE_EMAIL_VERIFICATION - EnvKeyDisableEmailVerification = "DISABLE_EMAIL_VERIFICATION" - // EnvKeyDisableBasicAuthentication key for env variable DISABLE_BASIC_AUTH - EnvKeyDisableBasicAuthentication = "DISABLE_BASIC_AUTHENTICATION" - // EnvKeyDisableMagicLinkLogin key for env variable DISABLE_MAGIC_LINK_LOGIN - EnvKeyDisableMagicLinkLogin = "DISABLE_MAGIC_LINK_LOGIN" - // EnvKeyDisableLoginPage key for env variable DISABLE_LOGIN_PAGE - EnvKeyDisableLoginPage = "DISABLE_LOGIN_PAGE" - // EnvKeyDisableSignUp key for env variable DISABLE_SIGN_UP - EnvKeyDisableSignUp = "DISABLE_SIGN_UP" - // EnvKeyRoles key for env variable ROLES - EnvKeyRoles = "ROLES" - // EnvKeyProtectedRoles key for env variable PROTECTED_ROLES - EnvKeyProtectedRoles = "PROTECTED_ROLES" - // EnvKeyDefaultRoles key for env variable DEFAULT_ROLES - EnvKeyDefaultRoles = "DEFAULT_ROLES" // EnvKeyJwtRoleClaim key for env variable JWT_ROLE_CLAIM EnvKeyJwtRoleClaim = "JWT_ROLE_CLAIM" // EnvKeyGoogleClientID key for env variable GOOGLE_CLIENT_ID @@ -120,6 +97,30 @@ const ( EnvKeyEncryptionKey = "ENCRYPTION_KEY" // EnvKeyJWK key for env variable JWK EnvKeyJWK = "JWK" + + // Boolean variables // EnvKeyIsProd key for env variable IS_PROD EnvKeyIsProd = "IS_PROD" + // EnvKeyDisableEmailVerification key for env variable DISABLE_EMAIL_VERIFICATION + EnvKeyDisableEmailVerification = "DISABLE_EMAIL_VERIFICATION" + // EnvKeyDisableBasicAuthentication key for env variable DISABLE_BASIC_AUTH + EnvKeyDisableBasicAuthentication = "DISABLE_BASIC_AUTHENTICATION" + // EnvKeyDisableMagicLinkLogin key for env variable DISABLE_MAGIC_LINK_LOGIN + EnvKeyDisableMagicLinkLogin = "DISABLE_MAGIC_LINK_LOGIN" + // EnvKeyDisableLoginPage key for env variable DISABLE_LOGIN_PAGE + EnvKeyDisableLoginPage = "DISABLE_LOGIN_PAGE" + // EnvKeyDisableSignUp key for env variable DISABLE_SIGN_UP + EnvKeyDisableSignUp = "DISABLE_SIGN_UP" + // EnvKeyDisableRedisForEnv key for env variable DISABLE_REDIS_FOR_ENV + EnvKeyDisableRedisForEnv = "DISABLE_REDIS_FOR_ENV" + + // Slice variables + // EnvKeyRoles key for env variable ROLES + EnvKeyRoles = "ROLES" + // EnvKeyProtectedRoles key for env variable PROTECTED_ROLES + EnvKeyProtectedRoles = "PROTECTED_ROLES" + // EnvKeyDefaultRoles key for env variable DEFAULT_ROLES + EnvKeyDefaultRoles = "DEFAULT_ROLES" + // EnvKeyAllowedOrigins key for env variable ALLOWED_ORIGINS + EnvKeyAllowedOrigins = "ALLOWED_ORIGINS" ) diff --git a/server/cookie/admin_cookie.go b/server/cookie/admin_cookie.go index 58f2c56..6b64767 100644 --- a/server/cookie/admin_cookie.go +++ b/server/cookie/admin_cookie.go @@ -4,8 +4,7 @@ import ( "net/url" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/gin-gonic/gin" ) @@ -13,15 +12,14 @@ import ( func SetAdminCookie(gc *gin.Context, token string) { secure := true httpOnly := true - hostname := utils.GetHost(gc) - host, _ := utils.GetHostParts(hostname) - - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), token, 3600, "/", host, secure, httpOnly) + hostname := parsers.GetHost(gc) + host, _ := parsers.GetHostParts(hostname) + gc.SetCookie(constants.AdminCookieName, token, 3600, "/", host, secure, httpOnly) } // GetAdminCookie gets the admin cookie from the request func GetAdminCookie(gc *gin.Context) (string, error) { - cookie, err := gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName)) + cookie, err := gc.Request.Cookie(constants.AdminCookieName) if err != nil { return "", err } @@ -39,8 +37,7 @@ func GetAdminCookie(gc *gin.Context) (string, error) { func DeleteAdminCookie(gc *gin.Context) { secure := true httpOnly := true - hostname := utils.GetHost(gc) - host, _ := utils.GetHostParts(hostname) - - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), "", -1, "/", host, secure, httpOnly) + hostname := parsers.GetHost(gc) + host, _ := parsers.GetHostParts(hostname) + gc.SetCookie(constants.AdminCookieName, "", -1, "/", host, secure, httpOnly) } diff --git a/server/cookie/cookie.go b/server/cookie/cookie.go index 54600af..73c60ea 100644 --- a/server/cookie/cookie.go +++ b/server/cookie/cookie.go @@ -5,8 +5,7 @@ import ( "net/url" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/gin-gonic/gin" ) @@ -14,9 +13,9 @@ import ( func SetSession(gc *gin.Context, sessionID string) { secure := true httpOnly := true - hostname := utils.GetHost(gc) - host, _ := utils.GetHostParts(hostname) - domain := utils.GetDomainName(hostname) + hostname := parsers.GetHost(gc) + host, _ := parsers.GetHostParts(hostname) + domain := parsers.GetDomainName(hostname) if domain != "localhost" { domain = "." + domain } @@ -25,33 +24,33 @@ func SetSession(gc *gin.Context, sessionID string) { year := 60 * 60 * 24 * 365 gc.SetSameSite(http.SameSiteNoneMode) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", sessionID, year, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_domain", sessionID, year, "/", domain, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session", sessionID, year, "/", host, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session_domain", sessionID, year, "/", domain, secure, httpOnly) } // DeleteSession sets session cookies to expire func DeleteSession(gc *gin.Context) { secure := true httpOnly := true - hostname := utils.GetHost(gc) - host, _ := utils.GetHostParts(hostname) - domain := utils.GetDomainName(hostname) + hostname := parsers.GetHost(gc) + host, _ := parsers.GetHostParts(hostname) + domain := parsers.GetDomainName(hostname) if domain != "localhost" { domain = "." + domain } gc.SetSameSite(http.SameSiteNoneMode) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", "", -1, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_domain", "", -1, "/", domain, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session", "", -1, "/", host, secure, httpOnly) + gc.SetCookie(constants.AppCookieName+"_session_domain", "", -1, "/", domain, secure, httpOnly) } // GetSession gets the session cookie from context func GetSession(gc *gin.Context) (string, error) { var cookie *http.Cookie var err error - cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session") + cookie, err = gc.Request.Cookie(constants.AppCookieName + "_session") if err != nil { - cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session_domain") + cookie, err = gc.Request.Cookie(constants.AppCookieName + "_session_domain") if err != nil { return "", err } diff --git a/server/crypto/aes.go b/server/crypto/aes.go index 8d06ffb..422f694 100644 --- a/server/crypto/aes.go +++ b/server/crypto/aes.go @@ -7,14 +7,18 @@ import ( "io" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) var bytes = []byte{35, 46, 57, 24, 85, 35, 24, 74, 87, 35, 88, 98, 66, 32, 14, 0o5} // EncryptAES method is to encrypt or hide any classified text func EncryptAES(text string) (string, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return "", err + } + key := []byte(k) block, err := aes.NewCipher(key) if err != nil { return "", err @@ -28,7 +32,11 @@ func EncryptAES(text string) (string, error) { // DecryptAES method is to extract back the encrypted text func DecryptAES(text string) (string, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return "", err + } + key := []byte(k) block, err := aes.NewCipher(key) if err != nil { return "", err @@ -46,9 +54,13 @@ func DecryptAES(text string) (string, error) { // EncryptAESEnv encrypts data using AES algorithm // kept for the backward compatibility of env data encryption func EncryptAESEnv(text []byte) ([]byte, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) - c, err := aes.NewCipher(key) var res []byte + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return res, err + } + key := []byte(k) + c, err := aes.NewCipher(key) if err != nil { return res, err } @@ -81,9 +93,13 @@ func EncryptAESEnv(text []byte) ([]byte, error) { // DecryptAES decrypts data using AES algorithm // Kept for the backward compatibility of env data decryption func DecryptAESEnv(ciphertext []byte) ([]byte, error) { - key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) - c, err := aes.NewCipher(key) var res []byte + k, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey) + if err != nil { + return res, err + } + key := []byte(k) + c, err := aes.NewCipher(key) if err != nil { return res, err } diff --git a/server/crypto/common.go b/server/crypto/common.go index 35af515..91aed06 100644 --- a/server/crypto/common.go +++ b/server/crypto/common.go @@ -5,7 +5,7 @@ import ( "encoding/json" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "golang.org/x/crypto/bcrypt" "gopkg.in/square/go-jose.v2" ) @@ -37,20 +37,35 @@ func GetPubJWK(algo, keyID string, publicKey interface{}) (string, error) { // this is called while initializing app / when env is updated func GenerateJWKBasedOnEnv() (string, error) { jwk := "" - algo := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) - clientID := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) + algo, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + if err != nil { + return jwk, err + } + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return jwk, err + } + + jwtSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret) + if err != nil { + return jwk, err + } - var err error // check if jwt secret is provided if IsHMACA(algo) { - jwk, err = GetPubJWK(algo, clientID, []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret))) + jwk, err = GetPubJWK(algo, clientID, []byte(jwtSecret)) if err != nil { return "", err } } + jwtPublicKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return jwk, err + } + if IsRSA(algo) { - publicKeyInstance, err := ParseRsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + publicKeyInstance, err := ParseRsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return "", err } @@ -62,7 +77,11 @@ func GenerateJWKBasedOnEnv() (string, error) { } if IsECDSA(algo) { - publicKeyInstance, err := ParseEcdsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + jwtPublicKey, err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return jwk, err + } + publicKeyInstance, err := ParseEcdsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return "", err } @@ -77,13 +96,16 @@ func GenerateJWKBasedOnEnv() (string, error) { } // EncryptEnvData is used to encrypt the env data -func EncryptEnvData(data envstore.Store) (string, error) { +func EncryptEnvData(data map[string]interface{}) (string, error) { jsonBytes, err := json.Marshal(data) if err != nil { return "", err } - storeData := envstore.EnvStoreObj.GetEnvStoreClone() + storeData, err := memorystore.Provider.GetEnvStore() + if err != nil { + return "", err + } err = json.Unmarshal(jsonBytes, &storeData) if err != nil { diff --git a/server/db/db.go b/server/db/db.go index a93cc01..d41469f 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -9,7 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/db/providers/cassandradb" "github.com/authorizerdev/authorizer/server/db/providers/mongodb" "github.com/authorizerdev/authorizer/server/db/providers/sql" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // Provider returns the current database provider @@ -18,13 +18,15 @@ var Provider providers.Provider func InitDB() error { var err error - isSQL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeArangodb && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeMongodb && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeCassandraDB - isArangoDB := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) == constants.DbTypeArangodb - isMongoDB := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) == constants.DbTypeMongodb - isCassandra := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) == constants.DbTypeCassandraDB + envs := memorystore.RequiredEnvStoreObj.GetRequiredEnv() + + isSQL := envs.DatabaseType != constants.DbTypeArangodb && envs.DatabaseType != constants.DbTypeMongodb && envs.DatabaseType != constants.DbTypeCassandraDB + isArangoDB := envs.DatabaseType == constants.DbTypeArangodb + isMongoDB := envs.DatabaseType == constants.DbTypeMongodb + isCassandra := envs.DatabaseType == constants.DbTypeCassandraDB if isSQL { - log.Info("Initializing SQL Driver for: ", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType)) + log.Info("Initializing SQL Driver for: ", envs.DatabaseType) Provider, err = sql.NewProvider() if err != nil { log.Fatal("Failed to initialize SQL driver: ", err) diff --git a/server/db/providers/arangodb/provider.go b/server/db/providers/arangodb/provider.go index 92c007c..a9a8432 100644 --- a/server/db/providers/arangodb/provider.go +++ b/server/db/providers/arangodb/provider.go @@ -6,9 +6,8 @@ import ( "github.com/arangodb/go-driver" arangoDriver "github.com/arangodb/go-driver" "github.com/arangodb/go-driver/http" - "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) type provider struct { @@ -22,8 +21,9 @@ type provider struct { // NewProvider to initialize arangodb connection func NewProvider() (*provider, error) { ctx := context.Background() + dbURL := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseURL conn, err := http.NewConnection(http.ConnectionConfig{ - Endpoints: []string{envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)}, + Endpoints: []string{dbURL}, }) if err != nil { return nil, err @@ -37,16 +37,16 @@ func NewProvider() (*provider, error) { } var arangodb driver.Database - - arangodb_exists, err := arangoClient.DatabaseExists(nil, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName)) + dbName := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseName + arangodb_exists, err := arangoClient.DatabaseExists(nil, dbName) if arangodb_exists { - arangodb, err = arangoClient.Database(nil, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName)) + arangodb, err = arangoClient.Database(nil, dbName) if err != nil { return nil, err } } else { - arangodb, err = arangoClient.CreateDatabase(nil, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName), nil) + arangodb, err = arangoClient.CreateDatabase(nil, dbName, nil) if err != nil { return nil, err } diff --git a/server/db/providers/arangodb/user.go b/server/db/providers/arangodb/user.go index fc466a4..315a827 100644 --- a/server/db/providers/arangodb/user.go +++ b/server/db/providers/arangodb/user.go @@ -3,15 +3,14 @@ package arangodb import ( "context" "fmt" - "strings" "time" "github.com/arangodb/go-driver" arangoDriver "github.com/arangodb/go-driver" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" ) @@ -22,7 +21,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = defaultRoles } user.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/cassandradb/provider.go b/server/db/providers/cassandradb/provider.go index e7bf3b0..0a5f4a5 100644 --- a/server/db/providers/cassandradb/provider.go +++ b/server/db/providers/cassandradb/provider.go @@ -9,7 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gocql/gocql" cansandraDriver "github.com/gocql/gocql" ) @@ -23,15 +23,19 @@ var KeySpace string // NewProvider to initialize arangodb connection func NewProvider() (*provider, error) { - dbURL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) + dbURL := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseURL if dbURL == "" { - dbURL = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseHost) - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePort) != "" { - dbURL = fmt.Sprintf("%s:%s", dbURL, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePort)) + dbHost := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseHost + dbPort := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabasePort + if dbPort != "" && dbHost != "" { + dbURL = fmt.Sprintf("%s:%s", dbHost, dbPort) } } - KeySpace = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName) + KeySpace = memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseName + if KeySpace == "" { + KeySpace = constants.EnvKeyDatabaseName + } clusterURL := []string{} if strings.Contains(dbURL, ",") { clusterURL = strings.Split(dbURL, ",") @@ -39,25 +43,31 @@ func NewProvider() (*provider, error) { clusterURL = append(clusterURL, dbURL) } cassandraClient := cansandraDriver.NewCluster(clusterURL...) - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseUsername) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePassword) != "" { + dbUsername := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseUsername + dbPassword := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabasePassword + + if dbUsername != "" && dbPassword != "" { cassandraClient.Authenticator = &cansandraDriver.PasswordAuthenticator{ - Username: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseUsername), - Password: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabasePassword), + Username: dbUsername, + Password: dbPassword, } } - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCert) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCACert) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCertKey) != "" { - certString, err := crypto.DecryptB64(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCert)) + dbCert := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseCert + dbCACert := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseCACert + dbCertKey := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseCertKey + if dbCert != "" && dbCACert != "" && dbCertKey != "" { + certString, err := crypto.DecryptB64(dbCert) if err != nil { return nil, err } - keyString, err := crypto.DecryptB64(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCertKey)) + keyString, err := crypto.DecryptB64(dbCertKey) if err != nil { return nil, err } - caString, err := crypto.DecryptB64(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseCACert)) + caString, err := crypto.DecryptB64(dbCACert) if err != nil { return nil, err } diff --git a/server/db/providers/cassandradb/user.go b/server/db/providers/cassandradb/user.go index 09b7476..d1305da 100644 --- a/server/db/providers/cassandradb/user.go +++ b/server/db/providers/cassandradb/user.go @@ -9,8 +9,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gocql/gocql" "github.com/google/uuid" ) @@ -22,7 +22,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = defaultRoles } user.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/mongodb/provider.go b/server/db/providers/mongodb/provider.go index d29fca1..8909406 100644 --- a/server/db/providers/mongodb/provider.go +++ b/server/db/providers/mongodb/provider.go @@ -4,9 +4,8 @@ import ( "context" "time" - "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" @@ -19,7 +18,8 @@ type provider struct { // NewProvider to initialize mongodb connection func NewProvider() (*provider, error) { - mongodbOptions := options.Client().ApplyURI(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)) + dbURL := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseURL + mongodbOptions := options.Client().ApplyURI(dbURL) maxWait := time.Duration(5 * time.Second) mongodbOptions.ConnectTimeout = &maxWait mongoClient, err := mongo.NewClient(mongodbOptions) @@ -37,18 +37,19 @@ func NewProvider() (*provider, error) { return nil, err } - mongodb := mongoClient.Database(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseName), options.Database()) + dbName := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseName + mongodb := mongoClient.Database(dbName, options.Database()) mongodb.CreateCollection(ctx, models.Collections.User, options.CreateCollection()) userCollection := mongodb.Collection(models.Collections.User, options.Collection()) userCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"email": 1}, Options: options.Index().SetUnique(true).SetSparse(true), }, }, options.CreateIndexes()) userCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"phone_number": 1}, Options: options.Index().SetUnique(true).SetSparse(true).SetPartialFilterExpression(map[string]interface{}{ "phone_number": map[string]string{"$type": "string"}, @@ -59,13 +60,13 @@ func NewProvider() (*provider, error) { mongodb.CreateCollection(ctx, models.Collections.VerificationRequest, options.CreateCollection()) verificationRequestCollection := mongodb.Collection(models.Collections.VerificationRequest, options.Collection()) verificationRequestCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"email": 1, "identifier": 1}, Options: options.Index().SetUnique(true).SetSparse(true), }, }, options.CreateIndexes()) verificationRequestCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"token": 1}, Options: options.Index().SetSparse(true), }, @@ -74,7 +75,7 @@ func NewProvider() (*provider, error) { mongodb.CreateCollection(ctx, models.Collections.Session, options.CreateCollection()) sessionCollection := mongodb.Collection(models.Collections.Session, options.Collection()) sessionCollection.Indexes().CreateMany(ctx, []mongo.IndexModel{ - mongo.IndexModel{ + { Keys: bson.M{"user_id": 1}, Options: options.Index().SetSparse(true), }, diff --git a/server/db/providers/mongodb/user.go b/server/db/providers/mongodb/user.go index af6c799..4f60349 100644 --- a/server/db/providers/mongodb/user.go +++ b/server/db/providers/mongodb/user.go @@ -1,13 +1,12 @@ package mongodb import ( - "strings" "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" @@ -20,7 +19,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = defaultRoles } user.CreatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix() diff --git a/server/db/providers/provider_template/user.go b/server/db/providers/provider_template/user.go index 07f6a06..cb1069f 100644 --- a/server/db/providers/provider_template/user.go +++ b/server/db/providers/provider_template/user.go @@ -1,13 +1,12 @@ package provider_template import ( - "strings" "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" ) @@ -18,7 +17,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = defaultRoles } user.CreatedAt = time.Now().Unix() diff --git a/server/db/providers/sql/provider.go b/server/db/providers/sql/provider.go index 279b707..68910e5 100644 --- a/server/db/providers/sql/provider.go +++ b/server/db/providers/sql/provider.go @@ -7,7 +7,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -41,15 +41,19 @@ func NewProvider() (*provider, error) { TablePrefix: models.Prefix, }, } - switch envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) { + + dbType := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseType + dbURL := memorystore.RequiredEnvStoreObj.GetRequiredEnv().DatabaseURL + + switch dbType { case constants.DbTypePostgres, constants.DbTypeYugabyte: - sqlDB, err = gorm.Open(postgres.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(postgres.Open(dbURL), ormConfig) case constants.DbTypeSqlite: - sqlDB, err = gorm.Open(sqlite.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(sqlite.Open(dbURL), ormConfig) case constants.DbTypeMysql, constants.DbTypeMariaDB: - sqlDB, err = gorm.Open(mysql.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(mysql.Open(dbURL), ormConfig) case constants.DbTypeSqlserver: - sqlDB, err = gorm.Open(sqlserver.Open(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL)), ormConfig) + sqlDB, err = gorm.Open(sqlserver.Open(dbURL), ormConfig) } if err != nil { diff --git a/server/db/providers/sql/user.go b/server/db/providers/sql/user.go index ef295c6..e7e999e 100644 --- a/server/db/providers/sql/user.go +++ b/server/db/providers/sql/user.go @@ -1,13 +1,12 @@ package sql import ( - "strings" "time" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/google/uuid" "gorm.io/gorm/clause" ) @@ -19,7 +18,11 @@ func (p *provider) AddUser(user models.User) (models.User, error) { } if user.Roles == "" { - user.Roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + return user, err + } + user.Roles = defaultRoles } user.CreatedAt = time.Now().Unix() diff --git a/server/email/email.go b/server/email/email.go index b8e6d80..4234eff 100644 --- a/server/email/email.go +++ b/server/email/email.go @@ -11,7 +11,7 @@ import ( gomail "gopkg.in/mail.v2" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // addEmailTemplate is used to add html template in email body @@ -33,17 +33,57 @@ func addEmailTemplate(a string, b map[string]interface{}, templateName string) s // SendMail function to send mail func SendMail(to []string, Subject, bodyMessage string) error { // dont trigger email sending in case of test - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEnv) == "test" { + envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) + if err != nil { + return err + } + if envKey == "test" { return nil } m := gomail.NewMessage() - m.SetHeader("From", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySenderEmail)) + senderEmail, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySenderEmail) + if err != nil { + log.Errorf("Error while getting sender email from env variable: %v", err) + return err + } + + smtpPort, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpPort) + if err != nil { + log.Errorf("Error while getting smtp port from env variable: %v", err) + return err + } + + smtpHost, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpHost) + if err != nil { + log.Errorf("Error while getting smtp host from env variable: %v", err) + return err + } + + smtpUsername, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpUsername) + if err != nil { + log.Errorf("Error while getting smtp username from env variable: %v", err) + return err + } + + smtpPassword, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeySmtpPassword) + if err != nil { + log.Errorf("Error while getting smtp password from env variable: %v", err) + return err + } + + isProd, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyIsProd) + if err != nil { + log.Errorf("Error while getting env variable: %v", err) + return err + } + + m.SetHeader("From", senderEmail) m.SetHeader("To", to...) m.SetHeader("Subject", Subject) m.SetBody("text/html", bodyMessage) - port, _ := strconv.Atoi(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpPort)) - d := gomail.NewDialer(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpHost), port, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpUsername), envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySmtpPassword)) - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEnv) == "development" { + port, _ := strconv.Atoi(smtpPort) + d := gomail.NewDialer(smtpHost, port, smtpUsername, smtpPassword) + if !isProd { d.TLSConfig = &tls.Config{InsecureSkipVerify: true} } if err := d.DialAndSend(m); err != nil { diff --git a/server/email/forgot_password_email.go b/server/email/forgot_password_email.go index 1e06437..aabd6a9 100644 --- a/server/email/forgot_password_email.go +++ b/server/email/forgot_password_email.go @@ -2,14 +2,19 @@ package email import ( "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // SendForgotPasswordMail to send forgot password email func SendForgotPasswordMail(toEmail, token, hostname string) error { - resetPasswordUrl := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyResetPasswordURL) + resetPasswordUrl, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyResetPasswordURL) + if err != nil { + return err + } if resetPasswordUrl == "" { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyResetPasswordURL, hostname+"/app/reset-password") + if err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyResetPasswordURL, hostname+"/app/reset-password"); err != nil { + return err + } } // The receiver needs to be in slice as the receive supports multiple receiver @@ -103,8 +108,14 @@ func SendForgotPasswordMail(toEmail, token, hostname string) error { ` data := make(map[string]interface{}, 3) - data["org_logo"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) - data["org_name"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + data["org_logo"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + return err + } + data["org_name"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + return err + } data["verification_url"] = resetPasswordUrl + "?token=" + token message = addEmailTemplate(message, data, "reset_password_email.tmpl") diff --git a/server/email/invite_email.go b/server/email/invite_email.go index 8689353..ef561a6 100644 --- a/server/email/invite_email.go +++ b/server/email/invite_email.go @@ -4,7 +4,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // InviteEmail to send invite email @@ -99,13 +99,20 @@ func InviteEmail(toEmail, token, verificationURL, redirectURI string) error { ` data := make(map[string]interface{}, 3) - data["org_logo"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) - data["org_name"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + var err error + data["org_logo"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + return err + } + data["org_name"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + return err + } data["verification_url"] = verificationURL + "?token=" + token + "&redirect_uri=" + redirectURI message = addEmailTemplate(message, data, "invite_email.tmpl") // bodyMessage := sender.WriteHTMLEmail(Receiver, Subject, message) - err := SendMail(Receiver, Subject, message) + err = SendMail(Receiver, Subject, message) if err != nil { log.Warn("error sending email: ", err) } diff --git a/server/email/verification_email.go b/server/email/verification_email.go index dd73657..dded5ef 100644 --- a/server/email/verification_email.go +++ b/server/email/verification_email.go @@ -4,7 +4,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // SendVerificationMail to send verification email @@ -99,13 +99,20 @@ func SendVerificationMail(toEmail, token, hostname string) error { ` data := make(map[string]interface{}, 3) - data["org_logo"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) - data["org_name"] = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + var err error + data["org_logo"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + return err + } + data["org_name"], err = memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + return err + } data["verification_url"] = hostname + "/verify_email?token=" + token message = addEmailTemplate(message, data, "verify_email.tmpl") // bodyMessage := sender.WriteHTMLEmail(Receiver, Subject, message) - err := SendMail(Receiver, Subject, message) + err = SendMail(Receiver, Subject, message) if err != nil { log.Warn("error sending email: ", err) } diff --git a/server/env/env.go b/server/env/env.go index abf9b53..e80fbe4 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -2,212 +2,242 @@ package env import ( "errors" + "fmt" "os" + "strconv" "strings" "github.com/google/uuid" - "github.com/joho/godotenv" log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) -// InitRequiredEnv to initialize EnvData and through error if required env are not present -func InitRequiredEnv() error { - envPath := os.Getenv(constants.EnvKeyEnvPath) - - if envPath == "" { - envPath = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEnvPath) - if envPath == "" { - envPath = `.env` - } - } - - if envstore.ARG_ENV_FILE != nil && *envstore.ARG_ENV_FILE != "" { - envPath = *envstore.ARG_ENV_FILE - } - log.Info("env path: ", envPath) - - err := godotenv.Load(envPath) - if err != nil { - log.Info("using OS env instead of %s file", envPath) - } - - dbURL := os.Getenv(constants.EnvKeyDatabaseURL) - dbType := os.Getenv(constants.EnvKeyDatabaseType) - dbName := os.Getenv(constants.EnvKeyDatabaseName) - dbPort := os.Getenv(constants.EnvKeyDatabasePort) - dbHost := os.Getenv(constants.EnvKeyDatabaseHost) - dbUsername := os.Getenv(constants.EnvKeyDatabaseUsername) - dbPassword := os.Getenv(constants.EnvKeyDatabasePassword) - dbCert := os.Getenv(constants.EnvKeyDatabaseCert) - dbCertKey := os.Getenv(constants.EnvKeyDatabaseCertKey) - dbCACert := os.Getenv(constants.EnvKeyDatabaseCACert) - - if strings.TrimSpace(dbType) == "" { - if envstore.ARG_DB_TYPE != nil && *envstore.ARG_DB_TYPE != "" { - dbType = strings.TrimSpace(*envstore.ARG_DB_TYPE) - } - - if dbType == "" { - log.Debug("DATABASE_TYPE is not set") - return errors.New("invalid database type. DATABASE_TYPE is empty") - } - } - - if strings.TrimSpace(dbURL) == "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseURL) == "" { - if envstore.ARG_DB_URL != nil && *envstore.ARG_DB_URL != "" { - dbURL = strings.TrimSpace(*envstore.ARG_DB_URL) - } - - if dbURL == "" && dbPort == "" && dbHost == "" && dbUsername == "" && dbPassword == "" { - log.Debug("DATABASE_URL is not set") - return errors.New("invalid database url. DATABASE_URL is required") - } - } - - if dbName == "" { - if dbName == "" { - dbName = "authorizer" - } - } - - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEnvPath, envPath) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseURL, dbURL) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseType, dbType) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseName, dbName) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseHost, dbHost) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabasePort, dbPort) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseUsername, dbUsername) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabasePassword, dbPassword) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseCert, dbCert) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseCertKey, dbCertKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseCACert, dbCACert) - - return nil -} - // InitEnv to initialize EnvData and through error if required env are not present func InitAllEnv() error { envData, err := GetEnvData() if err != nil { log.Info("No env data found in db, using local clone of env data") // get clone of current store - envData = envstore.EnvStoreObj.GetEnvStoreClone() + envData, err = memorystore.Provider.GetEnvStore() + if err != nil { + log.Debug("Error while getting env data from memorystore: ", err) + return err + } } - clientID := envData.StringEnv[constants.EnvKeyClientID] // unique client id for each instance - if clientID == "" { + cid, ok := envData[constants.EnvKeyClientID] + clientID := "" + if !ok || cid == "" { clientID = uuid.New().String() - envData.StringEnv[constants.EnvKeyClientID] = clientID + envData[constants.EnvKeyClientID] = clientID + } else { + clientID = cid.(string) } - clientSecret := envData.StringEnv[constants.EnvKeyClientSecret] - // unique client id for each instance - if clientSecret == "" { - clientSecret = uuid.New().String() - envData.StringEnv[constants.EnvKeyClientSecret] = clientSecret + // unique client secret for each instance + if val, ok := envData[constants.EnvKeyClientSecret]; !ok || val != "" { + envData[constants.EnvKeyClientSecret] = uuid.New().String() } - if envData.StringEnv[constants.EnvKeyEnv] == "" { - envData.StringEnv[constants.EnvKeyEnv] = os.Getenv(constants.EnvKeyEnv) - if envData.StringEnv[constants.EnvKeyEnv] == "" { - envData.StringEnv[constants.EnvKeyEnv] = "production" + // os string envs + osEnv := os.Getenv(constants.EnvKeyEnv) + osAppURL := os.Getenv(constants.EnvKeyAppURL) + osAuthorizerURL := os.Getenv(constants.EnvKeyAuthorizerURL) + osPort := os.Getenv(constants.EnvKeyPort) + osAccessTokenExpiryTime := os.Getenv(constants.EnvKeyAccessTokenExpiryTime) + osAdminSecret := os.Getenv(constants.EnvKeyAdminSecret) + osSmtpHost := os.Getenv(constants.EnvKeySmtpHost) + osSmtpPort := os.Getenv(constants.EnvKeySmtpPort) + osSmtpUsername := os.Getenv(constants.EnvKeySmtpUsername) + osSmtpPassword := os.Getenv(constants.EnvKeySmtpPassword) + osSenderEmail := os.Getenv(constants.EnvKeySenderEmail) + osJwtType := os.Getenv(constants.EnvKeyJwtType) + osJwtSecret := os.Getenv(constants.EnvKeyJwtSecret) + osJwtPrivateKey := os.Getenv(constants.EnvKeyJwtPrivateKey) + osJwtPublicKey := os.Getenv(constants.EnvKeyJwtPublicKey) + osJwtRoleClaim := os.Getenv(constants.EnvKeyJwtRoleClaim) + osCustomAccessTokenScript := os.Getenv(constants.EnvKeyCustomAccessTokenScript) + osGoogleClientID := os.Getenv(constants.EnvKeyGoogleClientID) + osGoogleClientSecret := os.Getenv(constants.EnvKeyGoogleClientSecret) + osGithubClientID := os.Getenv(constants.EnvKeyGithubClientID) + osGithubClientSecret := os.Getenv(constants.EnvKeyGithubClientSecret) + osFacebookClientID := os.Getenv(constants.EnvKeyFacebookClientID) + osFacebookClientSecret := os.Getenv(constants.EnvKeyFacebookClientSecret) + osResetPasswordURL := os.Getenv(constants.EnvKeyResetPasswordURL) + osOrganizationName := os.Getenv(constants.EnvKeyOrganizationName) + osOrganizationLogo := os.Getenv(constants.EnvKeyOrganizationLogo) + + // os bool vars + osDisableBasicAuthentication := os.Getenv(constants.EnvKeyDisableBasicAuthentication) + osDisableEmailVerification := os.Getenv(constants.EnvKeyDisableEmailVerification) + osDisableMagicLinkLogin := os.Getenv(constants.EnvKeyDisableMagicLinkLogin) + osDisableLoginPage := os.Getenv(constants.EnvKeyDisableLoginPage) + osDisableSignUp := os.Getenv(constants.EnvKeyDisableSignUp) + osDisableRedisForEnv := os.Getenv(constants.EnvKeyDisableRedisForEnv) + + // os slice vars + osAllowedOrigins := os.Getenv(constants.EnvKeyAllowedOrigins) + osRoles := os.Getenv(constants.EnvKeyRoles) + osDefaultRoles := os.Getenv(constants.EnvKeyDefaultRoles) + osProtectedRoles := os.Getenv(constants.EnvKeyProtectedRoles) + + ienv, ok := envData[constants.EnvKeyEnv] + if !ok || ienv == "" { + envData[constants.EnvKeyEnv] = osEnv + if envData[constants.EnvKeyEnv] == "" { + envData[constants.EnvKeyEnv] = "production" } - if envData.StringEnv[constants.EnvKeyEnv] == "production" { - envData.BoolEnv[constants.EnvKeyIsProd] = true + if envData[constants.EnvKeyEnv] == "production" { + envData[constants.EnvKeyIsProd] = true } else { - envData.BoolEnv[constants.EnvKeyIsProd] = false + envData[constants.EnvKeyIsProd] = false } } - - if envData.StringEnv[constants.EnvKeyAppURL] == "" { - envData.StringEnv[constants.EnvKeyAppURL] = os.Getenv(constants.EnvKeyAppURL) - } - - if envData.StringEnv[constants.EnvKeyAuthorizerURL] == "" { - envData.StringEnv[constants.EnvKeyAuthorizerURL] = os.Getenv(constants.EnvKeyAuthorizerURL) - } - - if envData.StringEnv[constants.EnvKeyPort] == "" { - envData.StringEnv[constants.EnvKeyPort] = os.Getenv(constants.EnvKeyPort) - if envData.StringEnv[constants.EnvKeyPort] == "" { - envData.StringEnv[constants.EnvKeyPort] = "8080" - } - } - - if envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] == "" { - envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] = os.Getenv(constants.EnvKeyAccessTokenExpiryTime) - if envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] == "" { - envData.StringEnv[constants.EnvKeyAccessTokenExpiryTime] = "30m" - } - } - - if envData.StringEnv[constants.EnvKeyAdminSecret] == "" { - envData.StringEnv[constants.EnvKeyAdminSecret] = os.Getenv(constants.EnvKeyAdminSecret) - } - - if envData.StringEnv[constants.EnvKeySmtpHost] == "" { - envData.StringEnv[constants.EnvKeySmtpHost] = os.Getenv(constants.EnvKeySmtpHost) - } - - if envData.StringEnv[constants.EnvKeySmtpPort] == "" { - envData.StringEnv[constants.EnvKeySmtpPort] = os.Getenv(constants.EnvKeySmtpPort) - } - - if envData.StringEnv[constants.EnvKeySmtpUsername] == "" { - envData.StringEnv[constants.EnvKeySmtpUsername] = os.Getenv(constants.EnvKeySmtpUsername) - } - - if envData.StringEnv[constants.EnvKeySmtpPassword] == "" { - envData.StringEnv[constants.EnvKeySmtpPassword] = os.Getenv(constants.EnvKeySmtpPassword) - } - - if envData.StringEnv[constants.EnvKeySenderEmail] == "" { - envData.StringEnv[constants.EnvKeySenderEmail] = os.Getenv(constants.EnvKeySenderEmail) - } - - algo := envData.StringEnv[constants.EnvKeyJwtType] - if algo == "" { - envData.StringEnv[constants.EnvKeyJwtType] = os.Getenv(constants.EnvKeyJwtType) - if envData.StringEnv[constants.EnvKeyJwtType] == "" { - envData.StringEnv[constants.EnvKeyJwtType] = "RS256" - algo = envData.StringEnv[constants.EnvKeyJwtType] + if osEnv != "" && osEnv != envData[constants.EnvKeyEnv] { + envData[constants.EnvKeyEnv] = osEnv + if envData[constants.EnvKeyEnv] == "production" { + envData[constants.EnvKeyIsProd] = true } else { - algo = envData.StringEnv[constants.EnvKeyJwtType] - if !crypto.IsHMACA(algo) && !crypto.IsRSA(algo) && !crypto.IsECDSA(algo) { - log.Debug("Invalid JWT Algorithm") - return errors.New("invalid JWT_TYPE") - } + envData[constants.EnvKeyIsProd] = false } } + if val, ok := envData[constants.EnvKeyAppURL]; !ok || val == "" { + envData[constants.EnvKeyAppURL] = osAppURL + } + if osAppURL != "" && envData[constants.EnvKeyAppURL] != osAppURL { + envData[constants.EnvKeyAppURL] = osAppURL + } + + if val, ok := envData[constants.EnvKeyAuthorizerURL]; !ok || val == "" { + envData[constants.EnvKeyAuthorizerURL] = osAuthorizerURL + } + if osAuthorizerURL != "" && envData[constants.EnvKeyAuthorizerURL] != osAuthorizerURL { + envData[constants.EnvKeyAuthorizerURL] = osAuthorizerURL + } + + if val, ok := envData[constants.EnvKeyPort]; !ok || val == "" { + envData[constants.EnvKeyPort] = osPort + if envData[constants.EnvKeyPort] == "" { + envData[constants.EnvKeyPort] = "8080" + } + } + if osPort != "" && envData[constants.EnvKeyPort] != osPort { + envData[constants.EnvKeyPort] = osPort + } + + if val, ok := envData[constants.EnvKeyAccessTokenExpiryTime]; !ok || val == "" { + envData[constants.EnvKeyAccessTokenExpiryTime] = osAccessTokenExpiryTime + if envData[constants.EnvKeyAccessTokenExpiryTime] == "" { + envData[constants.EnvKeyAccessTokenExpiryTime] = "30m" + } + } + if osAccessTokenExpiryTime != "" && envData[constants.EnvKeyAccessTokenExpiryTime] != osAccessTokenExpiryTime { + envData[constants.EnvKeyAccessTokenExpiryTime] = osAccessTokenExpiryTime + } + + if val, ok := envData[constants.EnvKeyAdminSecret]; !ok || val == "" { + envData[constants.EnvKeyAdminSecret] = osAdminSecret + } + if osAdminSecret != "" && envData[constants.EnvKeyAdminSecret] != osAdminSecret { + envData[constants.EnvKeyAdminSecret] = osAdminSecret + } + + if val, ok := envData[constants.EnvKeySmtpHost]; !ok || val == "" { + envData[constants.EnvKeySmtpHost] = osSmtpHost + } + if osSmtpHost != "" && envData[constants.EnvKeySmtpHost] != osSmtpHost { + envData[constants.EnvKeySmtpHost] = osSmtpHost + } + + if val, ok := envData[constants.EnvKeySmtpPort]; !ok || val == "" { + envData[constants.EnvKeySmtpPort] = osSmtpPort + } + if osSmtpPort != "" && envData[constants.EnvKeySmtpPort] != osSmtpPort { + envData[constants.EnvKeySmtpPort] = osSmtpPort + } + + if val, ok := envData[constants.EnvKeySmtpUsername]; !ok || val == "" { + envData[constants.EnvKeySmtpUsername] = osSmtpUsername + } + if osSmtpUsername != "" && envData[constants.EnvKeySmtpUsername] != osSmtpUsername { + envData[constants.EnvKeySmtpUsername] = osSmtpUsername + } + + if val, ok := envData[constants.EnvKeySmtpPassword]; !ok || val == "" { + envData[constants.EnvKeySmtpPassword] = osSmtpPassword + } + if osSmtpPassword != "" && envData[constants.EnvKeySmtpPassword] != osSmtpPassword { + envData[constants.EnvKeySmtpPassword] = osSmtpPassword + } + + if val, ok := envData[constants.EnvKeySenderEmail]; !ok || val == "" { + envData[constants.EnvKeySenderEmail] = osSenderEmail + } + if osSenderEmail != "" && envData[constants.EnvKeySenderEmail] != osSenderEmail { + envData[constants.EnvKeySenderEmail] = osSenderEmail + } + + algoVal, ok := envData[constants.EnvKeyJwtType] + algo := "" + if !ok || algoVal == "" { + envData[constants.EnvKeyJwtType] = osJwtType + if envData[constants.EnvKeyJwtType] == "" { + envData[constants.EnvKeyJwtType] = "RS256" + algo = envData[constants.EnvKeyJwtType].(string) + } + } else { + algo = algoVal.(string) + if !crypto.IsHMACA(algo) && !crypto.IsRSA(algo) && !crypto.IsECDSA(algo) { + log.Debug("Invalid JWT Algorithm") + return errors.New("invalid JWT_TYPE") + } + } + if osJwtType != "" && osJwtType != algo { + if !crypto.IsHMACA(osJwtType) && !crypto.IsRSA(osJwtType) && !crypto.IsECDSA(osJwtType) { + log.Debug("Invalid JWT Algorithm") + return errors.New("invalid JWT_TYPE") + } + algo = osJwtType + envData[constants.EnvKeyJwtType] = osJwtType + } + if crypto.IsHMACA(algo) { - if envData.StringEnv[constants.EnvKeyJwtSecret] == "" { - envData.StringEnv[constants.EnvKeyJwtSecret] = os.Getenv(constants.EnvKeyJwtSecret) - if envData.StringEnv[constants.EnvKeyJwtSecret] == "" { - envData.StringEnv[constants.EnvKeyJwtSecret], _, err = crypto.NewHMACKey(algo, clientID) + if val, ok := envData[constants.EnvKeyJwtSecret]; !ok || val == "" { + envData[constants.EnvKeyJwtSecret] = osJwtSecret + if envData[constants.EnvKeyJwtSecret] == "" { + envData[constants.EnvKeyJwtSecret], _, err = crypto.NewHMACKey(algo, clientID) if err != nil { return err } } } + if osJwtSecret != "" && envData[constants.EnvKeyJwtSecret] != osJwtSecret { + envData[constants.EnvKeyJwtSecret] = osJwtSecret + } } if crypto.IsRSA(algo) || crypto.IsECDSA(algo) { privateKey, publicKey := "", "" - if envData.StringEnv[constants.EnvKeyJwtPrivateKey] == "" { - privateKey = os.Getenv(constants.EnvKeyJwtPrivateKey) + if val, ok := envData[constants.EnvKeyJwtPrivateKey]; !ok || val == "" { + privateKey = osJwtPrivateKey + } + if osJwtPrivateKey != "" && privateKey != osJwtPrivateKey { + privateKey = osJwtPrivateKey } - if envData.StringEnv[constants.EnvKeyJwtPublicKey] == "" { - publicKey = os.Getenv(constants.EnvKeyJwtPublicKey) + if val, ok := envData[constants.EnvKeyJwtPublicKey]; !ok || val == "" { + publicKey = osJwtPublicKey + } + if osJwtPublicKey != "" && publicKey != osJwtPublicKey { + publicKey = osJwtPublicKey } // if algo is RSA / ECDSA, then we need to have both private and public key @@ -250,159 +280,232 @@ func InitAllEnv() error { } } - envData.StringEnv[constants.EnvKeyJwtPrivateKey] = privateKey - envData.StringEnv[constants.EnvKeyJwtPublicKey] = publicKey + envData[constants.EnvKeyJwtPrivateKey] = privateKey + envData[constants.EnvKeyJwtPublicKey] = publicKey } - if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" { - envData.StringEnv[constants.EnvKeyJwtRoleClaim] = os.Getenv(constants.EnvKeyJwtRoleClaim) + if val, ok := envData[constants.EnvKeyJwtRoleClaim]; !ok || val == "" { + envData[constants.EnvKeyJwtRoleClaim] = osJwtRoleClaim - if envData.StringEnv[constants.EnvKeyJwtRoleClaim] == "" { - envData.StringEnv[constants.EnvKeyJwtRoleClaim] = "role" + if envData[constants.EnvKeyJwtRoleClaim] == "" { + envData[constants.EnvKeyJwtRoleClaim] = "role" + } + } + if osJwtRoleClaim != "" && envData[constants.EnvKeyJwtRoleClaim] != osJwtRoleClaim { + envData[constants.EnvKeyJwtRoleClaim] = osJwtRoleClaim + } + + if val, ok := envData[constants.EnvKeyCustomAccessTokenScript]; !ok || val == "" { + envData[constants.EnvKeyCustomAccessTokenScript] = osCustomAccessTokenScript + } + if osCustomAccessTokenScript != "" && envData[constants.EnvKeyCustomAccessTokenScript] != osCustomAccessTokenScript { + envData[constants.EnvKeyCustomAccessTokenScript] = osCustomAccessTokenScript + } + + if val, ok := envData[constants.EnvKeyGoogleClientID]; !ok || val == "" { + envData[constants.EnvKeyGoogleClientID] = osGoogleClientID + } + if osGoogleClientID != "" && envData[constants.EnvKeyGoogleClientID] != osGoogleClientID { + envData[constants.EnvKeyGoogleClientID] = osGoogleClientID + } + + if val, ok := envData[constants.EnvKeyGoogleClientSecret]; !ok || val == "" { + envData[constants.EnvKeyGoogleClientSecret] = osGoogleClientSecret + } + if osGoogleClientSecret != "" && envData[constants.EnvKeyGoogleClientSecret] != osGoogleClientSecret { + envData[constants.EnvKeyGoogleClientSecret] = osGoogleClientSecret + } + + if val, ok := envData[constants.EnvKeyGithubClientID]; !ok || val == "" { + envData[constants.EnvKeyGithubClientID] = osGithubClientID + } + if osGithubClientID != "" && envData[constants.EnvKeyGithubClientID] != osGithubClientID { + envData[constants.EnvKeyGithubClientID] = osGithubClientID + } + + if val, ok := envData[constants.EnvKeyGithubClientSecret]; !ok || val == "" { + envData[constants.EnvKeyGithubClientSecret] = osGithubClientSecret + } + if osGithubClientSecret != "" && envData[constants.EnvKeyGithubClientSecret] != osGithubClientSecret { + envData[constants.EnvKeyGithubClientSecret] = osGithubClientSecret + } + + if val, ok := envData[constants.EnvKeyFacebookClientID]; !ok || val == "" { + envData[constants.EnvKeyFacebookClientID] = osFacebookClientID + } + if osFacebookClientID != "" && envData[constants.EnvKeyFacebookClientID] != osFacebookClientID { + envData[constants.EnvKeyFacebookClientID] = osFacebookClientID + } + + if val, ok := envData[constants.EnvKeyFacebookClientSecret]; !ok || val == "" { + envData[constants.EnvKeyFacebookClientSecret] = osFacebookClientSecret + } + if osFacebookClientSecret != "" && envData[constants.EnvKeyFacebookClientSecret] != osFacebookClientSecret { + envData[constants.EnvKeyFacebookClientSecret] = osFacebookClientSecret + } + + if val, ok := envData[constants.EnvKeyResetPasswordURL]; !ok || val == "" { + envData[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(osResetPasswordURL, "/") + } + if osResetPasswordURL != "" && envData[constants.EnvKeyResetPasswordURL] != osResetPasswordURL { + envData[constants.EnvKeyResetPasswordURL] = osResetPasswordURL + } + + if val, ok := envData[constants.EnvKeyOrganizationName]; !ok || val == "" { + envData[constants.EnvKeyOrganizationName] = osOrganizationName + } + if osOrganizationName != "" && envData[constants.EnvKeyOrganizationName] != osOrganizationName { + envData[constants.EnvKeyOrganizationName] = osOrganizationName + } + + if val, ok := envData[constants.EnvKeyOrganizationLogo]; !ok || val == "" { + envData[constants.EnvKeyOrganizationLogo] = osOrganizationLogo + } + if osOrganizationLogo != "" && envData[constants.EnvKeyOrganizationLogo] != osOrganizationLogo { + envData[constants.EnvKeyOrganizationLogo] = osOrganizationLogo + } + + if _, ok := envData[constants.EnvKeyDisableBasicAuthentication]; !ok { + envData[constants.EnvKeyDisableBasicAuthentication] = osDisableBasicAuthentication == "true" + } + if osDisableBasicAuthentication != "" { + boolValue, err := strconv.ParseBool(osDisableBasicAuthentication) + if err != nil { + return err + } + if boolValue != envData[constants.EnvKeyDisableBasicAuthentication].(bool) { + envData[constants.EnvKeyDisableBasicAuthentication] = boolValue } } - if envData.StringEnv[constants.EnvKeyCustomAccessTokenScript] == "" { - envData.StringEnv[constants.EnvKeyCustomAccessTokenScript] = os.Getenv(constants.EnvKeyCustomAccessTokenScript) + if _, ok := envData[constants.EnvKeyDisableEmailVerification]; !ok { + envData[constants.EnvKeyDisableEmailVerification] = osDisableEmailVerification == "true" } - - if envData.StringEnv[constants.EnvKeyRedisURL] == "" { - envData.StringEnv[constants.EnvKeyRedisURL] = os.Getenv(constants.EnvKeyRedisURL) - } - - if envData.StringEnv[constants.EnvKeyCookieName] == "" { - envData.StringEnv[constants.EnvKeyCookieName] = os.Getenv(constants.EnvKeyCookieName) - if envData.StringEnv[constants.EnvKeyCookieName] == "" { - envData.StringEnv[constants.EnvKeyCookieName] = "authorizer" + if osDisableEmailVerification != "" { + boolValue, err := strconv.ParseBool(osDisableEmailVerification) + if err != nil { + return err + } + if boolValue != envData[constants.EnvKeyDisableEmailVerification].(bool) { + envData[constants.EnvKeyDisableEmailVerification] = boolValue } } - if envData.StringEnv[constants.EnvKeyGoogleClientID] == "" { - envData.StringEnv[constants.EnvKeyGoogleClientID] = os.Getenv(constants.EnvKeyGoogleClientID) + if _, ok := envData[constants.EnvKeyDisableMagicLinkLogin]; !ok { + envData[constants.EnvKeyDisableMagicLinkLogin] = osDisableMagicLinkLogin == "true" + } + if osDisableMagicLinkLogin != "" { + boolValue, err := strconv.ParseBool(osDisableMagicLinkLogin) + if err != nil { + return err + } + if boolValue != envData[constants.EnvKeyDisableMagicLinkLogin].(bool) { + envData[constants.EnvKeyDisableMagicLinkLogin] = boolValue + } } - if envData.StringEnv[constants.EnvKeyGoogleClientSecret] == "" { - envData.StringEnv[constants.EnvKeyGoogleClientSecret] = os.Getenv(constants.EnvKeyGoogleClientSecret) + if _, ok := envData[constants.EnvKeyDisableLoginPage]; !ok { + envData[constants.EnvKeyDisableLoginPage] = osDisableLoginPage == "true" + } + if osDisableLoginPage != "" { + boolValue, err := strconv.ParseBool(osDisableLoginPage) + if err != nil { + return err + } + if boolValue != envData[constants.EnvKeyDisableLoginPage].(bool) { + envData[constants.EnvKeyDisableLoginPage] = boolValue + } } - if envData.StringEnv[constants.EnvKeyGithubClientID] == "" { - envData.StringEnv[constants.EnvKeyGithubClientID] = os.Getenv(constants.EnvKeyGithubClientID) + if _, ok := envData[constants.EnvKeyDisableSignUp]; !ok { + envData[constants.EnvKeyDisableSignUp] = osDisableSignUp == "true" + } + if osDisableSignUp != "" { + boolValue, err := strconv.ParseBool(osDisableSignUp) + if err != nil { + return err + } + if boolValue != envData[constants.EnvKeyDisableSignUp].(bool) { + envData[constants.EnvKeyDisableSignUp] = boolValue + } } - if envData.StringEnv[constants.EnvKeyGithubClientSecret] == "" { - envData.StringEnv[constants.EnvKeyGithubClientSecret] = os.Getenv(constants.EnvKeyGithubClientSecret) + if _, ok := envData[constants.EnvKeyDisableRedisForEnv]; !ok { + envData[constants.EnvKeyDisableRedisForEnv] = osDisableRedisForEnv == "true" } - - if envData.StringEnv[constants.EnvKeyFacebookClientID] == "" { - envData.StringEnv[constants.EnvKeyFacebookClientID] = os.Getenv(constants.EnvKeyFacebookClientID) + if osDisableRedisForEnv != "" { + boolValue, err := strconv.ParseBool(osDisableRedisForEnv) + if err != nil { + return err + } + if boolValue != envData[constants.EnvKeyDisableRedisForEnv].(bool) { + envData[constants.EnvKeyDisableRedisForEnv] = boolValue + } } - if envData.StringEnv[constants.EnvKeyFacebookClientSecret] == "" { - envData.StringEnv[constants.EnvKeyFacebookClientSecret] = os.Getenv(constants.EnvKeyFacebookClientSecret) - } - - if envData.StringEnv[constants.EnvKeyResetPasswordURL] == "" { - envData.StringEnv[constants.EnvKeyResetPasswordURL] = strings.TrimPrefix(os.Getenv(constants.EnvKeyResetPasswordURL), "/") - } - - envData.BoolEnv[constants.EnvKeyDisableBasicAuthentication] = os.Getenv(constants.EnvKeyDisableBasicAuthentication) == "true" - envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = os.Getenv(constants.EnvKeyDisableEmailVerification) == "true" - envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = os.Getenv(constants.EnvKeyDisableMagicLinkLogin) == "true" - envData.BoolEnv[constants.EnvKeyDisableLoginPage] = os.Getenv(constants.EnvKeyDisableLoginPage) == "true" - envData.BoolEnv[constants.EnvKeyDisableSignUp] = os.Getenv(constants.EnvKeyDisableSignUp) == "true" - // no need to add nil check as its already done above - if envData.StringEnv[constants.EnvKeySmtpHost] == "" || envData.StringEnv[constants.EnvKeySmtpUsername] == "" || envData.StringEnv[constants.EnvKeySmtpPassword] == "" || envData.StringEnv[constants.EnvKeySenderEmail] == "" && envData.StringEnv[constants.EnvKeySmtpPort] == "" { - envData.BoolEnv[constants.EnvKeyDisableEmailVerification] = true - envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true + if envData[constants.EnvKeySmtpHost] == "" || envData[constants.EnvKeySmtpUsername] == "" || envData[constants.EnvKeySmtpPassword] == "" || envData[constants.EnvKeySenderEmail] == "" && envData[constants.EnvKeySmtpPort] == "" { + envData[constants.EnvKeyDisableEmailVerification] = true + envData[constants.EnvKeyDisableMagicLinkLogin] = true } - if envData.BoolEnv[constants.EnvKeyDisableEmailVerification] { - envData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true + if envData[constants.EnvKeyDisableEmailVerification].(bool) { + envData[constants.EnvKeyDisableMagicLinkLogin] = true } - allowedOriginsSplit := strings.Split(os.Getenv(constants.EnvKeyAllowedOrigins), ",") - allowedOrigins := []string{} - hasWildCard := false + if val, ok := envData[constants.EnvKeyAllowedOrigins]; !ok || val == "" { + envData[constants.EnvKeyAllowedOrigins] = osAllowedOrigins + if envData[constants.EnvKeyAllowedOrigins] == "" { + envData[constants.EnvKeyAllowedOrigins] = "*" + } + } + if osAllowedOrigins != "" && envData[constants.EnvKeyAllowedOrigins] != osAllowedOrigins { + envData[constants.EnvKeyAllowedOrigins] = osAllowedOrigins + } - for _, val := range allowedOriginsSplit { - trimVal := strings.TrimSpace(val) - if trimVal != "" { - if trimVal != "*" { - host, port := utils.GetHostParts(trimVal) - allowedOrigins = append(allowedOrigins, host+":"+port) - } else { - hasWildCard = true - allowedOrigins = append(allowedOrigins, trimVal) - break - } + if val, ok := envData[constants.EnvKeyRoles]; !ok || val == "" { + envData[constants.EnvKeyRoles] = osRoles + if envData[constants.EnvKeyRoles] == "" { + envData[constants.EnvKeyRoles] = "user" + } + } + if osRoles != "" && envData[constants.EnvKeyRoles] != osRoles { + envData[constants.EnvKeyRoles] = osRoles + } + roles := strings.Split(envData[constants.EnvKeyRoles].(string), ",") + + if val, ok := envData[constants.EnvKeyDefaultRoles]; !ok || val == "" { + envData[constants.EnvKeyDefaultRoles] = osDefaultRoles + if envData[constants.EnvKeyDefaultRoles] == "" { + envData[constants.EnvKeyDefaultRoles] = "user" + } + } + if osDefaultRoles != "" && envData[constants.EnvKeyDefaultRoles] != osDefaultRoles { + envData[constants.EnvKeyDefaultRoles] = osDefaultRoles + } + defaultRoles := strings.Split(envData[constants.EnvKeyDefaultRoles].(string), ",") + if len(defaultRoles) == 0 { + defaultRoles = []string{roles[0]} + } + + for _, role := range defaultRoles { + if !utils.StringSliceContains(roles, role) { + return fmt.Errorf("Default role %s is not defined in roles", role) } } - if len(allowedOrigins) > 1 && hasWildCard { - allowedOrigins = []string{"*"} + if val, ok := envData[constants.EnvKeyProtectedRoles]; !ok || val == "" { + envData[constants.EnvKeyProtectedRoles] = osProtectedRoles + } + if osProtectedRoles != "" && envData[constants.EnvKeyProtectedRoles] != osProtectedRoles { + envData[constants.EnvKeyProtectedRoles] = osProtectedRoles } - if len(allowedOrigins) == 0 { - allowedOrigins = []string{"*"} + err = memorystore.Provider.UpdateEnvStore(envData) + if err != nil { + log.Debug("Error while updating env store: ", err) + return err } - - envData.SliceEnv[constants.EnvKeyAllowedOrigins] = allowedOrigins - - rolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyRoles)) - rolesSplit := strings.Split(rolesEnv, ",") - roles := []string{} - if len(rolesEnv) == 0 { - roles = []string{"user"} - } - - defaultRolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyDefaultRoles)) - defaultRoleSplit := strings.Split(defaultRolesEnv, ",") - defaultRoles := []string{} - - if len(defaultRolesEnv) == 0 { - defaultRoles = []string{"user"} - } - - protectedRolesEnv := strings.TrimSpace(os.Getenv(constants.EnvKeyProtectedRoles)) - protectedRolesSplit := strings.Split(protectedRolesEnv, ",") - protectedRoles := []string{} - - if len(protectedRolesEnv) > 0 { - for _, val := range protectedRolesSplit { - trimVal := strings.TrimSpace(val) - protectedRoles = append(protectedRoles, trimVal) - } - } - - for _, val := range rolesSplit { - trimVal := strings.TrimSpace(val) - if trimVal != "" { - roles = append(roles, trimVal) - if utils.StringSliceContains(defaultRoleSplit, trimVal) { - defaultRoles = append(defaultRoles, trimVal) - } - } - } - - if len(roles) > 0 && len(defaultRoles) == 0 && len(defaultRolesEnv) > 0 { - log.Debug("Default roles not found in roles list. It can be one from ROLES only") - return errors.New(`invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`) - } - - envData.SliceEnv[constants.EnvKeyRoles] = roles - envData.SliceEnv[constants.EnvKeyDefaultRoles] = defaultRoles - envData.SliceEnv[constants.EnvKeyProtectedRoles] = protectedRoles - - if os.Getenv(constants.EnvKeyOrganizationName) != "" { - envData.StringEnv[constants.EnvKeyOrganizationName] = os.Getenv(constants.EnvKeyOrganizationName) - } - - if os.Getenv(constants.EnvKeyOrganizationLogo) != "" { - envData.StringEnv[constants.EnvKeyOrganizationLogo] = os.Getenv(constants.EnvKeyOrganizationLogo) - } - - envstore.EnvStoreObj.UpdateEnvStore(envData) return nil } diff --git a/server/env/persist_env.go b/server/env/persist_env.go index 9b3c23e..80646aa 100644 --- a/server/env/persist_env.go +++ b/server/env/persist_env.go @@ -3,6 +3,7 @@ package env import ( "encoding/json" "os" + "reflect" "strconv" "strings" @@ -13,13 +14,50 @@ import ( "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) +func fixBackwardCompatibility(data map[string]interface{}) (bool, map[string]interface{}) { + result := data + // check if env data is stored in older format + hasOlderFormat := false + if _, ok := result["bool_env"]; ok { + for key, value := range result["bool_env"].(map[string]interface{}) { + result[key] = value + } + hasOlderFormat = true + delete(result, "bool_env") + } + + if _, ok := result["string_env"]; ok { + for key, value := range result["string_env"].(map[string]interface{}) { + result[key] = value + } + hasOlderFormat = true + delete(result, "string_env") + } + + if _, ok := result["slice_env"]; ok { + for key, value := range result["slice_env"].(map[string]interface{}) { + typeOfValue := reflect.TypeOf(value) + if strings.Contains(typeOfValue.String(), "[]string") { + result[key] = strings.Join(value.([]string), ",") + } + if strings.Contains(typeOfValue.String(), "[]interface") { + result[key] = strings.Join(utils.ConvertInterfaceToStringSlice(value), ",") + } + } + hasOlderFormat = true + delete(result, "slice_env") + } + + return hasOlderFormat, result +} + // GetEnvData returns the env data from database -func GetEnvData() (envstore.Store, error) { - var result envstore.Store +func GetEnvData() (map[string]interface{}, error) { + var result map[string]interface{} env, err := db.Provider.GetEnv() // config not found in db if err != nil { @@ -34,7 +72,7 @@ func GetEnvData() (envstore.Store, error) { return result, err } - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, decryptedEncryptionKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, decryptedEncryptionKey) b64DecryptedConfig, err := crypto.DecryptB64(env.EnvData) if err != nil { @@ -54,6 +92,17 @@ func GetEnvData() (envstore.Store, error) { return result, err } + hasOlderFormat, result := fixBackwardCompatibility(result) + + if hasOlderFormat { + err = memorystore.Provider.UpdateEnvStore(result) + if err != nil { + log.Debug("Error while updating env store: ", err) + return result, err + } + + } + return result, err } @@ -64,10 +113,20 @@ func PersistEnv() error { if err != nil { // AES encryption needs 32 bit key only, so we chop off last 4 characters from 36 bit uuid hash := uuid.New().String()[:36-4] - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, hash) + err := memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, hash) + if err != nil { + log.Debug("Error while updating encryption env variable: ", err) + return err + } encodedHash := crypto.EncryptB64(hash) - encryptedConfig, err := crypto.EncryptEnvData(envstore.EnvStoreObj.GetEnvStoreClone()) + res, err := memorystore.Provider.GetEnvStore() + if err != nil { + log.Debug("Error while getting env store: ", err) + return err + } + + encryptedConfig, err := crypto.EncryptEnvData(res) if err != nil { log.Debug("Error while encrypting env data: ", err) return err @@ -93,7 +152,7 @@ func PersistEnv() error { return err } - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, decryptedEncryptionKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEncryptionKey, decryptedEncryptionKey) b64DecryptedConfig, err := crypto.DecryptB64(env.EnvData) if err != nil { @@ -108,7 +167,7 @@ func PersistEnv() error { } // temp store variable - var storeData envstore.Store + storeData := map[string]interface{}{} err = json.Unmarshal(decryptedConfigs, &storeData) if err != nil { @@ -116,75 +175,73 @@ func PersistEnv() error { return err } + hasOlderFormat, result := fixBackwardCompatibility(storeData) + if hasOlderFormat { + err = memorystore.Provider.UpdateEnvStore(result) + if err != nil { + log.Debug("Error while updating env store: ", err) + return err + } + + } + // if env is changed via env file or OS env // give that higher preference and update db, but we don't recommend it hasChanged := false - - for key, value := range storeData.StringEnv { + for key, value := range storeData { // don't override unexposed envs + // check only for derivative keys + // No need to check for ENCRYPTION_KEY which special key we use for encrypting config data + // as we have removed it from json if key != constants.EnvKeyEncryptionKey { - // check only for derivative keys - // No need to check for ENCRYPTION_KEY which special key we use for encrypting config data - // as we have removed it from json envValue := strings.TrimSpace(os.Getenv(key)) - - // env is not empty if envValue != "" { - if value != envValue { - storeData.StringEnv[key] = envValue - hasChanged = true + switch key { + case constants.EnvKeyIsProd, constants.EnvKeyDisableBasicAuthentication, constants.EnvKeyDisableEmailVerification, constants.EnvKeyDisableLoginPage, constants.EnvKeyDisableMagicLinkLogin, constants.EnvKeyDisableSignUp, constants.EnvKeyDisableRedisForEnv: + if envValueBool, err := strconv.ParseBool(envValue); err == nil { + if value.(bool) != envValueBool { + storeData[key] = envValueBool + hasChanged = true + } + } + default: + if value != nil && value.(string) != envValue { + storeData[key] = envValue + hasChanged = true + } } } } } - for key, value := range storeData.BoolEnv { - envValue := strings.TrimSpace(os.Getenv(key)) - // env is not empty - if envValue != "" { - envValueBool, _ := strconv.ParseBool(envValue) - if value != envValueBool { - storeData.BoolEnv[key] = envValueBool - hasChanged = true - } - } - } - - for key, value := range storeData.SliceEnv { - envValue := strings.TrimSpace(os.Getenv(key)) - // env is not empty - if envValue != "" { - envStringArr := strings.Split(envValue, ",") - if !utils.IsStringArrayEqual(value, envStringArr) { - storeData.SliceEnv[key] = envStringArr - hasChanged = true - } - } - } - // handle derivative cases like disabling email verification & magic login // in case SMTP is off but env is set to true - if storeData.StringEnv[constants.EnvKeySmtpHost] == "" || storeData.StringEnv[constants.EnvKeySmtpUsername] == "" || storeData.StringEnv[constants.EnvKeySmtpPassword] == "" || storeData.StringEnv[constants.EnvKeySenderEmail] == "" && storeData.StringEnv[constants.EnvKeySmtpPort] == "" { - if !storeData.BoolEnv[constants.EnvKeyDisableEmailVerification] { - storeData.BoolEnv[constants.EnvKeyDisableEmailVerification] = true + if storeData[constants.EnvKeySmtpHost] == "" || storeData[constants.EnvKeySmtpUsername] == "" || storeData[constants.EnvKeySmtpPassword] == "" || storeData[constants.EnvKeySenderEmail] == "" && storeData[constants.EnvKeySmtpPort] == "" { + if !storeData[constants.EnvKeyDisableEmailVerification].(bool) { + storeData[constants.EnvKeyDisableEmailVerification] = true hasChanged = true } - if !storeData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] { - storeData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true + if !storeData[constants.EnvKeyDisableMagicLinkLogin].(bool) { + storeData[constants.EnvKeyDisableMagicLinkLogin] = true hasChanged = true } } - envstore.EnvStoreObj.UpdateEnvStore(storeData) + err = memorystore.Provider.UpdateEnvStore(storeData) + if err != nil { + log.Debug("Error while updating env store: ", err) + return err + } + jwk, err := crypto.GenerateJWKBasedOnEnv() if err != nil { log.Debug("Error while generating JWK: ", err) return err } // updating jwk - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJWK, jwk) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJWK, jwk) if hasChanged { encryptedConfig, err := crypto.EncryptEnvData(storeData) diff --git a/server/envstore/store.go b/server/envstore/store.go deleted file mode 100644 index d2f5487..0000000 --- a/server/envstore/store.go +++ /dev/null @@ -1,122 +0,0 @@ -package envstore - -import ( - "sync" - - "github.com/authorizerdev/authorizer/server/constants" -) - -var ( - // ARG_DB_URL is the cli arg variable for the database url - ARG_DB_URL *string - // ARG_DB_TYPE is the cli arg variable for the database type - ARG_DB_TYPE *string - // ARG_ENV_FILE is the cli arg variable for the env file - ARG_ENV_FILE *string - // ARG_LOG_LEVEL is the cli arg variable for the log level - ARG_LOG_LEVEL *string -) - -// Store data structure -type Store struct { - StringEnv map[string]string `json:"string_env"` - BoolEnv map[string]bool `json:"bool_env"` - SliceEnv map[string][]string `json:"slice_env"` -} - -// EnvStore struct -type EnvStore struct { - mutex sync.Mutex - store *Store -} - -var defaultStore = &EnvStore{ - store: &Store{ - StringEnv: map[string]string{ - constants.EnvKeyAdminCookieName: "authorizer-admin", - constants.EnvKeyJwtRoleClaim: "role", - constants.EnvKeyOrganizationName: "Authorizer", - constants.EnvKeyOrganizationLogo: "https://www.authorizer.dev/images/logo.png", - }, - BoolEnv: map[string]bool{ - constants.EnvKeyDisableBasicAuthentication: false, - constants.EnvKeyDisableMagicLinkLogin: false, - constants.EnvKeyDisableEmailVerification: false, - constants.EnvKeyDisableLoginPage: false, - constants.EnvKeyDisableSignUp: false, - }, - SliceEnv: map[string][]string{}, - }, -} - -// EnvStoreObj.GetBoolStoreEnvVariable global variable for EnvStore -var EnvStoreObj = defaultStore - -// UpdateEnvStore to update the whole env store object -func (e *EnvStore) UpdateEnvStore(store Store) { - e.mutex.Lock() - defer e.mutex.Unlock() - // just override the keys + new keys - - for key, value := range store.StringEnv { - e.store.StringEnv[key] = value - } - - for key, value := range store.BoolEnv { - e.store.BoolEnv[key] = value - } - - for key, value := range store.SliceEnv { - e.store.SliceEnv[key] = value - } -} - -// UpdateEnvVariable to update the particular env variable -func (e *EnvStore) UpdateEnvVariable(storeIdentifier, key string, value interface{}) { - e.mutex.Lock() - defer e.mutex.Unlock() - switch storeIdentifier { - case constants.StringStoreIdentifier: - e.store.StringEnv[key] = value.(string) - case constants.BoolStoreIdentifier: - e.store.BoolEnv[key] = value.(bool) - case constants.SliceStoreIdentifier: - e.store.SliceEnv[key] = value.([]string) - } -} - -// GetStringStoreEnvVariable to get the env variable from string store object -func (e *EnvStore) GetStringStoreEnvVariable(key string) string { - // e.mutex.Lock() - // defer e.mutex.Unlock() - return e.store.StringEnv[key] -} - -// GetBoolStoreEnvVariable to get the env variable from bool store object -func (e *EnvStore) GetBoolStoreEnvVariable(key string) bool { - // e.mutex.Lock() - // defer e.mutex.Unlock() - return e.store.BoolEnv[key] -} - -// GetSliceStoreEnvVariable to get the env variable from slice store object -func (e *EnvStore) GetSliceStoreEnvVariable(key string) []string { - // e.mutex.Lock() - // defer e.mutex.Unlock() - return e.store.SliceEnv[key] -} - -// GetEnvStoreClone to get clone of current env store object -func (e *EnvStore) GetEnvStoreClone() Store { - e.mutex.Lock() - defer e.mutex.Unlock() - - result := *e.store - return result -} - -func (e *EnvStore) ResetStore() { - e.mutex.Lock() - defer e.mutex.Unlock() - e.store = defaultStore.store -} diff --git a/server/go.mod b/server/go.mod index be8365e..13e3e52 100644 --- a/server/go.mod +++ b/server/go.mod @@ -20,7 +20,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect github.com/robertkrimen/otto v0.0.0-20211024170158-b87d35c0b86f - github.com/sirupsen/logrus v1.8.1 // indirect + github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.7.0 github.com/ugorji/go v1.2.6 // indirect github.com/vektah/gqlparser/v2 v2.2.0 diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index c8cb538..258c6fe 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -59,7 +59,6 @@ type ComplexityRoot struct { AppURL func(childComplexity int) int ClientID func(childComplexity int) int ClientSecret func(childComplexity int) int - CookieName func(childComplexity int) int CustomAccessTokenScript func(childComplexity int) int DatabaseHost func(childComplexity int) int DatabaseName func(childComplexity int) int @@ -73,6 +72,7 @@ type ComplexityRoot struct { DisableEmailVerification func(childComplexity int) int DisableLoginPage func(childComplexity int) int DisableMagicLinkLogin func(childComplexity int) int + DisableRedisForEnv func(childComplexity int) int DisableSignUp func(childComplexity int) int FacebookClientID func(childComplexity int) int FacebookClientSecret func(childComplexity int) int @@ -346,13 +346,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Env.ClientSecret(childComplexity), true - case "Env.COOKIE_NAME": - if e.complexity.Env.CookieName == nil { - break - } - - return e.complexity.Env.CookieName(childComplexity), true - case "Env.CUSTOM_ACCESS_TOKEN_SCRIPT": if e.complexity.Env.CustomAccessTokenScript == nil { break @@ -444,6 +437,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Env.DisableMagicLinkLogin(childComplexity), true + case "Env.DISABLE_REDIS_FOR_ENV": + if e.complexity.Env.DisableRedisForEnv == nil { + break + } + + return e.complexity.Env.DisableRedisForEnv(childComplexity), true + case "Env.DISABLE_SIGN_UP": if e.complexity.Env.DisableSignUp == nil { break @@ -1423,13 +1423,13 @@ type Response { type Env { ACCESS_TOKEN_EXPIRY_TIME: String ADMIN_SECRET: String - DATABASE_NAME: String! - DATABASE_URL: String! - DATABASE_TYPE: String! - DATABASE_USERNAME: String! - DATABASE_PASSWORD: String! - DATABASE_HOST: String! - DATABASE_PORT: String! + DATABASE_NAME: String + DATABASE_URL: String + DATABASE_TYPE: String + DATABASE_USERNAME: String + DATABASE_PASSWORD: String + DATABASE_HOST: String + DATABASE_PORT: String CLIENT_ID: String! CLIENT_SECRET: String! CUSTOM_ACCESS_TOKEN_SCRIPT: String @@ -1445,13 +1445,13 @@ type Env { ALLOWED_ORIGINS: [String!] APP_URL: String REDIS_URL: String - COOKIE_NAME: String RESET_PASSWORD_URL: String - DISABLE_EMAIL_VERIFICATION: Boolean - DISABLE_BASIC_AUTHENTICATION: Boolean - DISABLE_MAGIC_LINK_LOGIN: Boolean - DISABLE_LOGIN_PAGE: Boolean - DISABLE_SIGN_UP: Boolean + DISABLE_EMAIL_VERIFICATION: Boolean! + DISABLE_BASIC_AUTHENTICATION: Boolean! + DISABLE_MAGIC_LINK_LOGIN: Boolean! + DISABLE_LOGIN_PAGE: Boolean! + DISABLE_SIGN_UP: Boolean! + DISABLE_REDIS_FOR_ENV: Boolean! ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] @@ -1492,14 +1492,13 @@ input UpdateEnvInput { JWT_PUBLIC_KEY: String ALLOWED_ORIGINS: [String!] APP_URL: String - REDIS_URL: String - COOKIE_NAME: String RESET_PASSWORD_URL: String DISABLE_EMAIL_VERIFICATION: Boolean DISABLE_BASIC_AUTHENTICATION: Boolean DISABLE_MAGIC_LINK_LOGIN: Boolean DISABLE_LOGIN_PAGE: Boolean DISABLE_SIGN_UP: Boolean + DISABLE_REDIS_FOR_ENV: Boolean ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] @@ -2356,14 +2355,11 @@ func (ec *executionContext) _Env_DATABASE_NAME(ctx context.Context, field graphq return graphql.Null } if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalOString2áš–string(ctx, field.Selections, res) } func (ec *executionContext) _Env_DATABASE_URL(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -2391,14 +2387,11 @@ func (ec *executionContext) _Env_DATABASE_URL(ctx context.Context, field graphql return graphql.Null } if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalOString2áš–string(ctx, field.Selections, res) } func (ec *executionContext) _Env_DATABASE_TYPE(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -2426,14 +2419,11 @@ func (ec *executionContext) _Env_DATABASE_TYPE(ctx context.Context, field graphq return graphql.Null } if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalOString2áš–string(ctx, field.Selections, res) } func (ec *executionContext) _Env_DATABASE_USERNAME(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -2461,14 +2451,11 @@ func (ec *executionContext) _Env_DATABASE_USERNAME(ctx context.Context, field gr return graphql.Null } if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalOString2áš–string(ctx, field.Selections, res) } func (ec *executionContext) _Env_DATABASE_PASSWORD(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -2496,14 +2483,11 @@ func (ec *executionContext) _Env_DATABASE_PASSWORD(ctx context.Context, field gr return graphql.Null } if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalOString2áš–string(ctx, field.Selections, res) } func (ec *executionContext) _Env_DATABASE_HOST(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -2531,14 +2515,11 @@ func (ec *executionContext) _Env_DATABASE_HOST(ctx context.Context, field graphq return graphql.Null } if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalOString2áš–string(ctx, field.Selections, res) } func (ec *executionContext) _Env_DATABASE_PORT(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -2566,14 +2547,11 @@ func (ec *executionContext) _Env_DATABASE_PORT(ctx context.Context, field graphq return graphql.Null } if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } return graphql.Null } - res := resTmp.(string) + res := resTmp.(*string) fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) + return ec.marshalOString2áš–string(ctx, field.Selections, res) } func (ec *executionContext) _Env_CLIENT_ID(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -3062,38 +3040,6 @@ func (ec *executionContext) _Env_REDIS_URL(ctx context.Context, field graphql.Co return ec.marshalOString2áš–string(ctx, field.Selections, res) } -func (ec *executionContext) _Env_COOKIE_NAME(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - fc := &graphql.FieldContext{ - Object: "Env", - Field: field, - Args: nil, - IsMethod: false, - IsResolver: false, - } - - ctx = graphql.WithFieldContext(ctx, fc) - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return obj.CookieName, nil - }) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - if resTmp == nil { - return graphql.Null - } - res := resTmp.(*string) - fc.Result = res - return ec.marshalOString2áš–string(ctx, field.Selections, res) -} - func (ec *executionContext) _Env_RESET_PASSWORD_URL(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -3151,11 +3097,14 @@ func (ec *executionContext) _Env_DISABLE_EMAIL_VERIFICATION(ctx context.Context, return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*bool) + res := resTmp.(bool) fc.Result = res - return ec.marshalOBoolean2áš–bool(ctx, field.Selections, res) + return ec.marshalNBoolean2bool(ctx, field.Selections, res) } func (ec *executionContext) _Env_DISABLE_BASIC_AUTHENTICATION(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -3183,11 +3132,14 @@ func (ec *executionContext) _Env_DISABLE_BASIC_AUTHENTICATION(ctx context.Contex return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*bool) + res := resTmp.(bool) fc.Result = res - return ec.marshalOBoolean2áš–bool(ctx, field.Selections, res) + return ec.marshalNBoolean2bool(ctx, field.Selections, res) } func (ec *executionContext) _Env_DISABLE_MAGIC_LINK_LOGIN(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -3215,11 +3167,14 @@ func (ec *executionContext) _Env_DISABLE_MAGIC_LINK_LOGIN(ctx context.Context, f return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*bool) + res := resTmp.(bool) fc.Result = res - return ec.marshalOBoolean2áš–bool(ctx, field.Selections, res) + return ec.marshalNBoolean2bool(ctx, field.Selections, res) } func (ec *executionContext) _Env_DISABLE_LOGIN_PAGE(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -3247,11 +3202,14 @@ func (ec *executionContext) _Env_DISABLE_LOGIN_PAGE(ctx context.Context, field g return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*bool) + res := resTmp.(bool) fc.Result = res - return ec.marshalOBoolean2áš–bool(ctx, field.Selections, res) + return ec.marshalNBoolean2bool(ctx, field.Selections, res) } func (ec *executionContext) _Env_DISABLE_SIGN_UP(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -3279,11 +3237,49 @@ func (ec *executionContext) _Env_DISABLE_SIGN_UP(ctx context.Context, field grap return graphql.Null } if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } return graphql.Null } - res := resTmp.(*bool) + res := resTmp.(bool) fc.Result = res - return ec.marshalOBoolean2áš–bool(ctx, field.Selections, res) + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + +func (ec *executionContext) _Env_DISABLE_REDIS_FOR_ENV(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "Env", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.DisableRedisForEnv, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) } func (ec *executionContext) _Env_ROLES(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) { @@ -8431,22 +8427,6 @@ func (ec *executionContext) unmarshalInputUpdateEnvInput(ctx context.Context, ob if err != nil { return it, err } - case "REDIS_URL": - var err error - - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("REDIS_URL")) - it.RedisURL, err = ec.unmarshalOString2áš–string(ctx, v) - if err != nil { - return it, err - } - case "COOKIE_NAME": - var err error - - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("COOKIE_NAME")) - it.CookieName, err = ec.unmarshalOString2áš–string(ctx, v) - if err != nil { - return it, err - } case "RESET_PASSWORD_URL": var err error @@ -8495,6 +8475,14 @@ func (ec *executionContext) unmarshalInputUpdateEnvInput(ctx context.Context, ob if err != nil { return it, err } + case "DISABLE_REDIS_FOR_ENV": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("DISABLE_REDIS_FOR_ENV")) + it.DisableRedisForEnv, err = ec.unmarshalOBoolean2áš–bool(ctx, v) + if err != nil { + return it, err + } case "ROLES": var err error @@ -8943,39 +8931,18 @@ func (ec *executionContext) _Env(ctx context.Context, sel ast.SelectionSet, obj out.Values[i] = ec._Env_ADMIN_SECRET(ctx, field, obj) case "DATABASE_NAME": out.Values[i] = ec._Env_DATABASE_NAME(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } case "DATABASE_URL": out.Values[i] = ec._Env_DATABASE_URL(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } case "DATABASE_TYPE": out.Values[i] = ec._Env_DATABASE_TYPE(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } case "DATABASE_USERNAME": out.Values[i] = ec._Env_DATABASE_USERNAME(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } case "DATABASE_PASSWORD": out.Values[i] = ec._Env_DATABASE_PASSWORD(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } case "DATABASE_HOST": out.Values[i] = ec._Env_DATABASE_HOST(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } case "DATABASE_PORT": out.Values[i] = ec._Env_DATABASE_PORT(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } case "CLIENT_ID": out.Values[i] = ec._Env_CLIENT_ID(ctx, field, obj) if out.Values[i] == graphql.Null { @@ -9012,20 +8979,38 @@ func (ec *executionContext) _Env(ctx context.Context, sel ast.SelectionSet, obj out.Values[i] = ec._Env_APP_URL(ctx, field, obj) case "REDIS_URL": out.Values[i] = ec._Env_REDIS_URL(ctx, field, obj) - case "COOKIE_NAME": - out.Values[i] = ec._Env_COOKIE_NAME(ctx, field, obj) case "RESET_PASSWORD_URL": out.Values[i] = ec._Env_RESET_PASSWORD_URL(ctx, field, obj) case "DISABLE_EMAIL_VERIFICATION": out.Values[i] = ec._Env_DISABLE_EMAIL_VERIFICATION(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "DISABLE_BASIC_AUTHENTICATION": out.Values[i] = ec._Env_DISABLE_BASIC_AUTHENTICATION(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "DISABLE_MAGIC_LINK_LOGIN": out.Values[i] = ec._Env_DISABLE_MAGIC_LINK_LOGIN(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "DISABLE_LOGIN_PAGE": out.Values[i] = ec._Env_DISABLE_LOGIN_PAGE(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "DISABLE_SIGN_UP": out.Values[i] = ec._Env_DISABLE_SIGN_UP(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } + case "DISABLE_REDIS_FOR_ENV": + out.Values[i] = ec._Env_DISABLE_REDIS_FOR_ENV(ctx, field, obj) + if out.Values[i] == graphql.Null { + invalids++ + } case "ROLES": out.Values[i] = ec._Env_ROLES(ctx, field, obj) case "PROTECTED_ROLES": diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index 1660c9c..2f24f7c 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -26,13 +26,13 @@ type DeleteUserInput struct { type Env struct { AccessTokenExpiryTime *string `json:"ACCESS_TOKEN_EXPIRY_TIME"` AdminSecret *string `json:"ADMIN_SECRET"` - DatabaseName string `json:"DATABASE_NAME"` - DatabaseURL string `json:"DATABASE_URL"` - DatabaseType string `json:"DATABASE_TYPE"` - DatabaseUsername string `json:"DATABASE_USERNAME"` - DatabasePassword string `json:"DATABASE_PASSWORD"` - DatabaseHost string `json:"DATABASE_HOST"` - DatabasePort string `json:"DATABASE_PORT"` + DatabaseName *string `json:"DATABASE_NAME"` + DatabaseURL *string `json:"DATABASE_URL"` + DatabaseType *string `json:"DATABASE_TYPE"` + DatabaseUsername *string `json:"DATABASE_USERNAME"` + DatabasePassword *string `json:"DATABASE_PASSWORD"` + DatabaseHost *string `json:"DATABASE_HOST"` + DatabasePort *string `json:"DATABASE_PORT"` ClientID string `json:"CLIENT_ID"` ClientSecret string `json:"CLIENT_SECRET"` CustomAccessTokenScript *string `json:"CUSTOM_ACCESS_TOKEN_SCRIPT"` @@ -48,13 +48,13 @@ type Env struct { AllowedOrigins []string `json:"ALLOWED_ORIGINS"` AppURL *string `json:"APP_URL"` RedisURL *string `json:"REDIS_URL"` - CookieName *string `json:"COOKIE_NAME"` ResetPasswordURL *string `json:"RESET_PASSWORD_URL"` - DisableEmailVerification *bool `json:"DISABLE_EMAIL_VERIFICATION"` - DisableBasicAuthentication *bool `json:"DISABLE_BASIC_AUTHENTICATION"` - DisableMagicLinkLogin *bool `json:"DISABLE_MAGIC_LINK_LOGIN"` - DisableLoginPage *bool `json:"DISABLE_LOGIN_PAGE"` - DisableSignUp *bool `json:"DISABLE_SIGN_UP"` + DisableEmailVerification bool `json:"DISABLE_EMAIL_VERIFICATION"` + DisableBasicAuthentication bool `json:"DISABLE_BASIC_AUTHENTICATION"` + DisableMagicLinkLogin bool `json:"DISABLE_MAGIC_LINK_LOGIN"` + DisableLoginPage bool `json:"DISABLE_LOGIN_PAGE"` + DisableSignUp bool `json:"DISABLE_SIGN_UP"` + DisableRedisForEnv bool `json:"DISABLE_REDIS_FOR_ENV"` Roles []string `json:"ROLES"` ProtectedRoles []string `json:"PROTECTED_ROLES"` DefaultRoles []string `json:"DEFAULT_ROLES"` @@ -199,14 +199,13 @@ type UpdateEnvInput struct { JwtPublicKey *string `json:"JWT_PUBLIC_KEY"` AllowedOrigins []string `json:"ALLOWED_ORIGINS"` AppURL *string `json:"APP_URL"` - RedisURL *string `json:"REDIS_URL"` - CookieName *string `json:"COOKIE_NAME"` ResetPasswordURL *string `json:"RESET_PASSWORD_URL"` DisableEmailVerification *bool `json:"DISABLE_EMAIL_VERIFICATION"` DisableBasicAuthentication *bool `json:"DISABLE_BASIC_AUTHENTICATION"` DisableMagicLinkLogin *bool `json:"DISABLE_MAGIC_LINK_LOGIN"` DisableLoginPage *bool `json:"DISABLE_LOGIN_PAGE"` DisableSignUp *bool `json:"DISABLE_SIGN_UP"` + DisableRedisForEnv *bool `json:"DISABLE_REDIS_FOR_ENV"` Roles []string `json:"ROLES"` ProtectedRoles []string `json:"PROTECTED_ROLES"` DefaultRoles []string `json:"DEFAULT_ROLES"` diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index d3673cd..84797ee 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -89,13 +89,13 @@ type Response { type Env { ACCESS_TOKEN_EXPIRY_TIME: String ADMIN_SECRET: String - DATABASE_NAME: String! - DATABASE_URL: String! - DATABASE_TYPE: String! - DATABASE_USERNAME: String! - DATABASE_PASSWORD: String! - DATABASE_HOST: String! - DATABASE_PORT: String! + DATABASE_NAME: String + DATABASE_URL: String + DATABASE_TYPE: String + DATABASE_USERNAME: String + DATABASE_PASSWORD: String + DATABASE_HOST: String + DATABASE_PORT: String CLIENT_ID: String! CLIENT_SECRET: String! CUSTOM_ACCESS_TOKEN_SCRIPT: String @@ -111,13 +111,13 @@ type Env { ALLOWED_ORIGINS: [String!] APP_URL: String REDIS_URL: String - COOKIE_NAME: String RESET_PASSWORD_URL: String - DISABLE_EMAIL_VERIFICATION: Boolean - DISABLE_BASIC_AUTHENTICATION: Boolean - DISABLE_MAGIC_LINK_LOGIN: Boolean - DISABLE_LOGIN_PAGE: Boolean - DISABLE_SIGN_UP: Boolean + DISABLE_EMAIL_VERIFICATION: Boolean! + DISABLE_BASIC_AUTHENTICATION: Boolean! + DISABLE_MAGIC_LINK_LOGIN: Boolean! + DISABLE_LOGIN_PAGE: Boolean! + DISABLE_SIGN_UP: Boolean! + DISABLE_REDIS_FOR_ENV: Boolean! ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] @@ -158,14 +158,13 @@ input UpdateEnvInput { JWT_PUBLIC_KEY: String ALLOWED_ORIGINS: [String!] APP_URL: String - REDIS_URL: String - COOKIE_NAME: String RESET_PASSWORD_URL: String DISABLE_EMAIL_VERIFICATION: Boolean DISABLE_BASIC_AUTHENTICATION: Boolean DISABLE_MAGIC_LINK_LOGIN: Boolean DISABLE_LOGIN_PAGE: Boolean DISABLE_SIGN_UP: Boolean + DISABLE_REDIS_FOR_ENV: Boolean ROLES: [String!] PROTECTED_ROLES: [String!] DEFAULT_ROLES: [String!] diff --git a/server/handlers/app.go b/server/handlers/app.go index d855db7..5b34fb6 100644 --- a/server/handlers/app.go +++ b/server/handlers/app.go @@ -8,8 +8,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" + "github.com/authorizerdev/authorizer/server/validators" ) // State is the struct that holds authorizer url and redirect url @@ -22,8 +23,8 @@ type State struct { // AppHandler is the handler for the /app route func AppHandler() gin.HandlerFunc { return func(c *gin.Context) { - hostname := utils.GetHost(c) - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableLoginPage) { + hostname := parsers.GetHost(c) + if isLoginPageDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableLoginPage); err != nil || isLoginPageDisabled { log.Debug("Login page is disabled") c.JSON(400, gin.H{"error": "login page is not enabled"}) return @@ -44,7 +45,7 @@ func AppHandler() gin.HandlerFunc { redirect_uri = hostname + "/app" } else { // validate redirect url with allowed origins - if !utils.IsValidOrigin(redirect_uri) { + if !validators.IsValidOrigin(redirect_uri) { log.Debug("Invalid redirect_uri") c.JSON(400, gin.H{"error": "invalid redirect url"}) return @@ -58,14 +59,27 @@ func AppHandler() gin.HandlerFunc { log.Debug("Failed to push file path: ", err) } } + + orgName, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName) + if err != nil { + log.Debug("Failed to get organization name") + c.JSON(400, gin.H{"error": "failed to get organization name"}) + return + } + orgLogo, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo) + if err != nil { + log.Debug("Failed to get organization logo") + c.JSON(400, gin.H{"error": "failed to get organization logo"}) + return + } c.HTML(http.StatusOK, "app.tmpl", gin.H{ "data": map[string]interface{}{ "authorizerURL": hostname, "redirectURL": redirect_uri, "scope": scope, "state": state, - "organizationName": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationName), - "organizationLogo": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyOrganizationLogo), + "organizationName": orgName, + "organizationLogo": orgLogo, }, }) } diff --git a/server/handlers/authorize.go b/server/handlers/authorize.go index d8d6016..1fa84cc 100644 --- a/server/handlers/authorize.go +++ b/server/handlers/authorize.go @@ -13,8 +13,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" ) @@ -80,7 +79,7 @@ func AuthorizeHandler() gin.HandlerFunc { return } - if clientID != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + if client, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID); client != clientID || err != nil { if isQuery { gc.Redirect(http.StatusFound, loginURL) } else { @@ -223,7 +222,10 @@ func AuthorizeHandler() gin.HandlerFunc { // based on the response type, generate the response if isResponseTypeCode { // rollover the session for security - sessionstore.RemoveState(sessionToken) + err = memorystore.Provider.RemoveState(sessionToken) + if err != nil { + log.Debug("Failed to remove state: ", err) + } nonce := uuid.New().String() newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope) if err != nil { @@ -244,10 +246,10 @@ func AuthorizeHandler() gin.HandlerFunc { return } - sessionstore.SetState(newSessionToken, newSessionTokenData.Nonce+"@"+user.ID) + memorystore.Provider.SetState(newSessionToken, newSessionTokenData.Nonce+"@"+user.ID) cookie.SetSession(gc, newSessionToken) code := uuid.New().String() - sessionstore.SetState(codeChallenge, code+"@"+newSessionToken) + memorystore.Provider.SetState(codeChallenge, code+"@"+newSessionToken) gc.HTML(http.StatusOK, template, gin.H{ "target_origin": redirectURI, "authorization_response": map[string]interface{}{ @@ -281,9 +283,9 @@ func AuthorizeHandler() gin.HandlerFunc { } return } - sessionstore.RemoveState(sessionToken) - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.RemoveState(sessionToken) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) cookie.SetSession(gc, authToken.FingerPrintHash) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() @@ -306,7 +308,7 @@ func AuthorizeHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token params += "&refresh_token=" + authToken.RefreshToken.Token - sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) } if isQuery { diff --git a/server/handlers/dashboard.go b/server/handlers/dashboard.go index 7eb7dce..55d1534 100644 --- a/server/handlers/dashboard.go +++ b/server/handlers/dashboard.go @@ -4,7 +4,7 @@ import ( "net/http" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gin-gonic/gin" ) @@ -12,8 +12,8 @@ import ( func DashboardHandler() gin.HandlerFunc { return func(c *gin.Context) { isOnboardingCompleted := false - - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) != "" { + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil || adminSecret != "" { isOnboardingCompleted = true } diff --git a/server/handlers/jwks.go b/server/handlers/jwks.go index 2e13dc2..7a2cc54 100644 --- a/server/handlers/jwks.go +++ b/server/handlers/jwks.go @@ -7,14 +7,21 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) func JWKsHandler() gin.HandlerFunc { return func(c *gin.Context) { var data map[string]string - jwk := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJWK) - err := json.Unmarshal([]byte(jwk), &data) + jwk, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJWK) + if err != nil { + log.Debug("Error getting JWK from memorystore: ", err) + c.JSON(500, gin.H{ + "error": err.Error(), + }) + return + } + err = json.Unmarshal([]byte(jwk), &data) if err != nil { log.Debug("Failed to parse JWK: ", err) c.JSON(500, gin.H{ diff --git a/server/handlers/logout.go b/server/handlers/logout.go index 66bc498..e207b87 100644 --- a/server/handlers/logout.go +++ b/server/handlers/logout.go @@ -9,7 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // Handler to logout user @@ -37,7 +37,10 @@ func LogoutHandler() gin.HandlerFunc { fingerPrint := string(decryptedFingerPrint) - sessionstore.RemoveState(fingerPrint) + err = memorystore.Provider.RemoveState(fingerPrint) + if err != nil { + log.Debug("Failed to remove state: ", err) + } cookie.DeleteSession(gc) if redirectURL != "" { diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index 07347c7..d384040 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -19,9 +19,8 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/oauth" - "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -32,12 +31,12 @@ func OAuthCallbackHandler() gin.HandlerFunc { provider := c.Param("oauth_provider") state := c.Request.FormValue("state") - sessionState := sessionstore.GetState(state) - if sessionState == "" { + sessionState, err := memorystore.Provider.GetState(state) + if sessionState == "" || err != nil { log.Debug("Invalid oauth state: ", state) c.JSON(400, gin.H{"error": "invalid oauth state"}) } - sessionstore.GetState(state) + memorystore.Provider.GetState(state) // contains random token, redirect url, role sessionSplit := strings.Split(state, "___") @@ -52,7 +51,6 @@ func OAuthCallbackHandler() gin.HandlerFunc { inputRoles := strings.Split(sessionSplit[2], ",") scopes := strings.Split(sessionSplit[3], ",") - var err error user := models.User{} code := c.Request.FormValue("code") switch provider { @@ -77,7 +75,13 @@ func OAuthCallbackHandler() gin.HandlerFunc { log := log.WithField("user", user.Email) if err != nil { - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) { + isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) + if err != nil { + log.Debug("Failed to get signup disabled env variable: ", err) + c.JSON(400, gin.H{"error": err.Error()}) + return + } + if isSignupDisabled { log.Debug("Failed to signup as disabled") c.JSON(400, gin.H{"error": "signup is disabled for this instance"}) return @@ -87,7 +91,15 @@ func OAuthCallbackHandler() gin.HandlerFunc { // make sure inputRoles don't include protected roles hasProtectedRole := false for _, ir := range inputRoles { - if utils.StringSliceContains(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles), ir) { + protectedRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyProtectedRoles) + protectedRoles := []string{} + if err != nil { + log.Debug("Failed to get protected roles: ", err) + protectedRolesString = "" + } else { + protectedRoles = strings.Split(protectedRolesString, ",") + } + if utils.StringSliceContains(protectedRoles, ir) { hasProtectedRole = true } } @@ -140,7 +152,15 @@ func OAuthCallbackHandler() gin.HandlerFunc { // check if it contains protected unassigned role hasProtectedRole := false for _, ur := range unasignedRoles { - if utils.StringSliceContains(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles), ur) { + protectedRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyProtectedRoles) + protectedRoles := []string{} + if err != nil { + log.Debug("Failed to get protected roles: ", err) + protectedRolesString = "" + } else { + protectedRoles = strings.Split(protectedRolesString, ",") + } + if utils.StringSliceContains(protectedRoles, ur) { hasProtectedRole = true } } @@ -178,12 +198,12 @@ func OAuthCallbackHandler() gin.HandlerFunc { params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + stateValue + "&id_token=" + authToken.IDToken.Token cookie.SetSession(c, authToken.FingerPrintHash) - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) if authToken.RefreshToken != nil { params = params + `&refresh_token=` + authToken.RefreshToken.Token - sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) } go db.Provider.AddSession(models.Session{ diff --git a/server/handlers/oauth_login.go b/server/handlers/oauth_login.go index 3dc3351..ca8e628 100644 --- a/server/handlers/oauth_login.go +++ b/server/handlers/oauth_login.go @@ -8,16 +8,16 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/oauth" - "github.com/authorizerdev/authorizer/server/sessionstore" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/parsers" + "github.com/authorizerdev/authorizer/server/validators" ) // OAuthLoginHandler set host in the oauth state that is useful for redirecting to oauth_callback func OAuthLoginHandler() gin.HandlerFunc { return func(c *gin.Context) { - hostname := utils.GetHost(c) + hostname := parsers.GetHost(c) // deprecating redirectURL instead use redirect_uri redirectURI := strings.TrimSpace(c.Query("redirectURL")) if redirectURI == "" { @@ -56,7 +56,25 @@ func OAuthLoginHandler() gin.HandlerFunc { // use protected roles verification for admin login only. // though if not associated with user, it will be rejected from oauth_callback - if !utils.IsValidRoles(rolesSplit, append([]string{}, append(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyRoles), envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles)...)...)) { + rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) + roles := []string{} + if err != nil { + log.Debug("Error getting roles: ", err) + rolesString = "" + } else { + roles = strings.Split(rolesString, ",") + } + + protectedRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyProtectedRoles) + protectedRoles := []string{} + if err != nil { + log.Debug("Error getting protected roles: ", err) + protectedRolesString = "" + } else { + protectedRoles = strings.Split(protectedRolesString, ",") + } + + if !validators.IsValidRoles(rolesSplit, append([]string{}, append(roles, protectedRoles...)...)) { log.Debug("Invalid roles: ", roles) c.JSON(400, gin.H{ "error": "invalid role", @@ -64,7 +82,16 @@ func OAuthLoginHandler() gin.HandlerFunc { return } } else { - roles = strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ",") + defaultRoles, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + log.Debug("Error getting default roles: ", err) + c.JSON(400, gin.H{ + "error": "invalid role", + }) + return + } + roles = defaultRoles + } oauthStateString := state + "___" + redirectURI + "___" + roles + "___" + strings.Join(scope, ",") @@ -78,7 +105,14 @@ func OAuthLoginHandler() gin.HandlerFunc { isProviderConfigured = false break } - sessionstore.SetState(oauthStateString, constants.SignupMethodGoogle) + err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGoogle) + if err != nil { + log.Debug("Error setting state: ", err) + c.JSON(500, gin.H{ + "error": "internal server error", + }) + return + } // during the init of OAuthProvider authorizer url might be empty oauth.OAuthProviders.GoogleConfig.RedirectURL = hostname + "/oauth_callback/google" url := oauth.OAuthProviders.GoogleConfig.AuthCodeURL(oauthStateString) @@ -89,7 +123,14 @@ func OAuthLoginHandler() gin.HandlerFunc { isProviderConfigured = false break } - sessionstore.SetState(oauthStateString, constants.SignupMethodGithub) + err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodGithub) + if err != nil { + log.Debug("Error setting state: ", err) + c.JSON(500, gin.H{ + "error": "internal server error", + }) + return + } oauth.OAuthProviders.GithubConfig.RedirectURL = hostname + "/oauth_callback/github" url := oauth.OAuthProviders.GithubConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) @@ -99,7 +140,14 @@ func OAuthLoginHandler() gin.HandlerFunc { isProviderConfigured = false break } - sessionstore.SetState(oauthStateString, constants.SignupMethodFacebook) + err := memorystore.Provider.SetState(oauthStateString, constants.SignupMethodFacebook) + if err != nil { + log.Debug("Error setting state: ", err) + c.JSON(500, gin.H{ + "error": "internal server error", + }) + return + } oauth.OAuthProviders.FacebookConfig.RedirectURL = hostname + "/oauth_callback/facebook" url := oauth.OAuthProviders.FacebookConfig.AuthCodeURL(oauthStateString) c.Redirect(http.StatusTemporaryRedirect, url) diff --git a/server/handlers/openid_config.go b/server/handlers/openid_config.go index 5b98d03..781caf1 100644 --- a/server/handlers/openid_config.go +++ b/server/handlers/openid_config.go @@ -4,15 +4,15 @@ import ( "github.com/gin-gonic/gin" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" ) // OpenIDConfigurationHandler handler for open-id configurations func OpenIDConfigurationHandler() gin.HandlerFunc { return func(c *gin.Context) { - issuer := utils.GetHost(c) - jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + issuer := parsers.GetHost(c) + jwtType, _ := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) c.JSON(200, gin.H{ "issuer": issuer, diff --git a/server/handlers/revoke.go b/server/handlers/revoke.go index f6d2bfc..9cc5b07 100644 --- a/server/handlers/revoke.go +++ b/server/handlers/revoke.go @@ -8,8 +8,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // Revoke handler to revoke refresh token @@ -37,7 +36,7 @@ func RevokeHandler() gin.HandlerFunc { return } - if clientID != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + if client, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID); client != clientID || err != nil { log.Debug("Client ID is invalid: ", clientID) gc.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_client_id", @@ -46,7 +45,7 @@ func RevokeHandler() gin.HandlerFunc { return } - sessionstore.RemoveState(refreshToken) + memorystore.Provider.RemoveState(refreshToken) gc.JSON(http.StatusOK, gin.H{ "message": "Token revoked successfully", diff --git a/server/handlers/token.go b/server/handlers/token.go index 895a672..4bcbe83 100644 --- a/server/handlers/token.go +++ b/server/handlers/token.go @@ -13,8 +13,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" ) @@ -62,7 +61,7 @@ func TokenHandler() gin.HandlerFunc { return } - if clientID != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + if client, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID); clientID != client || err != nil { log.Debug("Client ID is invalid: ", clientID) gc.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_client_id", @@ -98,8 +97,8 @@ func TokenHandler() gin.HandlerFunc { encryptedCode := strings.ReplaceAll(base64.URLEncoding.EncodeToString(hash.Sum(nil)), "+", "-") encryptedCode = strings.ReplaceAll(encryptedCode, "/", "_") encryptedCode = strings.ReplaceAll(encryptedCode, "=", "") - sessionData := sessionstore.GetState(encryptedCode) - if sessionData == "" { + sessionData, err := memorystore.Provider.GetState(encryptedCode) + if sessionData == "" || err != nil { log.Debug("Session data is empty") gc.JSON(http.StatusBadRequest, gin.H{ "error": "invalid_code_verifier", @@ -132,7 +131,7 @@ func TokenHandler() gin.HandlerFunc { return } // rollover the session for security - sessionstore.RemoveState(sessionDataSplit[1]) + memorystore.Provider.RemoveState(sessionDataSplit[1]) userID = claims.Subject roles = claims.Roles scope = claims.Scope @@ -164,7 +163,7 @@ func TokenHandler() gin.HandlerFunc { scope = append(scope, v.(string)) } // remove older refresh token and rotate it for security - sessionstore.RemoveState(refreshToken) + memorystore.Provider.RemoveState(refreshToken) } user, err := db.Provider.GetUserByID(userID) @@ -186,8 +185,8 @@ func TokenHandler() gin.HandlerFunc { }) return } - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) cookie.SetSession(gc, authToken.FingerPrintHash) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() @@ -205,7 +204,7 @@ func TokenHandler() gin.HandlerFunc { if authToken.RefreshToken != nil { res["refresh_token"] = authToken.RefreshToken.Token - sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) } gc.JSON(http.StatusOK, res) diff --git a/server/handlers/verify_email.go b/server/handlers/verify_email.go index 0d34b7d..7adb672 100644 --- a/server/handlers/verify_email.go +++ b/server/handlers/verify_email.go @@ -12,7 +12,8 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -40,7 +41,7 @@ func VerifyEmailHandler() gin.HandlerFunc { } // verify if token exists in db - hostname := utils.GetHost(c) + hostname := parsers.GetHost(c) claim, err := token.ParseJWTToken(tokenInQuery, hostname, verificationRequest.Nonce, verificationRequest.Email) if err != nil { log.Debug("Error parsing token: ", err) @@ -99,12 +100,12 @@ func VerifyEmailHandler() gin.HandlerFunc { params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + state + "&id_token=" + authToken.IDToken.Token cookie.SetSession(c, authToken.FingerPrintHash) - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) if authToken.RefreshToken != nil { params = params + `&refresh_token=${refresh_token}` - sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) } if redirectURL == "" { diff --git a/server/main.go b/server/main.go index 347dbae..e00500a 100644 --- a/server/main.go +++ b/server/main.go @@ -6,13 +6,13 @@ import ( "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" + "github.com/authorizerdev/authorizer/server/cli" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/env" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/routes" - "github.com/authorizerdev/authorizer/server/sessionstore" ) var VERSION string @@ -27,23 +27,22 @@ func (u LogUTCFormatter) Format(e *log.Entry) ([]byte, error) { } func main() { - envstore.ARG_DB_URL = flag.String("database_url", "", "Database connection string") - envstore.ARG_DB_TYPE = flag.String("database_type", "", "Database type, possible values are postgres,mysql,sqlite") - envstore.ARG_ENV_FILE = flag.String("env_file", "", "Env file path") - envstore.ARG_LOG_LEVEL = flag.String("log_level", "info", "Log level, possible values are debug,info,warn,error,fatal,panic") + cli.ARG_DB_URL = flag.String("database_url", "", "Database connection string") + cli.ARG_DB_TYPE = flag.String("database_type", "", "Database type, possible values are postgres,mysql,sqlite") + cli.ARG_ENV_FILE = flag.String("env_file", "", "Env file path") + cli.ARG_LOG_LEVEL = flag.String("log_level", "info", "Log level, possible values are debug,info,warn,error,fatal,panic") + cli.ARG_REDIS_URL = flag.String("redis_url", "", "Redis connection string") flag.Parse() // global log level logrus.SetFormatter(LogUTCFormatter{&logrus.JSONFormatter{}}) - logrus.SetReportCaller(true) // log instance for gin server log := logrus.New() log.SetFormatter(LogUTCFormatter{&logrus.JSONFormatter{}}) - log.SetReportCaller(true) var logLevel logrus.Level - switch *envstore.ARG_LOG_LEVEL { + switch *cli.ARG_LOG_LEVEL { case "debug": logLevel = logrus.DebugLevel case "info": @@ -62,14 +61,26 @@ func main() { logrus.SetLevel(logLevel) log.SetLevel(logLevel) + // show file path in log for debug or other log levels. + if logLevel != logrus.InfoLevel { + logrus.SetReportCaller(true) + log.SetReportCaller(true) + } + constants.VERSION = VERSION - // initialize required envs (mainly db & env file path) - err := env.InitRequiredEnv() + // initialize required envs (mainly db, env file path and redis) + err := memorystore.InitRequiredEnv() if err != nil { log.Fatal("Error while initializing required envs: ", err) } + // initialize memory store + err = memorystore.InitMemStore() + if err != nil { + log.Fatal("Error while initializing memory store: ", err) + } + // initialize db provider err = db.InitDB() if err != nil { @@ -89,12 +100,6 @@ func main() { log.Fatalln("Error while persisting env: ", err) } - // initialize session store (redis or in-memory based on env) - err = sessionstore.InitSession() - if err != nil { - log.Fatalln("Error while initializing session store: ", err) - } - // initialize oauth providers based on env err = oauth.InitOAuth() if err != nil { @@ -103,5 +108,11 @@ func main() { router := routes.InitRouter(log) log.Info("Starting Authorizer: ", VERSION) - router.Run(":" + envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyPort)) + port, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyPort) + if err != nil { + log.Info("Error while getting port from env using default port 8080: ", err) + port = "8080" + } + + router.Run(":" + port) } diff --git a/server/memorystore/memory_store.go b/server/memorystore/memory_store.go new file mode 100644 index 0000000..df4091b --- /dev/null +++ b/server/memorystore/memory_store.go @@ -0,0 +1,76 @@ +package memorystore + +import ( + "encoding/json" + + log "github.com/sirupsen/logrus" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/memorystore/providers" + "github.com/authorizerdev/authorizer/server/memorystore/providers/inmemory" + "github.com/authorizerdev/authorizer/server/memorystore/providers/redis" +) + +// Provider returns the current database provider +var Provider providers.Provider + +// InitMemStore initializes the memory store +func InitMemStore() error { + var err error + + defaultEnvs := map[string]interface{}{ + // string envs + constants.EnvKeyJwtRoleClaim: "role", + constants.EnvKeyOrganizationName: "Authorizer", + constants.EnvKeyOrganizationLogo: "https://www.authorizer.dev/images/logo.png", + + // boolean envs + constants.EnvKeyDisableBasicAuthentication: false, + constants.EnvKeyDisableMagicLinkLogin: false, + constants.EnvKeyDisableEmailVerification: false, + constants.EnvKeyDisableLoginPage: false, + constants.EnvKeyDisableSignUp: false, + } + + requiredEnvs := RequiredEnvStoreObj.GetRequiredEnv() + requiredEnvMap := make(map[string]interface{}) + requiredEnvBytes, err := json.Marshal(requiredEnvs) + if err != nil { + log.Debug("Error while marshalling required envs: ", err) + return err + } + err = json.Unmarshal(requiredEnvBytes, &requiredEnvMap) + if err != nil { + log.Debug("Error while unmarshalling required envs: ", err) + return err + } + + // merge default envs with required envs + for key, val := range requiredEnvMap { + defaultEnvs[key] = val + } + + redisURL := requiredEnvs.RedisURL + if redisURL != "" && !requiredEnvs.disableRedisForEnv { + log.Info("Initializing Redis memory store") + Provider, err = redis.NewRedisProvider(redisURL) + if err != nil { + return err + } + + // set default envs in redis + Provider.UpdateEnvStore(defaultEnvs) + + return nil + } + + log.Info("using in memory store to save sessions") + // if redis url is not set use in memory store + Provider, err = inmemory.NewInMemoryProvider() + if err != nil { + return err + } + // set default envs in local env + Provider.UpdateEnvStore(defaultEnvs) + return nil +} diff --git a/server/memorystore/providers/inmemory/envstore.go b/server/memorystore/providers/inmemory/envstore.go new file mode 100644 index 0000000..ce49d59 --- /dev/null +++ b/server/memorystore/providers/inmemory/envstore.go @@ -0,0 +1,53 @@ +package inmemory + +import ( + "os" + "sync" +) + +// EnvStore struct to store the env variables +type EnvStore struct { + mutex sync.Mutex + store map[string]interface{} +} + +// UpdateEnvStore to update the whole env store object +func (e *EnvStore) UpdateStore(store map[string]interface{}) { + if os.Getenv("ENV") != "test" { + e.mutex.Lock() + defer e.mutex.Unlock() + } + // just override the keys + new keys + + for key, value := range store { + e.store[key] = value + } +} + +// GetStore returns the env store +func (e *EnvStore) GetStore() map[string]interface{} { + if os.Getenv("ENV") != "test" { + e.mutex.Lock() + defer e.mutex.Unlock() + } + + return e.store +} + +// Get returns the value of the key in evn store +func (e *EnvStore) Get(key string) interface{} { + if os.Getenv("ENV") != "test" { + e.mutex.Lock() + defer e.mutex.Unlock() + } + return e.store[key] +} + +// Set sets the value of the key in env store +func (e *EnvStore) Set(key string, value interface{}) { + if os.Getenv("ENV") != "test" { + e.mutex.Lock() + defer e.mutex.Unlock() + } + e.store[key] = value +} diff --git a/server/memorystore/providers/inmemory/provider.go b/server/memorystore/providers/inmemory/provider.go new file mode 100644 index 0000000..0dec662 --- /dev/null +++ b/server/memorystore/providers/inmemory/provider.go @@ -0,0 +1,25 @@ +package inmemory + +import ( + "sync" +) + +type provider struct { + mutex sync.Mutex + sessionStore map[string]map[string]string + stateStore map[string]string + envStore *EnvStore +} + +// NewInMemoryStore returns a new in-memory store. +func NewInMemoryProvider() (*provider, error) { + return &provider{ + mutex: sync.Mutex{}, + sessionStore: map[string]map[string]string{}, + stateStore: map[string]string{}, + envStore: &EnvStore{ + mutex: sync.Mutex{}, + store: map[string]interface{}{}, + }, + }, nil +} diff --git a/server/memorystore/providers/inmemory/store.go b/server/memorystore/providers/inmemory/store.go new file mode 100644 index 0000000..4d74c2b --- /dev/null +++ b/server/memorystore/providers/inmemory/store.go @@ -0,0 +1,121 @@ +package inmemory + +import ( + "fmt" + "os" + "strings" +) + +// ClearStore clears the in-memory store. +func (c *provider) ClearStore() error { + if os.Getenv("ENV") != "test" { + c.mutex.Lock() + defer c.mutex.Unlock() + } + c.sessionStore = map[string]map[string]string{} + + return nil +} + +// GetUserSessions returns all the user session token from the in-memory store. +func (c *provider) GetUserSessions(userId string) map[string]string { + if os.Getenv("ENV") != "test" { + c.mutex.Lock() + defer c.mutex.Unlock() + } + res := map[string]string{} + for k, v := range c.stateStore { + split := strings.Split(v, "@") + if split[1] == userId { + res[k] = split[0] + } + } + + return res +} + +// DeleteAllUserSession deletes all the user sessions from in-memory store. +func (c *provider) DeleteAllUserSession(userId string) error { + if os.Getenv("ENV") != "test" { + c.mutex.Lock() + defer c.mutex.Unlock() + } + sessions := c.GetUserSessions(userId) + for k := range sessions { + c.RemoveState(k) + } + + return nil +} + +// SetState sets the state in the in-memory store. +func (c *provider) SetState(key, state string) error { + if os.Getenv("ENV") != "test" { + c.mutex.Lock() + defer c.mutex.Unlock() + } + c.stateStore[key] = state + + return nil +} + +// GetState gets the state from the in-memory store. +func (c *provider) GetState(key string) (string, error) { + if os.Getenv("ENV") != "test" { + c.mutex.Lock() + defer c.mutex.Unlock() + } + + state := "" + if stateVal, ok := c.stateStore[key]; ok { + state = stateVal + } + + return state, nil +} + +// RemoveState removes the state from the in-memory store. +func (c *provider) RemoveState(key string) error { + if os.Getenv("ENV") != "test" { + c.mutex.Lock() + defer c.mutex.Unlock() + } + delete(c.stateStore, key) + + return nil +} + +// UpdateEnvStore to update the whole env store object +func (c *provider) UpdateEnvStore(store map[string]interface{}) error { + c.envStore.UpdateStore(store) + return nil +} + +// GetEnvStore returns the env store object +func (c *provider) GetEnvStore() (map[string]interface{}, error) { + return c.envStore.GetStore(), nil +} + +// UpdateEnvVariable to update the particular env variable +func (c *provider) UpdateEnvVariable(key string, value interface{}) error { + c.envStore.Set(key, value) + return nil +} + +// GetStringStoreEnvVariable to get the env variable from string store object +func (c *provider) GetStringStoreEnvVariable(key string) (string, error) { + res := c.envStore.Get(key) + if res == nil { + return "", nil + } + return fmt.Sprintf("%v", res), nil +} + +// GetBoolStoreEnvVariable to get the env variable from bool store object +func (c *provider) GetBoolStoreEnvVariable(key string) (bool, error) { + res := c.envStore.Get(key) + if res == nil { + return false, nil + } + return res.(bool), nil +} diff --git a/server/memorystore/providers/providers.go b/server/memorystore/providers/providers.go new file mode 100644 index 0000000..b9730f3 --- /dev/null +++ b/server/memorystore/providers/providers.go @@ -0,0 +1,30 @@ +package providers + +// Provider defines current memory store provider +type Provider interface { + // DeleteAllSessions deletes all the sessions from the session store + DeleteAllUserSession(userId string) error + // GetUserSessions returns all the user sessions from the session store + GetUserSessions(userId string) map[string]string + // ClearStore clears the session store for authorizer tokens + ClearStore() error + // SetState sets the login state (key, value form) in the session store + SetState(key, state string) error + // GetState returns the state from the session store + GetState(key string) (string, error) + // RemoveState removes the social login state from the session store + RemoveState(key string) error + + // methods for env store + + // UpdateEnvStore to update the whole env store object + UpdateEnvStore(store map[string]interface{}) error + // GetEnvStore() returns the env store object + GetEnvStore() (map[string]interface{}, error) + // UpdateEnvVariable to update the particular env variable + UpdateEnvVariable(key string, value interface{}) error + // GetStringStoreEnvVariable to get the string env variable from env store + GetStringStoreEnvVariable(key string) (string, error) + // GetBoolStoreEnvVariable to get the bool env variable from env store + GetBoolStoreEnvVariable(key string) (bool, error) +} diff --git a/server/memorystore/providers/redis/provider.go b/server/memorystore/providers/redis/provider.go new file mode 100644 index 0000000..a91a300 --- /dev/null +++ b/server/memorystore/providers/redis/provider.go @@ -0,0 +1,78 @@ +package redis + +import ( + "context" + "strings" + "time" + + "github.com/go-redis/redis/v8" + log "github.com/sirupsen/logrus" +) + +// RedisClient is the interface for redis client & redis cluster client +type RedisClient interface { + HMSet(ctx context.Context, key string, values ...interface{}) *redis.BoolCmd + Del(ctx context.Context, keys ...string) *redis.IntCmd + HDel(ctx context.Context, key string, fields ...string) *redis.IntCmd + HMGet(ctx context.Context, key string, fields ...string) *redis.SliceCmd + HSet(ctx context.Context, key string, values ...interface{}) *redis.IntCmd + HGet(ctx context.Context, key, field string) *redis.StringCmd + HGetAll(ctx context.Context, key string) *redis.StringStringMapCmd + Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd + Get(ctx context.Context, key string) *redis.StringCmd + Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd +} + +type provider struct { + ctx context.Context + store RedisClient +} + +// NewRedisProvider returns a new redis provider +func NewRedisProvider(redisURL string) (*provider, error) { + redisURLHostPortsList := strings.Split(redisURL, ",") + + if len(redisURLHostPortsList) > 1 { + opt, err := redis.ParseURL(redisURLHostPortsList[0]) + if err != nil { + log.Debug("error parsing redis url: ", err) + return nil, err + } + urls := []string{opt.Addr} + urlList := redisURLHostPortsList[1:] + urls = append(urls, urlList...) + clusterOpt := &redis.ClusterOptions{Addrs: urls} + + rdb := redis.NewClusterClient(clusterOpt) + ctx := context.Background() + _, err = rdb.Ping(ctx).Result() + if err != nil { + log.Debug("error connecting to redis: ", err) + return nil, err + } + + return &provider{ + ctx: ctx, + store: rdb, + }, nil + } + + opt, err := redis.ParseURL(redisURL) + if err != nil { + log.Debug("error parsing redis url: ", err) + return nil, err + } + + rdb := redis.NewClient(opt) + ctx := context.Background() + _, err = rdb.Ping(ctx).Result() + if err != nil { + log.Debug("error connecting to redis: ", err) + return nil, err + } + + return &provider{ + ctx: ctx, + store: rdb, + }, nil +} diff --git a/server/memorystore/providers/redis/store.go b/server/memorystore/providers/redis/store.go new file mode 100644 index 0000000..43dd761 --- /dev/null +++ b/server/memorystore/providers/redis/store.go @@ -0,0 +1,158 @@ +package redis + +import ( + "strconv" + "strings" + + "github.com/authorizerdev/authorizer/server/constants" + log "github.com/sirupsen/logrus" +) + +var ( + // session store prefix + sessionStorePrefix = "authorizer_session:" + // env store prefix + envStorePrefix = "authorizer_env" +) + +// ClearStore clears the redis store for authorizer related tokens +func (c *provider) ClearStore() error { + err := c.store.Del(c.ctx, sessionStorePrefix+"*").Err() + if err != nil { + log.Debug("Error clearing redis store: ", err) + return err + } + + return nil +} + +// GetUserSessions returns all the user session token from the redis store. +func (c *provider) GetUserSessions(userID string) map[string]string { + data, err := c.store.HGetAll(c.ctx, "*").Result() + if err != nil { + log.Debug("error getting token from redis store: ", err) + } + + res := map[string]string{} + for k, v := range data { + split := strings.Split(v, "@") + if split[1] == userID { + res[k] = split[0] + } + } + + return res +} + +// DeleteAllUserSession deletes all the user session from redis +func (c *provider) DeleteAllUserSession(userId string) error { + sessions := c.GetUserSessions(userId) + for k, v := range sessions { + if k == "token" { + err := c.store.Del(c.ctx, v).Err() + if err != nil { + log.Debug("Error deleting redis token: ", err) + return err + } + } + } + + return nil +} + +// SetState sets the state in redis store. +func (c *provider) SetState(key, value string) error { + err := c.store.Set(c.ctx, sessionStorePrefix+key, value, 0).Err() + if err != nil { + log.Debug("Error saving redis token: ", err) + return err + } + + return nil +} + +// GetState gets the state from redis store. +func (c *provider) GetState(key string) (string, error) { + var res string + err := c.store.Get(c.ctx, sessionStorePrefix+key).Scan(&res) + if err != nil { + log.Debug("error getting token from redis store: ", err) + } + + return res, err +} + +// RemoveState removes the state from redis store. +func (c *provider) RemoveState(key string) error { + err := c.store.Del(c.ctx, sessionStorePrefix+key).Err() + if err != nil { + log.Fatalln("Error deleting redis token: ", err) + return err + } + + return nil +} + +// UpdateEnvStore to update the whole env store object +func (c *provider) UpdateEnvStore(store map[string]interface{}) error { + for key, value := range store { + err := c.store.HSet(c.ctx, envStorePrefix, key, value).Err() + if err != nil { + return err + } + } + return nil +} + +// GetEnvStore returns the whole env store object +func (c *provider) GetEnvStore() (map[string]interface{}, error) { + res := make(map[string]interface{}) + data, err := c.store.HGetAll(c.ctx, envStorePrefix).Result() + if err != nil { + return nil, err + } + for key, value := range data { + if key == constants.EnvKeyDisableBasicAuthentication || key == constants.EnvKeyDisableEmailVerification || key == constants.EnvKeyDisableLoginPage || key == constants.EnvKeyDisableMagicLinkLogin || key == constants.EnvKeyDisableRedisForEnv || key == constants.EnvKeyDisableSignUp { + boolValue, err := strconv.ParseBool(value) + if err != nil { + return res, err + } + res[key] = boolValue + } else { + res[key] = value + } + } + return res, nil +} + +// UpdateEnvVariable to update the particular env variable +func (c *provider) UpdateEnvVariable(key string, value interface{}) error { + err := c.store.HSet(c.ctx, envStorePrefix, key, value).Err() + if err != nil { + log.Debug("Error saving redis token: ", err) + return err + } + return nil +} + +// GetStringStoreEnvVariable to get the string env variable from env store +func (c *provider) GetStringStoreEnvVariable(key string) (string, error) { + var res string + err := c.store.HGet(c.ctx, envStorePrefix, key).Scan(&res) + if err != nil { + return "", nil + } + + return res, nil +} + +// GetBoolStoreEnvVariable to get the bool env variable from env store +func (c *provider) GetBoolStoreEnvVariable(key string) (bool, error) { + var res bool + err := c.store.HGet(c.ctx, envStorePrefix, key).Scan(res) + if err != nil { + return false, nil + } + + return res, nil +} diff --git a/server/memorystore/required_env_store.go b/server/memorystore/required_env_store.go new file mode 100644 index 0000000..a5f3a81 --- /dev/null +++ b/server/memorystore/required_env_store.go @@ -0,0 +1,149 @@ +package memorystore + +import ( + "errors" + "os" + "strings" + "sync" + + "github.com/joho/godotenv" + log "github.com/sirupsen/logrus" + + "github.com/authorizerdev/authorizer/server/cli" + "github.com/authorizerdev/authorizer/server/constants" +) + +// RequiredEnv holds information about required envs +type RequiredEnv struct { + EnvPath string `json:"ENV_PATH"` + DatabaseURL string `json:"DATABASE_URL"` + DatabaseType string `json:"DATABASE_TYPE"` + DatabaseName string `json:"DATABASE_NAME"` + DatabaseHost string `json:"DATABASE_HOST"` + DatabasePort string `json:"DATABASE_PORT"` + DatabaseUsername string `json:"DATABASE_USERNAME"` + DatabasePassword string `json:"DATABASE_PASSWORD"` + DatabaseCert string `json:"DATABASE_CERT"` + DatabaseCertKey string `json:"DATABASE_CERT_KEY"` + DatabaseCACert string `json:"DATABASE_CA_CERT"` + RedisURL string `json:"REDIS_URL"` + disableRedisForEnv bool `json:"DISABLE_REDIS_FOR_ENV"` +} + +// RequiredEnvObj is a simple in-memory store for sessions. +type RequiredEnvStore struct { + mutex sync.Mutex + requiredEnv RequiredEnv +} + +// GetRequiredEnv to get required env +func (r *RequiredEnvStore) GetRequiredEnv() RequiredEnv { + r.mutex.Lock() + defer r.mutex.Unlock() + + return r.requiredEnv +} + +// SetRequiredEnv to set required env +func (r *RequiredEnvStore) SetRequiredEnv(requiredEnv RequiredEnv) { + r.mutex.Lock() + defer r.mutex.Unlock() + r.requiredEnv = requiredEnv +} + +var RequiredEnvStoreObj *RequiredEnvStore + +// InitRequiredEnv to initialize EnvData and through error if required env are not present +func InitRequiredEnv() error { + envPath := os.Getenv(constants.EnvKeyEnvPath) + + if envPath == "" { + if envPath == "" { + envPath = `.env` + } + } + + if cli.ARG_ENV_FILE != nil && *cli.ARG_ENV_FILE != "" { + envPath = *cli.ARG_ENV_FILE + } + log.Info("env path: ", envPath) + + err := godotenv.Load(envPath) + if err != nil { + log.Infof("using OS env instead of %s file", envPath) + } + + dbURL := os.Getenv(constants.EnvKeyDatabaseURL) + dbType := os.Getenv(constants.EnvKeyDatabaseType) + dbName := os.Getenv(constants.EnvKeyDatabaseName) + dbPort := os.Getenv(constants.EnvKeyDatabasePort) + dbHost := os.Getenv(constants.EnvKeyDatabaseHost) + dbUsername := os.Getenv(constants.EnvKeyDatabaseUsername) + dbPassword := os.Getenv(constants.EnvKeyDatabasePassword) + dbCert := os.Getenv(constants.EnvKeyDatabaseCert) + dbCertKey := os.Getenv(constants.EnvKeyDatabaseCertKey) + dbCACert := os.Getenv(constants.EnvKeyDatabaseCACert) + redisURL := os.Getenv(constants.EnvKeyRedisURL) + disableRedisForEnv := os.Getenv(constants.EnvKeyDisableRedisForEnv) == "true" + + if strings.TrimSpace(redisURL) == "" { + if cli.ARG_REDIS_URL != nil && *cli.ARG_REDIS_URL != "" { + redisURL = *cli.ARG_REDIS_URL + } + } + + // set default db name for non sql dbs + if dbName == "" { + dbName = "authorizer" + } + + if strings.TrimSpace(dbType) == "" { + if cli.ARG_DB_TYPE != nil && *cli.ARG_DB_TYPE != "" { + dbType = strings.TrimSpace(*cli.ARG_DB_TYPE) + } + + if dbType == "" { + log.Debug("DATABASE_TYPE is not set") + return errors.New("invalid database type. DATABASE_TYPE is empty") + } + } + + if strings.TrimSpace(dbURL) == "" { + if cli.ARG_DB_URL != nil && *cli.ARG_DB_URL != "" { + dbURL = strings.TrimSpace(*cli.ARG_DB_URL) + } + + if dbURL == "" && dbPort == "" && dbHost == "" && dbUsername == "" && dbPassword == "" { + log.Debug("DATABASE_URL is not set") + return errors.New("invalid database url. DATABASE_URL is required") + } + } + + if dbName == "" { + if dbName == "" { + dbName = "authorizer" + } + } + + requiredEnv := RequiredEnv{ + EnvPath: envPath, + DatabaseURL: dbURL, + DatabaseType: dbType, + DatabaseName: dbName, + DatabaseHost: dbHost, + DatabasePort: dbPort, + DatabaseUsername: dbUsername, + DatabasePassword: dbPassword, + DatabaseCert: dbCert, + DatabaseCertKey: dbCertKey, + DatabaseCACert: dbCACert, + RedisURL: redisURL, + disableRedisForEnv: disableRedisForEnv, + } + + RequiredEnvStoreObj = &RequiredEnvStore{ + requiredEnv: requiredEnv, + } + + return nil +} diff --git a/server/middlewares/cors.go b/server/middlewares/cors.go index ca06721..ee3b9c3 100644 --- a/server/middlewares/cors.go +++ b/server/middlewares/cors.go @@ -1,7 +1,7 @@ package middlewares import ( - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" "github.com/gin-gonic/gin" ) @@ -10,7 +10,7 @@ func CORSMiddleware() gin.HandlerFunc { return func(c *gin.Context) { origin := c.Request.Header.Get("Origin") - if utils.IsValidOrigin(origin) { + if validators.IsValidOrigin(origin) { c.Writer.Header().Set("Access-Control-Allow-Origin", origin) } diff --git a/server/oauth/oauth.go b/server/oauth/oauth.go index 3618a9a..27bfa69 100644 --- a/server/oauth/oauth.go +++ b/server/oauth/oauth.go @@ -9,7 +9,7 @@ import ( githubOAuth2 "golang.org/x/oauth2/github" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // OAuthProviders is a struct that contains reference all the OAuth providers @@ -34,32 +34,58 @@ var ( // InitOAuth initializes the OAuth providers based on EnvData func InitOAuth() error { ctx := context.Background() - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret) != "" { + googleClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) + if err != nil { + googleClientID = "" + } + googleClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret) + if err != nil { + googleClientSecret = "" + } + if googleClientID != "" && googleClientSecret != "" { p, err := oidc.NewProvider(ctx, "https://accounts.google.com") if err != nil { return err } OIDCProviders.GoogleOIDC = p OAuthProviders.GoogleConfig = &oauth2.Config{ - ClientID: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID), - ClientSecret: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret), + ClientID: googleClientID, + ClientSecret: googleClientSecret, RedirectURL: "/oauth_callback/google", Endpoint: OIDCProviders.GoogleOIDC.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } } - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret) != "" { + + githubClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID) + if err != nil { + githubClientID = "" + } + githubClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret) + if err != nil { + githubClientSecret = "" + } + if githubClientID != "" && githubClientSecret != "" { OAuthProviders.GithubConfig = &oauth2.Config{ - ClientID: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID), - ClientSecret: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret), + ClientID: githubClientID, + ClientSecret: githubClientSecret, RedirectURL: "/oauth_callback/github", Endpoint: githubOAuth2.Endpoint, } } - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) != "" { + + facebookClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID) + if err != nil { + facebookClientID = "" + } + facebookClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientSecret) + if err != nil { + facebookClientSecret = "" + } + if facebookClientID != "" && facebookClientSecret != "" { OAuthProviders.FacebookConfig = &oauth2.Config{ - ClientID: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID), - ClientSecret: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientSecret), + ClientID: facebookClientID, + ClientSecret: facebookClientSecret, RedirectURL: "/oauth_callback/facebook", Endpoint: facebookOAuth2.Endpoint, Scopes: []string{"public_profile", "email"}, diff --git a/server/utils/urls.go b/server/parsers/url.go similarity index 85% rename from server/utils/urls.go rename to server/parsers/url.go index f97582b..19202c1 100644 --- a/server/utils/urls.go +++ b/server/parsers/url.go @@ -1,11 +1,11 @@ -package utils +package parsers import ( "net/url" "strings" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gin-gonic/gin" ) @@ -19,7 +19,10 @@ func GetHost(c *gin.Context) string { return authorizerURL } - authorizerURL = envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAuthorizerURL) + authorizerURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAuthorizerURL) + if err == nil { + authorizerURL = "" + } if authorizerURL != "" { return authorizerURL } @@ -89,8 +92,8 @@ func GetDomainName(uri string) string { // GetAppURL to get /app/ url if not configured by user func GetAppURL(gc *gin.Context) string { - envAppURL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAppURL) - if envAppURL == "" { + envAppURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAppURL) + if envAppURL == "" || err != nil { envAppURL = GetHost(gc) + "/app" } return envAppURL diff --git a/server/resolvers/admin_login.go b/server/resolvers/admin_login.go index 7de2421..23965bb 100644 --- a/server/resolvers/admin_login.go +++ b/server/resolvers/admin_login.go @@ -9,8 +9,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) @@ -24,7 +24,11 @@ func AdminLoginResolver(ctx context.Context, params model.AdminLoginInput) (*mod return res, err } - adminSecret := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + log.Debug("Error getting admin secret: ", err) + return res, err + } if params.AdminSecret != adminSecret { log.Debug("Admin secret is not correct") return res, fmt.Errorf(`invalid admin secret`) diff --git a/server/resolvers/admin_session.go b/server/resolvers/admin_session.go index 2952844..d5cb8d1 100644 --- a/server/resolvers/admin_session.go +++ b/server/resolvers/admin_session.go @@ -9,8 +9,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -30,7 +30,12 @@ func AdminSessionResolver(ctx context.Context) (*model.Response, error) { return res, fmt.Errorf("unauthorized") } - hashedKey, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + log.Debug("Error getting admin secret: ", err) + return res, fmt.Errorf("unauthorized") + } + hashedKey, err := crypto.EncryptPassword(adminSecret) if err != nil { log.Debug("Failed to encrypt key: ", err) return res, err diff --git a/server/resolvers/admin_signup.go b/server/resolvers/admin_signup.go index 399e95d..a8d0c3f 100644 --- a/server/resolvers/admin_signup.go +++ b/server/resolvers/admin_signup.go @@ -2,7 +2,6 @@ package resolvers import ( "context" - "encoding/json" "fmt" "strings" @@ -12,8 +11,8 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) @@ -39,7 +38,11 @@ func AdminSignupResolver(ctx context.Context, params model.AdminSignupInput) (*m return res, err } - adminSecret := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + log.Debug("Error getting admin secret: ", err) + adminSecret = "" + } if adminSecret != "" { log.Debug("Admin secret is already set") @@ -47,18 +50,11 @@ func AdminSignupResolver(ctx context.Context, params model.AdminSignupInput) (*m return res, err } - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyAdminSecret, params.AdminSecret) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyAdminSecret, params.AdminSecret) // consvert EnvData to JSON - var storeData envstore.Store - - jsonBytes, err := json.Marshal(envstore.EnvStoreObj.GetEnvStoreClone()) + storeData, err := memorystore.Provider.GetEnvStore() if err != nil { - log.Debug("Failed to marshal envstore: ", err) - return res, err - } - - if err := json.Unmarshal(jsonBytes, &storeData); err != nil { - log.Debug("Failed to unmarshal envstore: ", err) + log.Debug("Error getting env store: ", err) return res, err } diff --git a/server/resolvers/delete_user.go b/server/resolvers/delete_user.go index 4fadbfe..df64443 100644 --- a/server/resolvers/delete_user.go +++ b/server/resolvers/delete_user.go @@ -8,7 +8,7 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -38,7 +38,7 @@ func DeleteUserResolver(ctx context.Context, params model.DeleteUserInput) (*mod return res, err } - go sessionstore.DeleteAllUserSession(fmt.Sprintf("%x", user.ID)) + go memorystore.Provider.DeleteAllUserSession(fmt.Sprintf("%x", user.ID)) err = db.Provider.DeleteUser(user) if err != nil { diff --git a/server/resolvers/env.go b/server/resolvers/env.go index c1ddcff..d63aff4 100644 --- a/server/resolvers/env.go +++ b/server/resolvers/env.go @@ -3,12 +3,13 @@ package resolvers import ( "context" "fmt" + "strings" log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -16,7 +17,7 @@ import ( // EnvResolver is a resolver for config query // This is admin only query func EnvResolver(ctx context.Context) (*model.Env, error) { - var res *model.Env + res := &model.Env{} gc, err := utils.GinContextFromContext(ctx) if err != nil { @@ -30,99 +31,124 @@ func EnvResolver(ctx context.Context) (*model.Env, error) { } // get clone of store - store := envstore.EnvStoreObj.GetEnvStoreClone() - accessTokenExpiryTime := store.StringEnv[constants.EnvKeyAccessTokenExpiryTime] - adminSecret := store.StringEnv[constants.EnvKeyAdminSecret] - clientID := store.StringEnv[constants.EnvKeyClientID] - clientSecret := store.StringEnv[constants.EnvKeyClientSecret] - databaseURL := store.StringEnv[constants.EnvKeyDatabaseURL] - databaseName := store.StringEnv[constants.EnvKeyDatabaseName] - databaseType := store.StringEnv[constants.EnvKeyDatabaseType] - databaseUsername := store.StringEnv[constants.EnvKeyDatabaseUsername] - databasePassword := store.StringEnv[constants.EnvKeyDatabasePassword] - databaseHost := store.StringEnv[constants.EnvKeyDatabaseHost] - databasePort := store.StringEnv[constants.EnvKeyDatabasePort] - customAccessTokenScript := store.StringEnv[constants.EnvKeyCustomAccessTokenScript] - smtpHost := store.StringEnv[constants.EnvKeySmtpHost] - smtpPort := store.StringEnv[constants.EnvKeySmtpPort] - smtpUsername := store.StringEnv[constants.EnvKeySmtpUsername] - smtpPassword := store.StringEnv[constants.EnvKeySmtpPassword] - senderEmail := store.StringEnv[constants.EnvKeySenderEmail] - jwtType := store.StringEnv[constants.EnvKeyJwtType] - jwtSecret := store.StringEnv[constants.EnvKeyJwtSecret] - jwtRoleClaim := store.StringEnv[constants.EnvKeyJwtRoleClaim] - jwtPublicKey := store.StringEnv[constants.EnvKeyJwtPublicKey] - jwtPrivateKey := store.StringEnv[constants.EnvKeyJwtPrivateKey] - allowedOrigins := store.SliceEnv[constants.EnvKeyAllowedOrigins] - appURL := store.StringEnv[constants.EnvKeyAppURL] - redisURL := store.StringEnv[constants.EnvKeyRedisURL] - cookieName := store.StringEnv[constants.EnvKeyCookieName] - resetPasswordURL := store.StringEnv[constants.EnvKeyResetPasswordURL] - disableEmailVerification := store.BoolEnv[constants.EnvKeyDisableEmailVerification] - disableBasicAuthentication := store.BoolEnv[constants.EnvKeyDisableBasicAuthentication] - disableMagicLinkLogin := store.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] - disableLoginPage := store.BoolEnv[constants.EnvKeyDisableLoginPage] - disableSignUp := store.BoolEnv[constants.EnvKeyDisableSignUp] - roles := store.SliceEnv[constants.EnvKeyRoles] - defaultRoles := store.SliceEnv[constants.EnvKeyDefaultRoles] - protectedRoles := store.SliceEnv[constants.EnvKeyProtectedRoles] - googleClientID := store.StringEnv[constants.EnvKeyGoogleClientID] - googleClientSecret := store.StringEnv[constants.EnvKeyGoogleClientSecret] - facebookClientID := store.StringEnv[constants.EnvKeyFacebookClientID] - facebookClientSecret := store.StringEnv[constants.EnvKeyFacebookClientSecret] - githubClientID := store.StringEnv[constants.EnvKeyGithubClientID] - githubClientSecret := store.StringEnv[constants.EnvKeyGithubClientSecret] - organizationName := store.StringEnv[constants.EnvKeyOrganizationName] - organizationLogo := store.StringEnv[constants.EnvKeyOrganizationLogo] - - if accessTokenExpiryTime == "" { - accessTokenExpiryTime = "30m" + store, err := memorystore.Provider.GetEnvStore() + if err != nil { + log.Debug("Failed to get env store: ", err) + return res, err } - res = &model.Env{ - AccessTokenExpiryTime: &accessTokenExpiryTime, - AdminSecret: &adminSecret, - DatabaseName: databaseName, - DatabaseURL: databaseURL, - DatabaseType: databaseType, - DatabaseUsername: databaseUsername, - DatabasePassword: databasePassword, - DatabaseHost: databaseHost, - DatabasePort: databasePort, - ClientID: clientID, - ClientSecret: clientSecret, - CustomAccessTokenScript: &customAccessTokenScript, - SMTPHost: &smtpHost, - SMTPPort: &smtpPort, - SMTPPassword: &smtpPassword, - SMTPUsername: &smtpUsername, - SenderEmail: &senderEmail, - JwtType: &jwtType, - JwtSecret: &jwtSecret, - JwtPrivateKey: &jwtPrivateKey, - JwtPublicKey: &jwtPublicKey, - JwtRoleClaim: &jwtRoleClaim, - AllowedOrigins: allowedOrigins, - AppURL: &appURL, - RedisURL: &redisURL, - CookieName: &cookieName, - ResetPasswordURL: &resetPasswordURL, - DisableEmailVerification: &disableEmailVerification, - DisableBasicAuthentication: &disableBasicAuthentication, - DisableMagicLinkLogin: &disableMagicLinkLogin, - DisableLoginPage: &disableLoginPage, - DisableSignUp: &disableSignUp, - Roles: roles, - ProtectedRoles: protectedRoles, - DefaultRoles: defaultRoles, - GoogleClientID: &googleClientID, - GoogleClientSecret: &googleClientSecret, - GithubClientID: &githubClientID, - GithubClientSecret: &githubClientSecret, - FacebookClientID: &facebookClientID, - FacebookClientSecret: &facebookClientSecret, - OrganizationName: &organizationName, - OrganizationLogo: &organizationLogo, + if val, ok := store[constants.EnvKeyAccessTokenExpiryTime]; ok { + res.AccessTokenExpiryTime = utils.NewStringRef(val.(string)) } + if val, ok := store[constants.EnvKeyAdminSecret]; ok { + res.AdminSecret = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyClientID]; ok { + res.ClientID = val.(string) + } + if val, ok := store[constants.EnvKeyClientSecret]; ok { + res.ClientSecret = val.(string) + } + if val, ok := store[constants.EnvKeyDatabaseURL]; ok { + res.DatabaseURL = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyDatabaseName]; ok { + res.DatabaseName = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyDatabaseType]; ok { + res.DatabaseType = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyDatabaseUsername]; ok { + res.DatabaseUsername = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyDatabasePassword]; ok { + res.DatabasePassword = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyDatabaseHost]; ok { + res.DatabaseHost = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyDatabasePort]; ok { + res.DatabasePort = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyCustomAccessTokenScript]; ok { + res.CustomAccessTokenScript = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeySmtpHost]; ok { + res.SMTPHost = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeySmtpPort]; ok { + res.SMTPPort = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeySmtpUsername]; ok { + res.SMTPUsername = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeySmtpPassword]; ok { + res.SMTPPassword = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeySenderEmail]; ok { + res.SenderEmail = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyJwtType]; ok { + res.JwtType = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyJwtSecret]; ok { + res.JwtSecret = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyJwtRoleClaim]; ok { + res.JwtRoleClaim = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyJwtPublicKey]; ok { + res.JwtPublicKey = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyJwtPrivateKey]; ok { + res.JwtPrivateKey = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyAppURL]; ok { + res.AppURL = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyRedisURL]; ok { + res.RedisURL = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyResetPasswordURL]; ok { + res.ResetPasswordURL = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyGoogleClientID]; ok { + res.GoogleClientID = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyGoogleClientSecret]; ok { + res.GoogleClientSecret = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyFacebookClientID]; ok { + res.FacebookClientID = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyFacebookClientSecret]; ok { + res.FacebookClientSecret = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyGithubClientID]; ok { + res.GithubClientID = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyGithubClientSecret]; ok { + res.GithubClientSecret = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyOrganizationName]; ok { + res.OrganizationName = utils.NewStringRef(val.(string)) + } + if val, ok := store[constants.EnvKeyOrganizationLogo]; ok { + res.OrganizationLogo = utils.NewStringRef(val.(string)) + } + + // string slice vars + res.AllowedOrigins = strings.Split(store[constants.EnvKeyAllowedOrigins].(string), ",") + res.Roles = strings.Split(store[constants.EnvKeyRoles].(string), ",") + res.DefaultRoles = strings.Split(store[constants.EnvKeyDefaultRoles].(string), ",") + res.ProtectedRoles = strings.Split(store[constants.EnvKeyProtectedRoles].(string), ",") + + // bool vars + res.DisableEmailVerification = store[constants.EnvKeyDisableEmailVerification].(bool) + res.DisableBasicAuthentication = store[constants.EnvKeyDisableBasicAuthentication].(bool) + res.DisableMagicLinkLogin = store[constants.EnvKeyDisableMagicLinkLogin].(bool) + res.DisableLoginPage = store[constants.EnvKeyDisableLoginPage].(bool) + res.DisableSignUp = store[constants.EnvKeyDisableSignUp].(bool) + return res, nil } diff --git a/server/resolvers/forgot_password.go b/server/resolvers/forgot_password.go index ec7049e..4bcf107 100644 --- a/server/resolvers/forgot_password.go +++ b/server/resolvers/forgot_password.go @@ -12,10 +12,12 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // ForgotPasswordResolver is a resolver for forgot password mutation @@ -28,13 +30,18 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu return res, err } - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) { + isBasicAuthDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) + if err != nil { + log.Debug("Error getting basic auth disabled: ", err) + isBasicAuthDisabled = true + } + if isBasicAuthDisabled { log.Debug("Basic authentication is disabled") return res, fmt.Errorf(`basic authentication is disabled for this instance`) } params.Email = strings.ToLower(params.Email) - if !utils.IsValidEmail(params.Email) { + if !validators.IsValidEmail(params.Email) { log.Debug("Invalid email address: ", params.Email) return res, fmt.Errorf("invalid email") } @@ -48,13 +55,13 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu return res, fmt.Errorf(`user with this email not found`) } - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) _, nonceHash, err := utils.GenerateNonce() if err != nil { log.Debug("Failed to generate nonce: ", err) return res, err } - redirectURL := utils.GetAppURL(gc) + "/reset-password" + redirectURL := parsers.GetAppURL(gc) + "/reset-password" if params.RedirectURI != nil { redirectURL = *params.RedirectURI } diff --git a/server/resolvers/generate_jwt_keys.go b/server/resolvers/generate_jwt_keys.go index 8f0050e..323e006 100644 --- a/server/resolvers/generate_jwt_keys.go +++ b/server/resolvers/generate_jwt_keys.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" log "github.com/sirupsen/logrus" @@ -26,7 +26,11 @@ func GenerateJWTKeysResolver(ctx context.Context, params model.GenerateJWTKeysIn return nil, fmt.Errorf("unauthorized") } - clientID := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + log.Debug("Error getting client id: ", err) + return nil, err + } if crypto.IsHMACA(params.Type) { secret, _, err := crypto.NewHMACKey(params.Type, clientID) if err != nil { diff --git a/server/resolvers/invite_members.go b/server/resolvers/invite_members.go index 30dc79a..c454dc8 100644 --- a/server/resolvers/invite_members.go +++ b/server/resolvers/invite_members.go @@ -13,10 +13,12 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" emailservice "github.com/authorizerdev/authorizer/server/email" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // InviteMembersResolver resolver to invite members @@ -33,12 +35,20 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput) } // this feature is only allowed if email server is configured - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { + isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) + if err != nil { + log.Debug("Error getting email verification disabled: ", err) + isEmailVerificationDisabled = true + } + + if isEmailVerificationDisabled { log.Debug("Email server is not configured") return nil, errors.New("email sending is disabled") } - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) && envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableMagicLinkLogin) { + isBasicAuthDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) + isMagicLinkLoginDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableMagicLinkLogin) + if isBasicAuthDisabled && isMagicLinkLoginDisabled { log.Debug("Basic authentication and Magic link login is disabled.") return nil, errors.New("either basic authentication or magic link login is required") } @@ -46,7 +56,7 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput) // filter valid emails emails := []string{} for _, email := range params.Emails { - if utils.IsValidEmail(email) { + if validators.IsValidEmail(email) { emails = append(emails, email) } } @@ -77,13 +87,22 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput) // invite new emails for _, email := range newEmails { + defaultRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + defaultRoles := []string{} + if err != nil { + log.Debug("Error getting default roles: ", err) + defaultRolesString = "" + } else { + defaultRoles = strings.Split(defaultRolesString, ",") + } + user := models.User{ Email: email, - Roles: strings.Join(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles), ","), + Roles: strings.Join(defaultRoles, ","), } - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) verifyEmailURL := hostname + "/verify_email" - appURL := utils.GetAppURL(gc) + appURL := parsers.GetAppURL(gc) redirectURL := appURL if params.RedirectURI != nil { @@ -109,7 +128,7 @@ func InviteMembersResolver(ctx context.Context, params model.InviteMemberInput) } // use magic link login if that option is on - if !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableMagicLinkLogin) { + if !isMagicLinkLoginDisabled { user.SignupMethods = constants.SignupMethodMagicLinkLogin verificationRequest.Identifier = constants.VerificationTypeMagicLinkLogin } else { diff --git a/server/resolvers/login.go b/server/resolvers/login.go index eda8c9d..54ff030 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -13,11 +13,11 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // LoginResolver is a resolver for login mutation @@ -30,7 +30,13 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes return res, err } - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) { + isBasiAuthDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) + if err != nil { + log.Debug("Error getting basic auth disabled: ", err) + isBasiAuthDisabled = true + } + + if isBasiAuthDisabled { log.Debug("Basic authentication is disabled.") return res, fmt.Errorf(`basic authentication is disabled for this instance`) } @@ -66,10 +72,19 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes log.Debug("Failed to compare password: ", err) return res, fmt.Errorf(`invalid password`) } - roles := envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + + defaultRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + roles := []string{} + if err != nil { + log.Debug("Error getting default roles: ", err) + defaultRolesString = "" + } else { + roles = strings.Split(defaultRolesString, ",") + } + currentRoles := strings.Split(user.Roles, ",") if len(params.Roles) > 0 { - if !utils.IsValidRoles(params.Roles, currentRoles) { + if !validators.IsValidRoles(params.Roles, currentRoles) { log.Debug("Invalid roles: ", params.Roles) return res, fmt.Errorf(`invalid roles`) } @@ -102,12 +117,12 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes } cookie.SetSession(gc, authToken.FingerPrintHash) - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) } go db.Provider.AddSession(models.Session{ diff --git a/server/resolvers/logout.go b/server/resolvers/logout.go index 9683237..2c81f63 100644 --- a/server/resolvers/logout.go +++ b/server/resolvers/logout.go @@ -8,7 +8,7 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/utils" ) @@ -37,7 +37,7 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) { fingerPrint := string(decryptedFingerPrint) - sessionstore.RemoveState(fingerPrint) + memorystore.Provider.RemoveState(fingerPrint) cookie.DeleteSession(gc) res = &model.Response{ diff --git a/server/resolvers/magic_link_login.go b/server/resolvers/magic_link_login.go index d79fc46..713c850 100644 --- a/server/resolvers/magic_link_login.go +++ b/server/resolvers/magic_link_login.go @@ -12,10 +12,12 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // MagicLinkLoginResolver is a resolver for magic link login mutation @@ -28,14 +30,20 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu return res, err } - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableMagicLinkLogin) { + isMagicLinkLoginDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableMagicLinkLogin) + if err != nil { + log.Debug("Error getting magic link login disabled: ", err) + isMagicLinkLoginDisabled = true + } + + if isMagicLinkLoginDisabled { log.Debug("Magic link login is disabled.") return res, fmt.Errorf(`magic link login is disabled for this instance`) } params.Email = strings.ToLower(params.Email) - if !utils.IsValidEmail(params.Email) { + if !validators.IsValidEmail(params.Email) { log.Debug("Invalid email") return res, fmt.Errorf(`invalid email address`) } @@ -53,7 +61,11 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu // find user with email existingUser, err := db.Provider.GetUserByEmail(params.Email) if err != nil { - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) { + isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) + if err != nil { + log.Debug("Error getting signup disabled: ", err) + } + if isSignupDisabled { log.Debug("Signup is disabled.") return res, fmt.Errorf(`signup is disabled for this instance`) } @@ -62,14 +74,28 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu // define roles for new user if len(params.Roles) > 0 { // check if roles exists - if !utils.IsValidRoles(params.Roles, envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyRoles)) { + rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) + roles := []string{} + if err != nil { + log.Debug("Error getting roles: ", err) + return res, err + } else { + roles = strings.Split(rolesString, ",") + } + if !validators.IsValidRoles(params.Roles, roles) { log.Debug("Invalid roles: ", params.Roles) return res, fmt.Errorf(`invalid roles`) } else { inputRoles = params.Roles } } else { - inputRoles = envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + inputRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + log.Debug("Error getting default roles: ", err) + return res, fmt.Errorf(`invalid roles`) + } else { + inputRoles = strings.Split(inputRolesString, ",") + } } user.Roles = strings.Join(inputRoles, ",") @@ -88,7 +114,13 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu // find the unassigned roles if len(params.Roles) <= 0 { - inputRoles = envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + inputRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + log.Debug("Error getting default roles: ", err) + return res, fmt.Errorf(`invalid default roles`) + } else { + inputRoles = strings.Split(inputRolesString, ",") + } } existingRoles := strings.Split(existingUser.Roles, ",") unasignedRoles := []string{} @@ -101,8 +133,16 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu if len(unasignedRoles) > 0 { // check if it contains protected unassigned role hasProtectedRole := false + protectedRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyProtectedRoles) + protectedRoles := []string{} + if err != nil { + log.Debug("Error getting protected roles: ", err) + return res, err + } else { + protectedRoles = strings.Split(protectedRolesString, ",") + } for _, ur := range unasignedRoles { - if utils.StringSliceContains(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles), ur) { + if utils.StringSliceContains(protectedRoles, ur) { hasProtectedRole = true } } @@ -129,8 +169,13 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu } } - hostname := utils.GetHost(gc) - if !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { + hostname := parsers.GetHost(gc) + isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) + if err != nil { + log.Debug("Error getting email verification disabled: ", err) + isEmailVerificationDisabled = true + } + if !isEmailVerificationDisabled { // insert verification request _, nonceHash, err := utils.GenerateNonce() if err != nil { @@ -144,7 +189,7 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu if params.Scope != nil && len(params.Scope) > 0 { redirectURLParams = redirectURLParams + "&scope=" + strings.Join(params.Scope, " ") } - redirectURL := utils.GetAppURL(gc) + redirectURL := parsers.GetAppURL(gc) if params.RedirectURI != nil { redirectURL = *params.RedirectURI } diff --git a/server/resolvers/meta.go b/server/resolvers/meta.go index eab3846..18fe561 100644 --- a/server/resolvers/meta.go +++ b/server/resolvers/meta.go @@ -3,12 +3,90 @@ package resolvers import ( "context" + log "github.com/sirupsen/logrus" + + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/memorystore" ) // MetaResolver is a resolver for meta query func MetaResolver(ctx context.Context) (*model.Meta, error) { - metaInfo := utils.GetMetaInfo() + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return nil, err + } + + googleClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) + if err != nil { + log.Debug("Failed to get Google Client ID from environment variable", err) + googleClientID = "" + } + + googleClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret) + if err != nil { + log.Debug("Failed to get Google Client Secret from environment variable", err) + googleClientSecret = "" + } + + facebookClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID) + if err != nil { + log.Debug("Failed to get Facebook Client ID from environment variable", err) + facebookClientID = "" + } + + facebookClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientSecret) + if err != nil { + log.Debug("Failed to get Facebook Client Secret from environment variable", err) + facebookClientSecret = "" + } + + githubClientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID) + if err != nil { + log.Debug("Failed to get Github Client ID from environment variable", err) + githubClientID = "" + } + + githubClientSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret) + if err != nil { + log.Debug("Failed to get Github Client Secret from environment variable", err) + githubClientSecret = "" + } + + isBasicAuthDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) + if err != nil { + log.Debug("Failed to get Disable Basic Authentication from environment variable", err) + isBasicAuthDisabled = true + } + + isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) + if err != nil { + log.Debug("Failed to get Disable Email Verification from environment variable", err) + isEmailVerificationDisabled = true + } + + isMagicLinkLoginDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableMagicLinkLogin) + if err != nil { + log.Debug("Failed to get Disable Magic Link Login from environment variable", err) + isMagicLinkLoginDisabled = true + } + + isSignUpDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) + if err != nil { + log.Debug("Failed to get Disable Signup from environment variable", err) + isSignUpDisabled = true + } + + metaInfo := model.Meta{ + Version: constants.VERSION, + ClientID: clientID, + IsGoogleLoginEnabled: googleClientID != "" && googleClientSecret != "", + IsGithubLoginEnabled: githubClientID != "" && githubClientSecret != "", + IsFacebookLoginEnabled: facebookClientID != "" && facebookClientSecret != "", + IsBasicAuthenticationEnabled: !isBasicAuthDisabled, + IsEmailVerificationEnabled: !isEmailVerificationDisabled, + IsMagicLinkLoginEnabled: !isMagicLinkLoginDisabled, + IsSignUpEnabled: !isSignUpDisabled, + } return &metaInfo, nil } diff --git a/server/resolvers/resend_verify_email.go b/server/resolvers/resend_verify_email.go index 6ae6f34..79abb4b 100644 --- a/server/resolvers/resend_verify_email.go +++ b/server/resolvers/resend_verify_email.go @@ -12,8 +12,10 @@ import ( "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // ResendVerifyEmailResolver is a resolver for resend verify email mutation @@ -27,12 +29,12 @@ func ResendVerifyEmailResolver(ctx context.Context, params model.ResendVerifyEma } params.Email = strings.ToLower(params.Email) - if !utils.IsValidEmail(params.Email) { + if !validators.IsValidEmail(params.Email) { log.Debug("Invalid email: ", params.Email) return res, fmt.Errorf("invalid email") } - if !utils.IsValidVerificationIdentifier(params.Identifier) { + if !validators.IsValidVerificationIdentifier(params.Identifier) { log.Debug("Invalid verification identifier: ", params.Identifier) return res, fmt.Errorf("invalid identifier") } @@ -49,7 +51,7 @@ func ResendVerifyEmailResolver(ctx context.Context, params model.ResendVerifyEma log.Debug("Failed to delete verification request: ", err) } - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) _, nonceHash, err := utils.GenerateNonce() if err != nil { log.Debug("Failed to generate nonce: ", err) diff --git a/server/resolvers/reset_password.go b/server/resolvers/reset_password.go index 77f1c96..9defd06 100644 --- a/server/resolvers/reset_password.go +++ b/server/resolvers/reset_password.go @@ -11,10 +11,12 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // ResetPasswordResolver is a resolver for reset password mutation @@ -26,7 +28,13 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput) log.Debug("Failed to get GinContext: ", err) return res, err } - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) { + + isBasicAuthDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) + if err != nil { + log.Debug("Error getting basic auth disabled: ", err) + isBasicAuthDisabled = true + } + if isBasicAuthDisabled { log.Debug("Basic authentication is disabled") return res, fmt.Errorf(`basic authentication is disabled for this instance`) } @@ -42,13 +50,13 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput) return res, fmt.Errorf(`passwords don't match`) } - if !utils.IsValidPassword(params.Password) { + if !validators.IsValidPassword(params.Password) { log.Debug("Invalid password") return res, fmt.Errorf(`password is not valid. It needs to be at least 6 characters long and contain at least one number, one uppercase letter, one lowercase letter and one special character`) } // verify if token exists in db - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) claim, err := token.ParseJWTToken(params.Token, hostname, verificationRequest.Nonce, verificationRequest.Email) if err != nil { log.Debug("Failed to parse token: ", err) diff --git a/server/resolvers/revoke.go b/server/resolvers/revoke.go index 1ab1cb9..694e36b 100644 --- a/server/resolvers/revoke.go +++ b/server/resolvers/revoke.go @@ -4,12 +4,12 @@ import ( "context" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) // RevokeResolver resolver to revoke refresh token func RevokeResolver(ctx context.Context, params model.OAuthRevokeInput) (*model.Response, error) { - sessionstore.RemoveState(params.RefreshToken) + memorystore.Provider.RemoveState(params.RefreshToken) return &model.Response{ Message: "Token revoked", }, nil diff --git a/server/resolvers/revoke_access.go b/server/resolvers/revoke_access.go index a7b6ab0..9b24c71 100644 --- a/server/resolvers/revoke_access.go +++ b/server/resolvers/revoke_access.go @@ -9,7 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -47,7 +47,7 @@ func RevokeAccessResolver(ctx context.Context, params model.UpdateAccessInput) ( return res, err } - go sessionstore.DeleteAllUserSession(fmt.Sprintf("%x", user.ID)) + go memorystore.Provider.DeleteAllUserSession(fmt.Sprintf("%x", user.ID)) res = &model.Response{ Message: `user access revoked successfully`, diff --git a/server/resolvers/session.go b/server/resolvers/session.go index 0698b64..89b7e11 100644 --- a/server/resolvers/session.go +++ b/server/resolvers/session.go @@ -11,7 +11,7 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -76,9 +76,9 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod } // rollover the session for security - sessionstore.RemoveState(sessionToken) - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.RemoveState(sessionToken) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) cookie.SetSession(gc, authToken.FingerPrintHash) expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() @@ -96,7 +96,7 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod if authToken.RefreshToken != nil { res.RefreshToken = &authToken.RefreshToken.Token - sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) } return res, nil diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index b8cffce..a649521 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -14,11 +14,12 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // SignupResolver is a resolver for signup mutation @@ -31,12 +32,23 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR return res, err } - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) { + isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp) + if err != nil { + log.Debug("Error getting signup disabled: ", err) + isSignupDisabled = true + } + if isSignupDisabled { log.Debug("Signup is disabled") return res, fmt.Errorf(`signup is disabled for this instance`) } - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) { + isBasicAuthDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) + if err != nil { + log.Debug("Error getting basic auth disabled: ", err) + isBasicAuthDisabled = true + } + + if isBasicAuthDisabled { log.Debug("Basic authentication is disabled") return res, fmt.Errorf(`basic authentication is disabled for this instance`) } @@ -46,14 +58,14 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR return res, fmt.Errorf(`password and confirm password does not match`) } - if !utils.IsValidPassword(params.Password) { + if !validators.IsValidPassword(params.Password) { log.Debug("Invalid password") return res, fmt.Errorf(`password is not valid. It needs to be at least 6 characters long and contain at least one number, one uppercase letter, one lowercase letter and one special character`) } params.Email = strings.ToLower(params.Email) - if !utils.IsValidEmail(params.Email) { + if !validators.IsValidEmail(params.Email) { log.Debug("Invalid email: ", params.Email) return res, fmt.Errorf(`invalid email address`) } @@ -80,14 +92,28 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR if len(params.Roles) > 0 { // check if roles exists - if !utils.IsValidRoles(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyRoles), params.Roles) { + rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) + roles := []string{} + if err != nil { + log.Debug("Error getting roles: ", err) + return res, err + } else { + roles = strings.Split(rolesString, ",") + } + if !validators.IsValidRoles(roles, params.Roles) { log.Debug("Invalid roles: ", params.Roles) return res, fmt.Errorf(`invalid roles`) } else { inputRoles = params.Roles } } else { - inputRoles = envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyDefaultRoles) + inputRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultRoles) + if err != nil { + log.Debug("Error getting default roles: ", err) + return res, err + } else { + inputRoles = strings.Split(inputRolesString, ",") + } } user := models.User{ @@ -132,7 +158,12 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR } user.SignupMethods = constants.SignupMethodBasicAuth - if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { + isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) + if err != nil { + log.Debug("Error getting email verification disabled: ", err) + isEmailVerificationDisabled = true + } + if isEmailVerificationDisabled { now := time.Now().Unix() user.EmailVerifiedAt = &now } @@ -144,8 +175,8 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR roles := strings.Split(user.Roles, ",") userToReturn := user.AsAPIUser() - hostname := utils.GetHost(gc) - if !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { + hostname := parsers.GetHost(gc) + if !isEmailVerificationDisabled { // insert verification request _, nonceHash, err := utils.GenerateNonce() if err != nil { @@ -153,7 +184,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR return res, err } verificationType := constants.VerificationTypeBasicAuthSignup - redirectURL := utils.GetAppURL(gc) + redirectURL := parsers.GetAppURL(gc) if params.RedirectURI != nil { redirectURL = *params.RedirectURI } @@ -194,7 +225,7 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR return res, err } - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) cookie.SetSession(gc, authToken.FingerPrintHash) go db.Provider.AddSession(models.Session{ UserID: user.ID, diff --git a/server/resolvers/update_env.go b/server/resolvers/update_env.go index be298f9..d7023a0 100644 --- a/server/resolvers/update_env.go +++ b/server/resolvers/update_env.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "reflect" + "strings" log "github.com/sirupsen/logrus" @@ -13,10 +14,9 @@ import ( "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/oauth" - "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -37,10 +37,14 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model return res, fmt.Errorf("unauthorized") } - updatedData := envstore.EnvStoreObj.GetEnvStoreClone() + updatedData, err := memorystore.Provider.GetEnvStore() + if err != nil { + log.Debug("Failed to get env store: ", err) + return res, err + } isJWTUpdated := false - algo := updatedData.StringEnv[constants.EnvKeyJwtType] + algo := updatedData[constants.EnvKeyJwtType].(string) if params.JwtType != nil { algo = *params.JwtType if !crypto.IsHMACA(algo) && !crypto.IsECDSA(algo) && !crypto.IsRSA(algo) { @@ -48,7 +52,7 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model return res, fmt.Errorf("invalid jwt type") } - updatedData.StringEnv[constants.EnvKeyJwtType] = algo + updatedData[constants.EnvKeyJwtType] = algo isJWTUpdated = true } @@ -136,8 +140,12 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model log.Debug("Old admin secret is required for admin secret update") return res, errors.New("admin secret and old admin secret are required for secret change") } - - if *params.OldAdminSecret != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) { + oldAdminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + log.Debug("Failed to get old admin secret: ", err) + return res, err + } + if *params.OldAdminSecret != oldAdminSecret { log.Debug("Old admin secret is invalid") return res, errors.New("old admin secret is not correct") } @@ -155,31 +163,28 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model fieldType := reflect.TypeOf(value).String() if fieldType == "string" { - updatedData.StringEnv[key] = value.(string) + updatedData[key] = value.(string) } if fieldType == "bool" { - updatedData.BoolEnv[key] = value.(bool) + updatedData[key] = value.(bool) } if fieldType == "[]interface {}" { - stringArr := []string{} - for _, v := range value.([]interface{}) { - stringArr = append(stringArr, v.(string)) - } - updatedData.SliceEnv[key] = stringArr + stringArr := utils.ConvertInterfaceToStringSlice(value) + updatedData[key] = strings.Join(stringArr, ",") } } } // handle derivative cases like disabling email verification & magic login // in case SMTP is off but env is set to true - if updatedData.StringEnv[constants.EnvKeySmtpHost] == "" || updatedData.StringEnv[constants.EnvKeySmtpUsername] == "" || updatedData.StringEnv[constants.EnvKeySmtpPassword] == "" || updatedData.StringEnv[constants.EnvKeySenderEmail] == "" && updatedData.StringEnv[constants.EnvKeySmtpPort] == "" { - if !updatedData.BoolEnv[constants.EnvKeyDisableEmailVerification] { - updatedData.BoolEnv[constants.EnvKeyDisableEmailVerification] = true + if updatedData[constants.EnvKeySmtpHost] == "" || updatedData[constants.EnvKeySmtpUsername] == "" || updatedData[constants.EnvKeySmtpPassword] == "" || updatedData[constants.EnvKeySenderEmail] == "" && updatedData[constants.EnvKeySmtpPort] == "" { + if !updatedData[constants.EnvKeyDisableEmailVerification].(bool) { + updatedData[constants.EnvKeyDisableEmailVerification] = true } - if !updatedData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] { - updatedData.BoolEnv[constants.EnvKeyDisableMagicLinkLogin] = true + if !updatedData[constants.EnvKeyDisableMagicLinkLogin].(bool) { + updatedData[constants.EnvKeyDisableMagicLinkLogin] = true } } @@ -206,19 +211,25 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model } // Update local store - envstore.EnvStoreObj.UpdateEnvStore(updatedData) + memorystore.Provider.UpdateEnvStore(updatedData) jwk, err := crypto.GenerateJWKBasedOnEnv() if err != nil { log.Debug("Failed to generate JWK: ", err) return res, err } // updating jwk - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJWK, jwk) - err = sessionstore.InitSession() + err = memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJWK, jwk) if err != nil { - log.Debug("Failed to init session store: ", err) + log.Debug("Failed to update JWK: ", err) return res, err } + + // TODO check how to update session store based on env change. + // err = sessionstore.InitSession() + // if err != nil { + // log.Debug("Failed to init session store: ", err) + // return res, err + // } err = oauth.InitOAuth() if err != nil { return res, err @@ -232,7 +243,12 @@ func UpdateEnvResolver(ctx context.Context, params model.UpdateEnvInput) (*model } if params.AdminSecret != nil { - hashedKey, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + log.Debug("Failed to get admin secret: ", err) + return res, err + } + hashedKey, err := crypto.EncryptPassword(adminSecret) if err != nil { log.Debug("Failed to encrypt admin secret: ", err) return res, err diff --git a/server/resolvers/update_profile.go b/server/resolvers/update_profile.go index a7ddc49..9a00276 100644 --- a/server/resolvers/update_profile.go +++ b/server/resolvers/update_profile.go @@ -14,11 +14,12 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" "golang.org/x/crypto/bcrypt" ) @@ -122,14 +123,14 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) if params.Email != nil && user.Email != *params.Email { // check if valid email - if !utils.IsValidEmail(*params.Email) { + if !validators.IsValidEmail(*params.Email) { log.Debug("Failed to validate email: ", *params.Email) return res, fmt.Errorf("invalid email address") } newEmail := strings.ToLower(*params.Email) // check if valid email - if !utils.IsValidEmail(newEmail) { + if !validators.IsValidEmail(newEmail) { log.Debug("Failed to validate new email: ", newEmail) return res, fmt.Errorf("invalid new email address") } @@ -141,12 +142,17 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) return res, fmt.Errorf("user with this email address already exists") } - go sessionstore.DeleteAllUserSession(user.ID) + go memorystore.Provider.DeleteAllUserSession(user.ID) go cookie.DeleteSession(gc) user.Email = newEmail - if !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { - hostname := utils.GetHost(gc) + isEmailVerificationDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) + if err != nil { + log.Debug("Failed to get disable email verification env variable: ", err) + return res, err + } + if !isEmailVerificationDisabled { + hostname := parsers.GetHost(gc) user.EmailVerifiedAt = nil hasEmailChanged = true // insert verification request @@ -156,7 +162,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) return res, err } verificationType := constants.VerificationTypeUpdateEmail - redirectURL := utils.GetAppURL(gc) + redirectURL := parsers.GetAppURL(gc) verificationToken, err := token.CreateVerificationToken(newEmail, verificationType, hostname, nonceHash, redirectURL) if err != nil { log.Debug("Failed to create verification token: ", err) diff --git a/server/resolvers/update_user.go b/server/resolvers/update_user.go index 3628ba4..b1b72b6 100644 --- a/server/resolvers/update_user.go +++ b/server/resolvers/update_user.go @@ -12,11 +12,12 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/validators" ) // UpdateUserResolver is a resolver for update user mutation @@ -98,7 +99,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod if params.Email != nil && user.Email != *params.Email { // check if valid email - if !utils.IsValidEmail(*params.Email) { + if !validators.IsValidEmail(*params.Email) { log.Debug("Invalid email: ", *params.Email) return res, fmt.Errorf("invalid email address") } @@ -112,9 +113,9 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod } // TODO figure out how to do this - go sessionstore.DeleteAllUserSession(user.ID) + go memorystore.Provider.DeleteAllUserSession(user.ID) - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) user.Email = newEmail user.EmailVerifiedAt = nil // insert verification request @@ -124,7 +125,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod return res, err } verificationType := constants.VerificationTypeUpdateEmail - redirectURL := utils.GetAppURL(gc) + redirectURL := parsers.GetAppURL(gc) verificationToken, err := token.CreateVerificationToken(newEmail, verificationType, hostname, nonceHash, redirectURL) if err != nil { log.Debug("Failed to create verification token: ", err) @@ -155,16 +156,33 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod inputRoles = append(inputRoles, *item) } - if !utils.IsValidRoles(inputRoles, append([]string{}, append(envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyRoles), envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyProtectedRoles)...)...)) { + rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) + roles := []string{} + if err != nil { + log.Debug("Error getting roles: ", err) + rolesString = "" + } else { + roles = strings.Split(rolesString, ",") + } + protectedRolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyProtectedRoles) + protectedRoles := []string{} + if err != nil { + log.Debug("Error getting protected roles: ", err) + protectedRolesString = "" + } else { + protectedRoles = strings.Split(protectedRolesString, ",") + } + + if !validators.IsValidRoles(inputRoles, append([]string{}, append(roles, protectedRoles...)...)) { log.Debug("Invalid roles: ", params.Roles) return res, fmt.Errorf("invalid list of roles") } - if !utils.IsStringArrayEqual(inputRoles, currentRoles) { + if !validators.IsStringArrayEqual(inputRoles, currentRoles) { rolesToSave = strings.Join(inputRoles, ",") } - go sessionstore.DeleteAllUserSession(user.ID) + go memorystore.Provider.DeleteAllUserSession(user.ID) } if rolesToSave != "" { diff --git a/server/resolvers/validate_jwt_token.go b/server/resolvers/validate_jwt_token.go index 4733eb2..1efdad5 100644 --- a/server/resolvers/validate_jwt_token.go +++ b/server/resolvers/validate_jwt_token.go @@ -10,7 +10,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -38,8 +39,8 @@ func ValidateJwtTokenResolver(ctx context.Context, params model.ValidateJWTToken nonce := "" // access_token and refresh_token should be validated from session store as well if tokenType == "access_token" || tokenType == "refresh_token" { - savedSession := sessionstore.GetState(params.Token) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(params.Token) + if savedSession == "" || err != nil { return &model.ValidateJWTTokenResponse{ IsValid: false, }, nil @@ -49,7 +50,7 @@ func ValidateJwtTokenResolver(ctx context.Context, params model.ValidateJWTToken userID = savedSessionSplit[1] } - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) var claimRoles []string var claims jwt.MapClaims diff --git a/server/resolvers/verify_email.go b/server/resolvers/verify_email.go index c19fdcc..3c1d9d6 100644 --- a/server/resolvers/verify_email.go +++ b/server/resolvers/verify_email.go @@ -12,7 +12,8 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -34,7 +35,7 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m } // verify if token exists in db - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) claim, err := token.ParseJWTToken(params.Token, hostname, verificationRequest.Nonce, verificationRequest.Email) if err != nil { log.Debug("Failed to parse token: ", err) @@ -74,8 +75,8 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m return res, err } - sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) cookie.SetSession(gc, authToken.FingerPrintHash) go db.Provider.AddSession(models.Session{ UserID: user.ID, diff --git a/server/sessionstore/in_memory_session.go b/server/sessionstore/in_memory_session.go deleted file mode 100644 index edc6924..0000000 --- a/server/sessionstore/in_memory_session.go +++ /dev/null @@ -1,74 +0,0 @@ -package sessionstore - -import ( - "strings" - "sync" -) - -// InMemoryStore is a simple in-memory store for sessions. -type InMemoryStore struct { - mutex sync.Mutex - sessionStore map[string]map[string]string - stateStore map[string]string -} - -// ClearStore clears the in-memory store. -func (c *InMemoryStore) ClearStore() { - c.mutex.Lock() - defer c.mutex.Unlock() - c.sessionStore = map[string]map[string]string{} -} - -// GetUserSessions returns all the user session token from the in-memory store. -func (c *InMemoryStore) GetUserSessions(userId string) map[string]string { - // c.mutex.Lock() - // defer c.mutex.Unlock() - res := map[string]string{} - for k, v := range c.stateStore { - split := strings.Split(v, "@") - if split[1] == userId { - res[k] = split[0] - } - } - - return res -} - -// DeleteAllUserSession deletes all the user sessions from in-memory store. -func (c *InMemoryStore) DeleteAllUserSession(userId string) { - // c.mutex.Lock() - // defer c.mutex.Unlock() - sessions := GetUserSessions(userId) - for k := range sessions { - RemoveState(k) - } -} - -// SetState sets the state in the in-memory store. -func (c *InMemoryStore) SetState(key, state string) { - c.mutex.Lock() - defer c.mutex.Unlock() - - c.stateStore[key] = state -} - -// GetState gets the state from the in-memory store. -func (c *InMemoryStore) GetState(key string) string { - c.mutex.Lock() - defer c.mutex.Unlock() - - state := "" - if stateVal, ok := c.stateStore[key]; ok { - state = stateVal - } - - return state -} - -// RemoveState removes the state from the in-memory store. -func (c *InMemoryStore) RemoveState(key string) { - c.mutex.Lock() - defer c.mutex.Unlock() - - delete(c.stateStore, key) -} diff --git a/server/sessionstore/redis_client.go b/server/sessionstore/redis_client.go deleted file mode 100644 index e73cd74..0000000 --- a/server/sessionstore/redis_client.go +++ /dev/null @@ -1,18 +0,0 @@ -package sessionstore - -import ( - "context" - "time" - - "github.com/go-redis/redis/v8" -) - -type RedisSessionClient interface { - HMSet(ctx context.Context, key string, values ...interface{}) *redis.BoolCmd - Del(ctx context.Context, keys ...string) *redis.IntCmd - HDel(ctx context.Context, key string, fields ...string) *redis.IntCmd - HMGet(ctx context.Context, key string, fields ...string) *redis.SliceCmd - HGetAll(ctx context.Context, key string) *redis.StringStringMapCmd - Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd - Get(ctx context.Context, key string) *redis.StringCmd -} diff --git a/server/sessionstore/redis_store.go b/server/sessionstore/redis_store.go deleted file mode 100644 index 6ade0fa..0000000 --- a/server/sessionstore/redis_store.go +++ /dev/null @@ -1,79 +0,0 @@ -package sessionstore - -import ( - "context" - "strings" - - log "github.com/sirupsen/logrus" -) - -type RedisStore struct { - ctx context.Context - store RedisSessionClient -} - -// ClearStore clears the redis store for authorizer related tokens -func (c *RedisStore) ClearStore() { - err := c.store.Del(c.ctx, "authorizer_*").Err() - if err != nil { - log.Debug("Error clearing redis store: ", err) - } -} - -// GetUserSessions returns all the user session token from the redis store. -func (c *RedisStore) GetUserSessions(userID string) map[string]string { - data, err := c.store.HGetAll(c.ctx, "*").Result() - if err != nil { - log.Debug("error getting token from redis store: ", err) - } - - res := map[string]string{} - for k, v := range data { - split := strings.Split(v, "@") - if split[1] == userID { - res[k] = split[0] - } - } - - return res -} - -// DeleteAllUserSession deletes all the user session from redis -func (c *RedisStore) DeleteAllUserSession(userId string) { - sessions := GetUserSessions(userId) - for k, v := range sessions { - if k == "token" { - err := c.store.Del(c.ctx, v) - if err != nil { - log.Debug("Error deleting redis token: ", err) - } - } - } -} - -// SetState sets the state in redis store. -func (c *RedisStore) SetState(key, value string) { - err := c.store.Set(c.ctx, "authorizer_"+key, value, 0).Err() - if err != nil { - log.Debug("Error saving redis token: ", err) - } -} - -// GetState gets the state from redis store. -func (c *RedisStore) GetState(key string) string { - state := "" - state, err := c.store.Get(c.ctx, "authorizer_"+key).Result() - if err != nil { - log.Debug("error getting token from redis store: ", err) - } - - return state -} - -// RemoveState removes the state from redis store. -func (c *RedisStore) RemoveState(key string) { - err := c.store.Del(c.ctx, "authorizer_"+key).Err() - if err != nil { - log.Fatalln("Error deleting redis token: ", err) - } -} diff --git a/server/sessionstore/session.go b/server/sessionstore/session.go deleted file mode 100644 index 7626e8f..0000000 --- a/server/sessionstore/session.go +++ /dev/null @@ -1,156 +0,0 @@ -package sessionstore - -import ( - "context" - "strings" - - log "github.com/sirupsen/logrus" - - "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/go-redis/redis/v8" -) - -// SessionStore is a struct that defines available session stores -// If redis store is available, higher preference is given to that store. -// Else in memory store is used. -type SessionStore struct { - InMemoryStoreObj *InMemoryStore - RedisMemoryStoreObj *RedisStore -} - -// SessionStoreObj is a global variable that holds the -// reference to various session store instances -var SessionStoreObj SessionStore - -// DeleteAllSessions deletes all the sessions from the session store -func DeleteAllUserSession(userId string) { - if SessionStoreObj.RedisMemoryStoreObj != nil { - SessionStoreObj.RedisMemoryStoreObj.DeleteAllUserSession(userId) - } - if SessionStoreObj.InMemoryStoreObj != nil { - SessionStoreObj.InMemoryStoreObj.DeleteAllUserSession(userId) - } -} - -// GetUserSessions returns all the user sessions from the session store -func GetUserSessions(userId string) map[string]string { - if SessionStoreObj.RedisMemoryStoreObj != nil { - return SessionStoreObj.RedisMemoryStoreObj.GetUserSessions(userId) - } - if SessionStoreObj.InMemoryStoreObj != nil { - return SessionStoreObj.InMemoryStoreObj.GetUserSessions(userId) - } - - return nil -} - -// ClearStore clears the session store for authorizer tokens -func ClearStore() { - if SessionStoreObj.RedisMemoryStoreObj != nil { - SessionStoreObj.RedisMemoryStoreObj.ClearStore() - } - if SessionStoreObj.InMemoryStoreObj != nil { - SessionStoreObj.InMemoryStoreObj.ClearStore() - } -} - -// SetState sets the login state (key, value form) in the session store -func SetState(key, state string) { - if SessionStoreObj.RedisMemoryStoreObj != nil { - SessionStoreObj.RedisMemoryStoreObj.SetState(key, state) - } - if SessionStoreObj.InMemoryStoreObj != nil { - SessionStoreObj.InMemoryStoreObj.SetState(key, state) - } -} - -// GetState returns the state from the session store -func GetState(key string) string { - if SessionStoreObj.RedisMemoryStoreObj != nil { - return SessionStoreObj.RedisMemoryStoreObj.GetState(key) - } - if SessionStoreObj.InMemoryStoreObj != nil { - return SessionStoreObj.InMemoryStoreObj.GetState(key) - } - - return "" -} - -// RemoveState removes the social login state from the session store -func RemoveState(key string) { - if SessionStoreObj.RedisMemoryStoreObj != nil { - SessionStoreObj.RedisMemoryStoreObj.RemoveState(key) - } - if SessionStoreObj.InMemoryStoreObj != nil { - SessionStoreObj.InMemoryStoreObj.RemoveState(key) - } -} - -// InitializeSessionStore initializes the SessionStoreObj based on environment variables -func InitSession() error { - if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL) != "" { - log.Info("using redis store to save sessions") - - redisURL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL) - redisURLHostPortsList := strings.Split(redisURL, ",") - - if len(redisURLHostPortsList) > 1 { - opt, err := redis.ParseURL(redisURLHostPortsList[0]) - if err != nil { - log.Debug("error parsing redis url: ", err) - return err - } - urls := []string{opt.Addr} - urlList := redisURLHostPortsList[1:] - urls = append(urls, urlList...) - clusterOpt := &redis.ClusterOptions{Addrs: urls} - - rdb := redis.NewClusterClient(clusterOpt) - ctx := context.Background() - _, err = rdb.Ping(ctx).Result() - if err != nil { - log.Debug("error connecting to redis: ", err) - return err - } - SessionStoreObj.RedisMemoryStoreObj = &RedisStore{ - ctx: ctx, - store: rdb, - } - - // return on successful initialization - return nil - } - - opt, err := redis.ParseURL(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL)) - if err != nil { - log.Debug("error parsing redis url: ", err) - return err - } - - rdb := redis.NewClient(opt) - ctx := context.Background() - _, err = rdb.Ping(ctx).Result() - if err != nil { - log.Debug("error connecting to redis: ", err) - return err - } - - SessionStoreObj.RedisMemoryStoreObj = &RedisStore{ - ctx: ctx, - store: rdb, - } - - // return on successful initialization - return nil - } - - log.Info("using in memory store to save sessions") - // if redis url is not set use in memory store - SessionStoreObj.InMemoryStoreObj = &InMemoryStore{ - sessionStore: map[string]map[string]string{}, - stateStore: map[string]string{}, - } - - return nil -} diff --git a/server/test/admin_login_test.go b/server/test/admin_login_test.go index cf949f8..4b8cc85 100644 --- a/server/test/admin_login_test.go +++ b/server/test/admin_login_test.go @@ -4,8 +4,8 @@ import ( "testing" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -20,8 +20,10 @@ func adminLoginTests(t *testing.T, s TestSetup) { assert.NotNil(t, err) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + assert.Nil(t, err) _, err = resolvers.AdminLoginResolver(ctx, model.AdminLoginInput{ - AdminSecret: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret), + AdminSecret: adminSecret, }) assert.Nil(t, err) diff --git a/server/test/admin_logout_test.go b/server/test/admin_logout_test.go index 94f65c0..d37b140 100644 --- a/server/test/admin_logout_test.go +++ b/server/test/admin_logout_test.go @@ -6,7 +6,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -18,9 +18,12 @@ func adminLogoutTests(t *testing.T, s TestSetup) { _, err := resolvers.AdminLogoutResolver(ctx) assert.NotNil(t, err) - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) _, err = resolvers.AdminLogoutResolver(ctx) assert.Nil(t, err) diff --git a/server/test/admin_session_test.go b/server/test/admin_session_test.go index 96f5bfc..fb28c96 100644 --- a/server/test/admin_session_test.go +++ b/server/test/admin_session_test.go @@ -6,7 +6,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -18,9 +18,12 @@ func adminSessionTests(t *testing.T, s TestSetup) { _, err := resolvers.AdminSessionResolver(ctx) assert.NotNil(t, err) - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) _, err = resolvers.AdminSessionResolver(ctx) assert.Nil(t, err) diff --git a/server/test/admin_signup_test.go b/server/test/admin_signup_test.go index c478475..6aedc17 100644 --- a/server/test/admin_signup_test.go +++ b/server/test/admin_signup_test.go @@ -4,8 +4,8 @@ import ( "testing" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -20,7 +20,7 @@ func adminSignupTests(t *testing.T, s TestSetup) { assert.NotNil(t, err) // reset env for test to pass - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyAdminSecret, "") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyAdminSecret, "") _, err = resolvers.AdminSignupResolver(ctx, model.AdminSignupInput{ AdminSecret: "admin123", diff --git a/server/test/delete_user_test.go b/server/test/delete_user_test.go index 2ecbe35..2e11cf2 100644 --- a/server/test/delete_user_test.go +++ b/server/test/delete_user_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -27,10 +27,12 @@ func deleteUserTest(t *testing.T, s TestSetup) { Email: email, }) assert.NotNil(t, err, "unauthorized") - - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) _, err = resolvers.DeleteUserResolver(ctx, model.DeleteUserInput{ Email: email, diff --git a/server/test/enable_access_test.go b/server/test/enable_access_test.go index 6d06153..cc57be3 100644 --- a/server/test/enable_access_test.go +++ b/server/test/enable_access_test.go @@ -7,8 +7,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -28,10 +28,11 @@ func enableAccessTest(t *testing.T, s TestSetup) { }) assert.NoError(t, err) assert.NotNil(t, verifyRes.AccessToken) - - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) res, err := resolvers.RevokeAccessResolver(ctx, model.UpdateAccessInput{ UserID: verifyRes.User.ID, diff --git a/server/test/env_file_test.go b/server/test/env_file_test.go index 75cc498..31d5c70 100644 --- a/server/test/env_file_test.go +++ b/server/test/env_file_test.go @@ -1,26 +1,33 @@ package test import ( + "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/env" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/stretchr/testify/assert" + "github.com/authorizerdev/authorizer/server/memorystore" ) func TestEnvs(t *testing.T) { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEnvPath, "../../.env.sample") - env.InitAllEnv() - store := envstore.EnvStoreObj.GetEnvStoreClone() + err := os.Setenv(constants.EnvKeyEnvPath, "../../.env.test") + assert.Nil(t, err) + err = memorystore.InitRequiredEnv() + assert.Nil(t, err) + err = env.InitAllEnv() + assert.Nil(t, err) + store, err := memorystore.Provider.GetEnvStore() + assert.Nil(t, err) - assert.Equal(t, store.StringEnv[constants.EnvKeyEnv], "production") - assert.False(t, store.BoolEnv[constants.EnvKeyDisableEmailVerification]) - assert.False(t, store.BoolEnv[constants.EnvKeyDisableMagicLinkLogin]) - assert.False(t, store.BoolEnv[constants.EnvKeyDisableBasicAuthentication]) - assert.Equal(t, store.StringEnv[constants.EnvKeyJwtType], "RS256") - assert.Equal(t, store.StringEnv[constants.EnvKeyJwtRoleClaim], "role") - assert.EqualValues(t, store.SliceEnv[constants.EnvKeyRoles], []string{"user"}) - assert.EqualValues(t, store.SliceEnv[constants.EnvKeyDefaultRoles], []string{"user"}) - assert.EqualValues(t, store.SliceEnv[constants.EnvKeyAllowedOrigins], []string{"*"}) + assert.Equal(t, "test", store[constants.EnvKeyEnv].(string)) + assert.False(t, store[constants.EnvKeyDisableEmailVerification].(bool)) + assert.False(t, store[constants.EnvKeyDisableMagicLinkLogin].(bool)) + assert.False(t, store[constants.EnvKeyDisableBasicAuthentication].(bool)) + assert.Equal(t, "RS256", store[constants.EnvKeyJwtType].(string)) + assert.Equal(t, store[constants.EnvKeyJwtRoleClaim].(string), "role") + assert.EqualValues(t, store[constants.EnvKeyRoles].(string), "user") + assert.EqualValues(t, store[constants.EnvKeyDefaultRoles].(string), "user") + assert.EqualValues(t, store[constants.EnvKeyAllowedOrigins].(string), "*") } diff --git a/server/test/env_test.go b/server/test/env_test.go index f825a50..725a834 100644 --- a/server/test/env_test.go +++ b/server/test/env_test.go @@ -6,7 +6,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -18,12 +18,14 @@ func envTests(t *testing.T, s TestSetup) { _, err := resolvers.EnvResolver(ctx) assert.NotNil(t, err) - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) - res, err := resolvers.EnvResolver(ctx) + h, err := crypto.EncryptPassword(adminSecret) assert.Nil(t, err) - assert.Equal(t, *res.AdminSecret, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) + res, err := resolvers.EnvResolver(ctx) + assert.Nil(t, err) + assert.Equal(t, *res.AdminSecret, adminSecret) }) } diff --git a/server/test/generate_jwt_keys_test.go b/server/test/generate_jwt_keys_test.go index b9acb76..e9ef639 100644 --- a/server/test/generate_jwt_keys_test.go +++ b/server/test/generate_jwt_keys_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -30,9 +30,13 @@ func generateJWTkeyTest(t *testing.T, s TestSetup) { assert.Error(t, err) assert.Nil(t, res) }) - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) t.Run(`should generate HS256 secret`, func(t *testing.T) { res, err := resolvers.GenerateJWTKeysResolver(ctx, model.GenerateJWTKeysInput{ Type: "HS256", diff --git a/server/test/invite_member_test.go b/server/test/invite_member_test.go index 76cd389..42bc017 100644 --- a/server/test/invite_member_test.go +++ b/server/test/invite_member_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -26,9 +26,12 @@ func inviteUserTest(t *testing.T, s TestSetup) { assert.Error(t, err) assert.Nil(t, res) - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) // invalid emails test invalidEmailsTest := []string{ diff --git a/server/test/jwt_test.go b/server/test/jwt_test.go index 71f74be..9695406 100644 --- a/server/test/jwt_test.go +++ b/server/test/jwt_test.go @@ -6,7 +6,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" "github.com/golang-jwt/jwt" "github.com/google/uuid" @@ -15,10 +15,14 @@ import ( func TestJwt(t *testing.T) { // persist older data till test is done and then reset it - jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) - publicKey := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) - privateKey := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey) - clientID := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) + jwtType, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + assert.Nil(t, err) + publicKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + assert.Nil(t, err) + privateKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey) + assert.Nil(t, err) + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + assert.Nil(t, err) nonce := uuid.New().String() hostname := "localhost" subject := "test" @@ -33,14 +37,14 @@ func TestJwt(t *testing.T) { } t.Run("invalid jwt type", func(t *testing.T) { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "invalid") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "invalid") token, err := token.SignJWTToken(claims) assert.Error(t, err, "unsupported signing method") assert.Empty(t, token) }) t.Run("expired jwt token", func(t *testing.T) { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "HS256") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtSecret, "test") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "HS256") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtSecret, "test") expiredClaims := jwt.MapClaims{ "exp": time.Now().Add(-time.Minute * 30).Unix(), "iat": time.Now().Unix(), @@ -52,9 +56,9 @@ func TestJwt(t *testing.T) { assert.Error(t, err, err.Error(), "Token is expired") }) t.Run("HMAC algorithms", func(t *testing.T) { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtSecret, "test") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtSecret, "test") t.Run("HS256", func(t *testing.T) { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "HS256") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "HS256") jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -63,7 +67,7 @@ func TestJwt(t *testing.T) { assert.Equal(t, c["email"].(string), claims["email"]) }) t.Run("HS384", func(t *testing.T) { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "HS384") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "HS384") jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -72,7 +76,7 @@ func TestJwt(t *testing.T) { assert.Equal(t, c["email"].(string), claims["email"]) }) t.Run("HS512", func(t *testing.T) { - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "HS512") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "HS512") jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -86,9 +90,9 @@ func TestJwt(t *testing.T) { t.Run("RS256", func(t *testing.T) { _, privateKey, publickKey, _, err := crypto.NewRSAKey("RS256", clientID) assert.NoError(t, err) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "RS256") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPrivateKey, privateKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPublicKey, publickKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "RS256") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPrivateKey, privateKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPublicKey, publickKey) jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -99,9 +103,9 @@ func TestJwt(t *testing.T) { t.Run("RS384", func(t *testing.T) { _, privateKey, publickKey, _, err := crypto.NewRSAKey("RS384", clientID) assert.NoError(t, err) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "RS384") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPrivateKey, privateKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPublicKey, publickKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "RS384") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPrivateKey, privateKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPublicKey, publickKey) jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -112,9 +116,9 @@ func TestJwt(t *testing.T) { t.Run("RS512", func(t *testing.T) { _, privateKey, publickKey, _, err := crypto.NewRSAKey("RS512", clientID) assert.NoError(t, err) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "RS512") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPrivateKey, privateKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPublicKey, publickKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "RS512") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPrivateKey, privateKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPublicKey, publickKey) jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -128,9 +132,9 @@ func TestJwt(t *testing.T) { t.Run("ES256", func(t *testing.T) { _, privateKey, publickKey, _, err := crypto.NewECDSAKey("ES256", clientID) assert.NoError(t, err) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "ES256") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPrivateKey, privateKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPublicKey, publickKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "ES256") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPrivateKey, privateKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPublicKey, publickKey) jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -141,9 +145,9 @@ func TestJwt(t *testing.T) { t.Run("ES384", func(t *testing.T) { _, privateKey, publickKey, _, err := crypto.NewECDSAKey("ES384", clientID) assert.NoError(t, err) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "ES384") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPrivateKey, privateKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPublicKey, publickKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "ES384") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPrivateKey, privateKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPublicKey, publickKey) jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -154,9 +158,9 @@ func TestJwt(t *testing.T) { t.Run("ES512", func(t *testing.T) { _, privateKey, publickKey, _, err := crypto.NewECDSAKey("ES512", clientID) assert.NoError(t, err) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, "ES512") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPrivateKey, privateKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPublicKey, publickKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, "ES512") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPrivateKey, privateKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPublicKey, publickKey) jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) @@ -166,7 +170,7 @@ func TestJwt(t *testing.T) { }) }) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtType, jwtType) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPublicKey, publicKey) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyJwtPrivateKey, privateKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtType, jwtType) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPublicKey, publicKey) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyJwtPrivateKey, privateKey) } diff --git a/server/test/login_test.go b/server/test/login_test.go index ebfbf68..b2243e9 100644 --- a/server/test/login_test.go +++ b/server/test/login_test.go @@ -5,7 +5,6 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/authorizerdev/authorizer/server/utils" @@ -15,7 +14,6 @@ import ( func loginTests(t *testing.T, s TestSetup) { t.Helper() t.Run(`should login`, func(t *testing.T) { - t.Logf("=> is enabled: %v", envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification)) _, ctx := createContext(s) email := "login." + s.TestInfo.Email _, err := resolvers.SignupResolver(ctx, model.SignUpInput{ diff --git a/server/test/logout_test.go b/server/test/logout_test.go index 8956b31..3b38e3f 100644 --- a/server/test/logout_test.go +++ b/server/test/logout_test.go @@ -6,10 +6,9 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" - "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/stretchr/testify/assert" ) @@ -29,12 +28,12 @@ func logoutTests(t *testing.T, s TestSetup) { }) token := *verifyRes.AccessToken - sessions := sessionstore.GetUserSessions(verifyRes.User.ID) + sessions := memorystore.Provider.GetUserSessions(verifyRes.User.ID) cookie := "" // set all they keys in cookie one of them should be session cookie for key := range sessions { if key != token { - cookie += fmt.Sprintf("%s=%s;", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", key) + cookie += fmt.Sprintf("%s=%s;", constants.AppCookieName+"_session", key) } } diff --git a/server/test/magic_link_login_test.go b/server/test/magic_link_login_test.go index a38a802..cb6df35 100644 --- a/server/test/magic_link_login_test.go +++ b/server/test/magic_link_login_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -17,13 +17,13 @@ func magicLinkLoginTests(t *testing.T, s TestSetup) { t.Run(`should login with magic link`, func(t *testing.T) { req, ctx := createContext(s) email := "magic_link_login." + s.TestInfo.Email - envstore.EnvStoreObj.UpdateEnvVariable(constants.BoolStoreIdentifier, constants.EnvKeyDisableSignUp, true) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true) _, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ Email: email, }) assert.NotNil(t, err, "signup disabled") - envstore.EnvStoreObj.UpdateEnvVariable(constants.BoolStoreIdentifier, constants.EnvKeyDisableSignUp, false) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false) _, err = resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ Email: email, }) diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index 513c1b0..55e7f73 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -6,7 +6,7 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/env" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" ) func TestResolvers(t *testing.T) { @@ -19,9 +19,10 @@ func TestResolvers(t *testing.T) { for dbType, dbURL := range databases { s := testSetup() - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseURL, dbURL) - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseType, dbType) defer s.Server.Close() + + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseURL, dbURL) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDatabaseType, dbType) err := db.InitDB() if err != nil { t.Errorf("Error initializing database: %s", err) @@ -35,8 +36,8 @@ func TestResolvers(t *testing.T) { } env.PersistEnv() - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEnv, "test") - envstore.EnvStoreObj.UpdateEnvVariable(constants.BoolStoreIdentifier, constants.EnvKeyIsProd, false) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyEnv, "test") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyIsProd, false) t.Run("should pass tests for "+dbType, func(t *testing.T) { // admin tests adminSignupTests(t, s) diff --git a/server/test/revoke_access_test.go b/server/test/revoke_access_test.go index 5317721..3018c07 100644 --- a/server/test/revoke_access_test.go +++ b/server/test/revoke_access_test.go @@ -7,8 +7,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -34,9 +34,12 @@ func revokeAccessTest(t *testing.T, s TestSetup) { }) assert.Error(t, err) - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) res, err = resolvers.RevokeAccessResolver(ctx, model.UpdateAccessInput{ UserID: verifyRes.User.ID, diff --git a/server/test/session_test.go b/server/test/session_test.go index 65ce57e..2c4de19 100644 --- a/server/test/session_test.go +++ b/server/test/session_test.go @@ -7,10 +7,9 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" - "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/stretchr/testify/assert" ) @@ -34,13 +33,13 @@ func sessionTests(t *testing.T, s TestSetup) { Token: verificationRequest.Token, }) - sessions := sessionstore.GetUserSessions(verifyRes.User.ID) + sessions := memorystore.Provider.GetUserSessions(verifyRes.User.ID) cookie := "" token := *verifyRes.AccessToken // set all they keys in cookie one of them should be session cookie for key := range sessions { if key != token { - cookie += fmt.Sprintf("%s=%s;", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", key) + cookie += fmt.Sprintf("%s=%s;", constants.AppCookieName+"_session", key) } } cookie = strings.TrimSuffix(cookie, ";") diff --git a/server/test/signup_test.go b/server/test/signup_test.go index 4c40da3..1e34429 100644 --- a/server/test/signup_test.go +++ b/server/test/signup_test.go @@ -5,8 +5,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -30,7 +30,7 @@ func signupTests(t *testing.T, s TestSetup) { }) assert.NotNil(t, err, "invalid password") - envstore.EnvStoreObj.UpdateEnvVariable(constants.BoolStoreIdentifier, constants.EnvKeyDisableSignUp, true) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true) res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, @@ -38,7 +38,7 @@ func signupTests(t *testing.T, s TestSetup) { }) assert.NotNil(t, err, "singup disabled") - envstore.EnvStoreObj.UpdateEnvVariable(constants.BoolStoreIdentifier, constants.EnvKeyDisableSignUp, false) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false) res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, diff --git a/server/test/test.go b/server/test/test.go index c4cb14f..0160ca0 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -5,15 +5,17 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" "time" + log "github.com/sirupsen/logrus" + "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/env" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/handlers" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/middlewares" - "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/gin-gonic/gin" ) @@ -76,17 +78,35 @@ func testSetup() TestSetup { Password: "Test@123", } - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEnvPath, "../../.env.sample") - env.InitRequiredEnv() - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpHost, "smtp.yopmail.com") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpPort, "2525") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpUsername, "lakhan@yopmail.com") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpPassword, "test") - envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySenderEmail, "info@yopmail.com") - envstore.EnvStoreObj.UpdateEnvVariable(constants.SliceStoreIdentifier, constants.EnvKeyProtectedRoles, []string{"admin"}) - db.InitDB() - env.InitAllEnv() - sessionstore.InitSession() + err := os.Setenv(constants.EnvKeyEnvPath, "../../.env.test") + if err != nil { + log.Fatal("Error loading .env.sample file") + } + err = memorystore.InitRequiredEnv() + if err != nil { + log.Fatal("Error loading required env: ", err) + } + + err = memorystore.InitMemStore() + if err != nil { + log.Fatal("Error loading memory store: ", err) + } + memorystore.Provider.UpdateEnvVariable(constants.EnvKeySmtpHost, "smtp.yopmail.com") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeySmtpPort, "2525") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeySmtpUsername, "lakhan@yopmail.com") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeySmtpPassword, "test") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeySenderEmail, "info@yopmail.com") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyProtectedRoles, "admin") + + err = db.InitDB() + if err != nil { + log.Fatal("Error loading db: ", err) + } + + err = env.InitAllEnv() + if err != nil { + log.Fatal("Error loading env: ", err) + } w := httptest.NewRecorder() c, r := gin.CreateTestContext(w) diff --git a/server/test/update_env_test.go b/server/test/update_env_test.go index 527becc..c602b4a 100644 --- a/server/test/update_env_test.go +++ b/server/test/update_env_test.go @@ -2,12 +2,13 @@ package test import ( "fmt" + "strings" "testing" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -16,16 +17,19 @@ func updateEnvTests(t *testing.T, s TestSetup) { t.Helper() t.Run(`should update envs`, func(t *testing.T) { req, ctx := createContext(s) - originalAppURL := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAppURL) + originalAppURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAppURL) + assert.Nil(t, err) data := model.UpdateEnvInput{} - _, err := resolvers.UpdateEnvResolver(ctx, data) + _, err = resolvers.UpdateEnvResolver(ctx, data) assert.NotNil(t, err) - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) newURL := "https://test.com" disableLoginPage := true allowedOrigins := []string{"http://localhost:8080"} @@ -35,11 +39,20 @@ func updateEnvTests(t *testing.T, s TestSetup) { AllowedOrigins: allowedOrigins, } _, err = resolvers.UpdateEnvResolver(ctx, data) - assert.Nil(t, err) - assert.Equal(t, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAppURL), newURL) - assert.True(t, envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableLoginPage)) - assert.Equal(t, envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyAllowedOrigins), allowedOrigins) + + appURL, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAppURL) + assert.Nil(t, err) + assert.Equal(t, appURL, newURL) + + isLoginPageDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableLoginPage) + assert.Nil(t, err) + assert.True(t, isLoginPageDisabled) + + storedOriginsStrings, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAllowedOrigins) + assert.Nil(t, err) + storedOrigins := strings.Split(storedOriginsStrings, ",") + assert.Equal(t, storedOrigins, allowedOrigins) disableLoginPage = false data = model.UpdateEnvInput{ diff --git a/server/test/update_user_test.go b/server/test/update_user_test.go index fd76653..ca197f7 100644 --- a/server/test/update_user_test.go +++ b/server/test/update_user_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -33,10 +33,11 @@ func updateUserTest(t *testing.T, s TestSetup) { Roles: newRoles, }) assert.NotNil(t, err, "unauthorized") - - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) _, err = resolvers.UpdateUserResolver(ctx, model.UpdateUserInput{ ID: user.ID, Roles: newRoles, @@ -44,7 +45,7 @@ func updateUserTest(t *testing.T, s TestSetup) { // supplier is not part of envs assert.Error(t, err) adminRole = "admin" - envstore.EnvStoreObj.UpdateEnvVariable(constants.SliceStoreIdentifier, constants.EnvKeyProtectedRoles, []string{adminRole}) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyProtectedRoles, adminRole) newRoles = []*string{&adminRole, &userRole} _, err = resolvers.UpdateUserResolver(ctx, model.UpdateUserInput{ ID: user.ID, diff --git a/server/test/urls_test.go b/server/test/urls_test.go index 2f2fbcd..3ec2a53 100644 --- a/server/test/urls_test.go +++ b/server/test/urls_test.go @@ -3,14 +3,14 @@ package test import ( "testing" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/stretchr/testify/assert" ) func TestGetHostName(t *testing.T) { url := "http://test.herokuapp.com:80" - host, port := utils.GetHostParts(url) + host, port := parsers.GetHostParts(url) expectedHost := "test.herokuapp.com" assert.Equal(t, host, expectedHost, "hostname should be equal") @@ -20,7 +20,7 @@ func TestGetHostName(t *testing.T) { func TestGetDomainName(t *testing.T) { url := "http://test.herokuapp.com" - got := utils.GetDomainName(url) + got := parsers.GetDomainName(url) want := "herokuapp.com" assert.Equal(t, got, want, "domain name should be equal") diff --git a/server/test/users_test.go b/server/test/users_test.go index f390ed0..96e6537 100644 --- a/server/test/users_test.go +++ b/server/test/users_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -35,9 +35,11 @@ func usersTest(t *testing.T, s TestSetup) { usersRes, err := resolvers.UsersResolver(ctx, pagination) assert.NotNil(t, err, "unauthorized") - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) usersRes, err = resolvers.UsersResolver(ctx, pagination) assert.Nil(t, err) diff --git a/server/test/validate_jwt_token_test.go b/server/test/validate_jwt_token_test.go index 5bb4268..4207ebc 100644 --- a/server/test/validate_jwt_token_test.go +++ b/server/test/validate_jwt_token_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" - "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" "github.com/google/uuid" @@ -48,8 +48,8 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) { gc, err := utils.GinContextFromContext(ctx) assert.NoError(t, err) authToken, err := token.CreateAuthToken(gc, user, roles, scope) - sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) - sessionstore.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + memorystore.Provider.SetState(authToken.RefreshToken.Token, authToken.FingerPrint+"@"+user.ID) t.Run(`should validate the access token`, func(t *testing.T) { res, err := resolvers.ValidateJwtTokenResolver(ctx, model.ValidateJWTTokenInput{ diff --git a/server/test/validator_test.go b/server/test/validator_test.go index 9c4b51a..3509c4c 100644 --- a/server/test/validator_test.go +++ b/server/test/validator_test.go @@ -4,8 +4,8 @@ import ( "testing" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/utils" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/validators" "github.com/stretchr/testify/assert" ) @@ -14,38 +14,38 @@ func TestIsValidEmail(t *testing.T) { invalidEmail1 := "lakhan" invalidEmail2 := "lakhan.me" - assert.True(t, utils.IsValidEmail(validEmail), "it should be valid email") - assert.False(t, utils.IsValidEmail(invalidEmail1), "it should be invalid email") - assert.False(t, utils.IsValidEmail(invalidEmail2), "it should be invalid email") + assert.True(t, validators.IsValidEmail(validEmail), "it should be valid email") + assert.False(t, validators.IsValidEmail(invalidEmail1), "it should be invalid email") + assert.False(t, validators.IsValidEmail(invalidEmail2), "it should be invalid email") } func TestIsValidOrigin(t *testing.T) { // don't use portocal(http/https) for ALLOWED_ORIGINS while testing, // as we trim them off while running the main function - envstore.EnvStoreObj.UpdateEnvVariable(constants.SliceStoreIdentifier, constants.EnvKeyAllowedOrigins, []string{"localhost:8080", "*.google.com", "*.google.in", "*abc.*"}) - assert.False(t, utils.IsValidOrigin("http://myapp.com"), "it should be invalid origin") - assert.False(t, utils.IsValidOrigin("http://appgoogle.com"), "it should be invalid origin") - assert.True(t, utils.IsValidOrigin("http://app.google.com"), "it should be valid origin") - assert.False(t, utils.IsValidOrigin("http://app.google.ind"), "it should be invalid origin") - assert.True(t, utils.IsValidOrigin("http://app.google.in"), "it should be valid origin") - assert.True(t, utils.IsValidOrigin("http://xyx.abc.com"), "it should be valid origin") - assert.True(t, utils.IsValidOrigin("http://xyx.abc.in"), "it should be valid origin") - assert.True(t, utils.IsValidOrigin("http://xyxabc.in"), "it should be valid origin") - assert.True(t, utils.IsValidOrigin("http://localhost:8080"), "it should be valid origin") - envstore.EnvStoreObj.UpdateEnvVariable(constants.SliceStoreIdentifier, constants.EnvKeyAllowedOrigins, []string{"*"}) + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyAllowedOrigins, "localhost:8080,*.google.com,*.google.in,*abc.*") + assert.False(t, validators.IsValidOrigin("http://myapp.com"), "it should be invalid origin") + assert.False(t, validators.IsValidOrigin("http://appgoogle.com"), "it should be invalid origin") + assert.True(t, validators.IsValidOrigin("http://app.google.com"), "it should be valid origin") + assert.False(t, validators.IsValidOrigin("http://app.google.ind"), "it should be invalid origin") + assert.True(t, validators.IsValidOrigin("http://app.google.in"), "it should be valid origin") + assert.True(t, validators.IsValidOrigin("http://xyx.abc.com"), "it should be valid origin") + assert.True(t, validators.IsValidOrigin("http://xyx.abc.in"), "it should be valid origin") + assert.True(t, validators.IsValidOrigin("http://xyxabc.in"), "it should be valid origin") + assert.True(t, validators.IsValidOrigin("http://localhost:8080"), "it should be valid origin") + memorystore.Provider.UpdateEnvVariable(constants.EnvKeyAllowedOrigins, "*") } func TestIsValidIdentifier(t *testing.T) { - assert.False(t, utils.IsValidVerificationIdentifier("test"), "it should be invalid identifier") - assert.True(t, utils.IsValidVerificationIdentifier(constants.VerificationTypeBasicAuthSignup), "it should be valid identifier") - assert.True(t, utils.IsValidVerificationIdentifier(constants.VerificationTypeUpdateEmail), "it should be valid identifier") - assert.True(t, utils.IsValidVerificationIdentifier(constants.VerificationTypeForgotPassword), "it should be valid identifier") + assert.False(t, validators.IsValidVerificationIdentifier("test"), "it should be invalid identifier") + assert.True(t, validators.IsValidVerificationIdentifier(constants.VerificationTypeBasicAuthSignup), "it should be valid identifier") + assert.True(t, validators.IsValidVerificationIdentifier(constants.VerificationTypeUpdateEmail), "it should be valid identifier") + assert.True(t, validators.IsValidVerificationIdentifier(constants.VerificationTypeForgotPassword), "it should be valid identifier") } func TestIsValidPassword(t *testing.T) { - assert.False(t, utils.IsValidPassword("test"), "it should be invalid password") - assert.False(t, utils.IsValidPassword("Te@1"), "it should be invalid password") - assert.False(t, utils.IsValidPassword("n*rp7GGTd29V{xx%{pDb@7n{](SD.!+.Mp#*$EHDGk&$pAMf7e#432Sg,Gr](j3n]jV/3F8BJJT+9u9{q=8zK:8u!rpQBaXJp%A+7r!jQj)M(vC$UX,h;;WKm$U6i#7dBnC&2ryKzKd+(y&=Ud)hErT/j;v3t..CM).8nS)9qLtV7pmP;@2QuzDyGfL7KB()k:BpjAGL@bxD%r5gcBfh7$&wutk!wzMfPFY#nkjjqyZbEHku,{jc;gvbYq2)3w=KExnYz9Vbv:;*;?f##faxkULdMpmm&yEfePixzx+[{[38zGN;3TzF;6M#Xy_tMtx:yK*n$bc(bPyGz%EYkC&]ttUF@#aZ%$QZ:u!icF@+"), "it should be invalid password") - assert.False(t, utils.IsValidPassword("test@123"), "it should be invalid password") - assert.True(t, utils.IsValidPassword("Test@123"), "it should be valid password") + assert.False(t, validators.IsValidPassword("test"), "it should be invalid password") + assert.False(t, validators.IsValidPassword("Te@1"), "it should be invalid password") + assert.False(t, validators.IsValidPassword("n*rp7GGTd29V{xx%{pDb@7n{](SD.!+.Mp#*$EHDGk&$pAMf7e#432Sg,Gr](j3n]jV/3F8BJJT+9u9{q=8zK:8u!rpQBaXJp%A+7r!jQj)M(vC$UX,h;;WKm$U6i#7dBnC&2ryKzKd+(y&=Ud)hErT/j;v3t..CM).8nS)9qLtV7pmP;@2QuzDyGfL7KB()k:BpjAGL@bxD%r5gcBfh7$&wutk!wzMfPFY#nkjjqyZbEHku,{jc;gvbYq2)3w=KExnYz9Vbv:;*;?f##faxkULdMpmm&yEfePixzx+[{[38zGN;3TzF;6M#Xy_tMtx:yK*n$bc(bPyGz%EYkC&]ttUF@#aZ%$QZ:u!icF@+"), "it should be invalid password") + assert.False(t, validators.IsValidPassword("test@123"), "it should be invalid password") + assert.True(t, validators.IsValidPassword("Test@123"), "it should be valid password") } diff --git a/server/test/verification_requests_test.go b/server/test/verification_requests_test.go index b81a35f..8cbb762 100644 --- a/server/test/verification_requests_test.go +++ b/server/test/verification_requests_test.go @@ -6,8 +6,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" ) @@ -39,10 +39,12 @@ func verificationRequestsTest(t *testing.T, s TestSetup) { requests, err := resolvers.VerificationRequestsResolver(ctx, pagination) assert.NotNil(t, err, "unauthorized") - - h, err := crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) assert.Nil(t, err) - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminCookieName), h)) + + h, err := crypto.EncryptPassword(adminSecret) + assert.Nil(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) requests, err = resolvers.VerificationRequestsResolver(ctx, pagination) assert.Nil(t, err) diff --git a/server/token/admin_token.go b/server/token/admin_token.go index 1cdfe50..9dcbeb3 100644 --- a/server/token/admin_token.go +++ b/server/token/admin_token.go @@ -6,14 +6,18 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/gin-gonic/gin" "golang.org/x/crypto/bcrypt" ) // CreateAdminAuthToken creates the admin token based on secret key func CreateAdminAuthToken(tokenType string, c *gin.Context) (string, error) { - return crypto.EncryptPassword(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + return "", err + } + return crypto.EncryptPassword(adminSecret) } // GetAdminAuthToken helps in getting the admin token from the request cookie @@ -23,7 +27,11 @@ func GetAdminAuthToken(gc *gin.Context) (string, error) { return "", fmt.Errorf("unauthorized") } - err = bcrypt.CompareHashAndPassword([]byte(token), []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret))) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + return "", err + } + err = bcrypt.CompareHashAndPassword([]byte(token), []byte(adminSecret)) if err != nil { return "", fmt.Errorf(`unauthorized`) @@ -40,8 +48,11 @@ func IsSuperAdmin(gc *gin.Context) bool { if secret == "" { return false } - - return secret == envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) + if err != nil { + return false + } + return secret == adminSecret } return token != "" diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 6f8930f..65cb0d1 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -16,8 +16,8 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" "github.com/authorizerdev/authorizer/server/utils" ) @@ -67,7 +67,7 @@ func CreateSessionToken(user models.User, nonce string, roles, scope []string) ( // CreateAuthToken creates a new auth token when userlogs in func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string) (*Token, error) { - hostname := utils.GetHost(gc) + hostname := parsers.GetHost(gc) nonce := uuid.New().String() _, fingerPrintHash, err := CreateSessionToken(user, nonce, roles, scope) if err != nil { @@ -107,9 +107,13 @@ func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonc // expires in 1 year expiryBound := time.Hour * 8760 expiresAt := time.Now().Add(expiryBound).Unix() + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return "", 0, err + } customClaims := jwt.MapClaims{ "iss": hostname, - "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), + "aud": clientID, "sub": user.ID, "exp": expiresAt, "iat": time.Now().Unix(), @@ -130,16 +134,24 @@ func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonc // CreateAccessToken util to create JWT token, based on // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce string) (string, int64, error) { - expiryBound, err := utils.ParseDurationInSeconds(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime)) + expireTime, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime) + if err != nil { + return "", 0, err + } + expiryBound, err := utils.ParseDurationInSeconds(expireTime) if err != nil { expiryBound = time.Minute * 30 } expiresAt := time.Now().Add(expiryBound).Unix() + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return "", 0, err + } customClaims := jwt.MapClaims{ "iss": hostName, - "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), + "aud": clientID, "nonce": nonce, "sub": user.ID, "exp": expiresAt, @@ -180,14 +192,14 @@ func GetAccessToken(gc *gin.Context) (string, error) { // Function to validate access token for authorizer apis (profile, update_profile) func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interface{}, error) { - var res map[string]interface{} + res := make(map[string]interface{}) if accessToken == "" { return res, fmt.Errorf(`unauthorized`) } - savedSession := sessionstore.GetState(accessToken) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(accessToken) + if savedSession == "" || err != nil { return res, fmt.Errorf(`unauthorized`) } @@ -195,8 +207,8 @@ func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interf nonce := savedSessionSplit[0] userID := savedSessionSplit[1] - hostname := utils.GetHost(gc) - res, err := ParseJWTToken(accessToken, hostname, nonce, userID) + hostname := parsers.GetHost(gc) + res, err = ParseJWTToken(accessToken, hostname, nonce, userID) if err != nil { return res, err } @@ -210,14 +222,14 @@ func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interf // Function to validate refreshToken func ValidateRefreshToken(gc *gin.Context, refreshToken string) (map[string]interface{}, error) { - var res map[string]interface{} + res := make(map[string]interface{}) if refreshToken == "" { return res, fmt.Errorf(`unauthorized`) } - savedSession := sessionstore.GetState(refreshToken) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(refreshToken) + if savedSession == "" || err != nil { return res, fmt.Errorf(`unauthorized`) } @@ -225,8 +237,8 @@ func ValidateRefreshToken(gc *gin.Context, refreshToken string) (map[string]inte nonce := savedSessionSplit[0] userID := savedSessionSplit[1] - hostname := utils.GetHost(gc) - res, err := ParseJWTToken(refreshToken, hostname, nonce, userID) + hostname := parsers.GetHost(gc) + res, err = ParseJWTToken(refreshToken, hostname, nonce, userID) if err != nil { return res, err } @@ -243,8 +255,8 @@ func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionD return nil, fmt.Errorf(`unauthorized`) } - savedSession := sessionstore.GetState(encryptedSession) - if savedSession == "" { + savedSession, err := memorystore.Provider.GetState(encryptedSession) + if savedSession == "" || err != nil { return nil, fmt.Errorf(`unauthorized`) } @@ -286,7 +298,11 @@ func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionD // CreateIDToken util to create JWT token, based on // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT func CreateIDToken(user models.User, roles []string, hostname, nonce string) (string, int64, error) { - expiryBound, err := utils.ParseDurationInSeconds(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime)) + expireTime, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAccessTokenExpiryTime) + if err != nil { + return "", 0, err + } + expiryBound, err := utils.ParseDurationInSeconds(expireTime) if err != nil { expiryBound = time.Minute * 30 } @@ -298,10 +314,18 @@ func CreateIDToken(user models.User, roles []string, hostname, nonce string) (st var userMap map[string]interface{} json.Unmarshal(userBytes, &userMap) - claimKey := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtRoleClaim) + claimKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtRoleClaim) + if err != nil { + claimKey = "roles" + } + + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return "", 0, err + } customClaims := jwt.MapClaims{ "iss": hostname, - "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), + "aud": clientID, "nonce": nonce, "sub": user.ID, "exp": expiresAt, @@ -318,7 +342,11 @@ func CreateIDToken(user models.User, roles []string, hostname, nonce string) (st } // check for the extra access token script - accessTokenScript := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCustomAccessTokenScript) + accessTokenScript, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyCustomAccessTokenScript) + if err != nil { + log.Debug("Failed to get custom access token script: ", err) + accessTokenScript = "" + } if accessTokenScript != "" { vm := otto.New() diff --git a/server/token/jwt.go b/server/token/jwt.go index 0b87c09..350c2f5 100644 --- a/server/token/jwt.go +++ b/server/token/jwt.go @@ -5,13 +5,16 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/golang-jwt/jwt" ) // SignJWTToken common util to sing jwt token func SignJWTToken(claims jwt.MapClaims) (string, error) { - jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + jwtType, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + if err != nil { + return "", err + } signingMethod := jwt.GetSigningMethod(jwtType) if signingMethod == nil { return "", errors.New("unsupported signing method") @@ -24,15 +27,27 @@ func SignJWTToken(claims jwt.MapClaims) (string, error) { switch signingMethod { case jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512: - return t.SignedString([]byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret))) + jwtSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret) + if err != nil { + return "", err + } + return t.SignedString([]byte(jwtSecret)) case jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512: - key, err := crypto.ParseRsaPrivateKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey)) + jwtPrivateKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey) + if err != nil { + return "", err + } + key, err := crypto.ParseRsaPrivateKeyFromPemStr(jwtPrivateKey) if err != nil { return "", err } return t.SignedString(key) case jwt.SigningMethodES256, jwt.SigningMethodES384, jwt.SigningMethodES512: - key, err := crypto.ParseEcdsaPrivateKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey)) + jwtPrivateKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey) + if err != nil { + return "", err + } + key, err := crypto.ParseEcdsaPrivateKeyFromPemStr(jwtPrivateKey) if err != nil { return "", err } @@ -45,20 +60,30 @@ func SignJWTToken(claims jwt.MapClaims) (string, error) { // ParseJWTToken common util to parse jwt token func ParseJWTToken(token, hostname, nonce, subject string) (jwt.MapClaims, error) { - jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + jwtType, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + if err != nil { + return nil, err + } signingMethod := jwt.GetSigningMethod(jwtType) - var err error var claims jwt.MapClaims switch signingMethod { case jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512: _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { - return []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)), nil + jwtSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret) + if err != nil { + return nil, err + } + return []byte(jwtSecret), nil }) case jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512: _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { - key, err := crypto.ParseRsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + jwtPublicKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return nil, err + } + key, err := crypto.ParseRsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return nil, err } @@ -66,7 +91,11 @@ func ParseJWTToken(token, hostname, nonce, subject string) (jwt.MapClaims, error }) case jwt.SigningMethodES256, jwt.SigningMethodES384, jwt.SigningMethodES512: _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { - key, err := crypto.ParseEcdsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + jwtPublicKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return nil, err + } + key, err := crypto.ParseEcdsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return nil, err } @@ -86,8 +115,11 @@ func ParseJWTToken(token, hostname, nonce, subject string) (jwt.MapClaims, error intIat := int64(claims["iat"].(float64)) claims["exp"] = intExp claims["iat"] = intIat - - if claims["aud"] != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return claims, err + } + if claims["aud"] != clientID { return claims, errors.New("invalid audience") } @@ -109,20 +141,30 @@ func ParseJWTToken(token, hostname, nonce, subject string) (jwt.MapClaims, error // ParseJWTTokenWithoutNonce common util to parse jwt token without nonce // used to validate ID token as it is not persisted in store func ParseJWTTokenWithoutNonce(token, hostname string) (jwt.MapClaims, error) { - jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + jwtType, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtType) + if err != nil { + return nil, err + } signingMethod := jwt.GetSigningMethod(jwtType) - var err error var claims jwt.MapClaims switch signingMethod { case jwt.SigningMethodHS256, jwt.SigningMethodHS384, jwt.SigningMethodHS512: _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { - return []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret)), nil + jwtSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtSecret) + if err != nil { + return nil, err + } + return []byte(jwtSecret), nil }) case jwt.SigningMethodRS256, jwt.SigningMethodRS384, jwt.SigningMethodRS512: _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { - key, err := crypto.ParseRsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + jwtPublicKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return nil, err + } + key, err := crypto.ParseRsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return nil, err } @@ -130,7 +172,11 @@ func ParseJWTTokenWithoutNonce(token, hostname string) (jwt.MapClaims, error) { }) case jwt.SigningMethodES256, jwt.SigningMethodES384, jwt.SigningMethodES512: _, err = jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) { - key, err := crypto.ParseEcdsaPublicKeyFromPemStr(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey)) + jwtPublicKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) + if err != nil { + return nil, err + } + key, err := crypto.ParseEcdsaPublicKeyFromPemStr(jwtPublicKey) if err != nil { return nil, err } @@ -150,8 +196,11 @@ func ParseJWTTokenWithoutNonce(token, hostname string) (jwt.MapClaims, error) { intIat := int64(claims["iat"].(float64)) claims["exp"] = intExp claims["iat"] = intIat - - if claims["aud"] != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return claims, err + } + if claims["aud"] != clientID { return claims, errors.New("invalid audience") } diff --git a/server/token/verification_token.go b/server/token/verification_token.go index ceaccbc..75aef30 100644 --- a/server/token/verification_token.go +++ b/server/token/verification_token.go @@ -4,15 +4,19 @@ import ( "time" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/memorystore" "github.com/golang-jwt/jwt" ) // CreateVerificationToken creates a verification JWT token func CreateVerificationToken(email, tokenType, hostname, nonceHash, redirectURL string) (string, error) { + clientID, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyClientID) + if err != nil { + return "", err + } claims := jwt.MapClaims{ "iss": hostname, - "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), + "aud": clientID, "sub": email, "exp": time.Now().Add(time.Minute * 30).Unix(), "iat": time.Now().Unix(), diff --git a/server/types/interface_slice.go b/server/types/interface_slice.go new file mode 100644 index 0000000..6b33117 --- /dev/null +++ b/server/types/interface_slice.go @@ -0,0 +1,16 @@ +package types + +import "encoding/json" + +// Type for interface slice. Used for redis store. +type InterfaceSlice []interface{} + +// MarshalBinary for interface slice. +func (s InterfaceSlice) MarshalBinary() ([]byte, error) { + return json.Marshal(s) +} + +// UnmarshalBinary for interface slice. +func (s *InterfaceSlice) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, s) +} diff --git a/server/utils/common.go b/server/utils/common.go index badd7ea..86156be 100644 --- a/server/utils/common.go +++ b/server/utils/common.go @@ -47,3 +47,14 @@ func ConvertInterfaceToSlice(slice interface{}) []interface{} { return ret } + +// ConvertInterfaceToStringSlice to convert interface to string slice +func ConvertInterfaceToStringSlice(slice interface{}) []string { + data := slice.([]interface{}) + var resSlice []string + + for _, v := range data { + resSlice = append(resSlice, v.(string)) + } + return resSlice +} diff --git a/server/utils/meta.go b/server/utils/meta.go deleted file mode 100644 index 18588c4..0000000 --- a/server/utils/meta.go +++ /dev/null @@ -1,22 +0,0 @@ -package utils - -import ( - "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/graph/model" -) - -// GetMeta helps in getting the meta data about the deployment from EnvData -func GetMetaInfo() model.Meta { - return model.Meta{ - Version: constants.VERSION, - ClientID: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), - IsGoogleLoginEnabled: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret) != "", - IsGithubLoginEnabled: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGithubClientSecret) != "", - IsFacebookLoginEnabled: envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientID) != "" && envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyFacebookClientSecret) != "", - IsBasicAuthenticationEnabled: !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication), - IsEmailVerificationEnabled: !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification), - IsMagicLinkLoginEnabled: !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableMagicLinkLogin), - IsSignUpEnabled: !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp), - } -} diff --git a/server/utils/refs.go b/server/utils/refs.go new file mode 100644 index 0000000..3dd87cd --- /dev/null +++ b/server/utils/refs.go @@ -0,0 +1,6 @@ +package utils + +// NewStringRef returns a reference to a string with given value +func NewStringRef(v string) *string { + return &v +} diff --git a/server/utils/validator.go b/server/utils/validator.go deleted file mode 100644 index 280d611..0000000 --- a/server/utils/validator.go +++ /dev/null @@ -1,120 +0,0 @@ -package utils - -import ( - "net/mail" - "regexp" - "strings" - - "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" -) - -// IsValidEmail validates email -func IsValidEmail(email string) bool { - _, err := mail.ParseAddress(email) - return err == nil -} - -// IsValidOrigin validates origin based on ALLOWED_ORIGINS -func IsValidOrigin(url string) bool { - allowedOrigins := envstore.EnvStoreObj.GetSliceStoreEnvVariable(constants.EnvKeyAllowedOrigins) - if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" { - return true - } - - hasValidURL := false - hostName, port := GetHostParts(url) - currentOrigin := hostName + ":" + port - - for _, origin := range allowedOrigins { - replacedString := origin - // if has regex whitelisted domains - if strings.Contains(origin, "*") { - replacedString = strings.Replace(origin, ".", "\\.", -1) - replacedString = strings.Replace(replacedString, "*", ".*", -1) - - if strings.HasPrefix(replacedString, ".*") { - replacedString += "\\b" - } - - if strings.HasSuffix(replacedString, ".*") { - replacedString = "\\b" + replacedString - } - } - - if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched { - hasValidURL = true - break - } - } - - return hasValidURL -} - -// IsValidRoles validates roles -func IsValidRoles(userRoles []string, roles []string) bool { - valid := true - for _, userRole := range userRoles { - if !StringSliceContains(roles, userRole) { - valid = false - break - } - } - - return valid -} - -// IsValidVerificationIdentifier validates verification identifier that is used to identify -// the type of verification request -func IsValidVerificationIdentifier(identifier string) bool { - if identifier != constants.VerificationTypeBasicAuthSignup && identifier != constants.VerificationTypeForgotPassword && identifier != constants.VerificationTypeUpdateEmail { - return false - } - return true -} - -// IsStringArrayEqual validates if string array are equal. -// This does check if the order is same -func IsStringArrayEqual(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i, v := range a { - if v != b[i] { - return false - } - } - return true -} - -// ValidatePassword to validate the password against the following policy -// min char length: 6 -// max char length: 36 -// at least one upper case letter -// at least one lower case letter -// at least one digit -// at least one special character -func IsValidPassword(password string) bool { - if len(password) < 6 || len(password) > 36 { - return false - } - - hasUpperCase := false - hasLowerCase := false - hasDigit := false - hasSpecialChar := false - - for _, char := range password { - if char >= 'A' && char <= 'Z' { - hasUpperCase = true - } else if char >= 'a' && char <= 'z' { - hasLowerCase = true - } else if char >= '0' && char <= '9' { - hasDigit = true - } else { - hasSpecialChar = true - } - } - - return hasUpperCase && hasLowerCase && hasDigit && hasSpecialChar -} diff --git a/server/validators/common.go b/server/validators/common.go new file mode 100644 index 0000000..348df00 --- /dev/null +++ b/server/validators/common.go @@ -0,0 +1,15 @@ +package validators + +// IsStringArrayEqual validates if string array are equal. +// This does check if the order is same +func IsStringArrayEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/server/validators/email.go b/server/validators/email.go new file mode 100644 index 0000000..deb98ed --- /dev/null +++ b/server/validators/email.go @@ -0,0 +1,9 @@ +package validators + +import "net/mail" + +// IsValidEmail validates email +func IsValidEmail(email string) bool { + _, err := mail.ParseAddress(email) + return err == nil +} diff --git a/server/validators/password.go b/server/validators/password.go new file mode 100644 index 0000000..7925dab --- /dev/null +++ b/server/validators/password.go @@ -0,0 +1,33 @@ +package validators + +// ValidatePassword to validate the password against the following policy +// min char length: 6 +// max char length: 36 +// at least one upper case letter +// at least one lower case letter +// at least one digit +// at least one special character +func IsValidPassword(password string) bool { + if len(password) < 6 || len(password) > 36 { + return false + } + + hasUpperCase := false + hasLowerCase := false + hasDigit := false + hasSpecialChar := false + + for _, char := range password { + if char >= 'A' && char <= 'Z' { + hasUpperCase = true + } else if char >= 'a' && char <= 'z' { + hasLowerCase = true + } else if char >= '0' && char <= '9' { + hasDigit = true + } else { + hasSpecialChar = true + } + } + + return hasUpperCase && hasLowerCase && hasDigit && hasSpecialChar +} diff --git a/server/validators/roles.go b/server/validators/roles.go new file mode 100644 index 0000000..9a43a51 --- /dev/null +++ b/server/validators/roles.go @@ -0,0 +1,16 @@ +package validators + +import "github.com/authorizerdev/authorizer/server/utils" + +// IsValidRoles validates roles +func IsValidRoles(userRoles []string, roles []string) bool { + valid := true + for _, userRole := range userRoles { + if !utils.StringSliceContains(roles, userRole) { + valid = false + break + } + } + + return valid +} diff --git a/server/validators/url.go b/server/validators/url.go new file mode 100644 index 0000000..eae9aa3 --- /dev/null +++ b/server/validators/url.go @@ -0,0 +1,52 @@ +package validators + +import ( + "regexp" + "strings" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/memorystore" + "github.com/authorizerdev/authorizer/server/parsers" +) + +// IsValidOrigin validates origin based on ALLOWED_ORIGINS +func IsValidOrigin(url string) bool { + allowedOriginsString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAllowedOrigins) + allowedOrigins := []string{} + if err != nil { + allowedOrigins = []string{"*"} + } else { + allowedOrigins = strings.Split(allowedOriginsString, ",") + } + if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" { + return true + } + + hasValidURL := false + hostName, port := parsers.GetHostParts(url) + currentOrigin := hostName + ":" + port + + for _, origin := range allowedOrigins { + replacedString := origin + // if has regex whitelisted domains + if strings.Contains(origin, "*") { + replacedString = strings.Replace(origin, ".", "\\.", -1) + replacedString = strings.Replace(replacedString, "*", ".*", -1) + + if strings.HasPrefix(replacedString, ".*") { + replacedString += "\\b" + } + + if strings.HasSuffix(replacedString, ".*") { + replacedString = "\\b" + replacedString + } + } + + if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched { + hasValidURL = true + break + } + } + + return hasValidURL +} diff --git a/server/validators/verification_requests.go b/server/validators/verification_requests.go new file mode 100644 index 0000000..48c9183 --- /dev/null +++ b/server/validators/verification_requests.go @@ -0,0 +1,12 @@ +package validators + +import "github.com/authorizerdev/authorizer/server/constants" + +// IsValidVerificationIdentifier validates verification identifier that is used to identify +// the type of verification request +func IsValidVerificationIdentifier(identifier string) bool { + if identifier != constants.VerificationTypeBasicAuthSignup && identifier != constants.VerificationTypeForgotPassword && identifier != constants.VerificationTypeUpdateEmail { + return false + } + return true +}