diff --git a/server/constants/token_types.go b/server/constants/token_types.go index 993d4ea..df671bf 100644 --- a/server/constants/token_types.go +++ b/server/constants/token_types.go @@ -5,4 +5,6 @@ const ( TokenTypeRefreshToken = "refresh_token" // TokenTypeAccessToken is the access_token token type TokenTypeAccessToken = "access_token" + // TokenTypeIdentityToken is the identity_token token type + TokenTypeIdentityToken = "id_token" ) diff --git a/server/cookie/cookie.go b/server/cookie/cookie.go index f90dd21..42445fb 100644 --- a/server/cookie/cookie.go +++ b/server/cookie/cookie.go @@ -2,7 +2,6 @@ package cookie import ( "net/http" - "net/url" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/envstore" @@ -10,8 +9,8 @@ import ( "github.com/gin-gonic/gin" ) -// SetSessionCookie sets the session cookie in the response -func SetSessionCookie(gc *gin.Context, sessionID string) { +// SetSession sets the session cookie in the response +func SetSession(gc *gin.Context, sessionID string) { secure := true httpOnly := true hostname := utils.GetHost(gc) @@ -26,21 +25,16 @@ func SetSessionCookie(gc *gin.Context, sessionID string) { gc.SetSameSite(http.SameSiteNoneMode) gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", sessionID, year, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session.domain", sessionID, year, "/", domain, secure, httpOnly) + gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_domain", sessionID, year, "/", domain, secure, httpOnly) // Fallback cookie for anomaly getection on browsers that don’t support the sameSite=None attribute. gc.SetSameSite(http.SameSiteDefaultMode) gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_compat", sessionID, year, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session.domain_compat", sessionID, year, "/", domain, secure, httpOnly) + gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_domain_compat", sessionID, year, "/", domain, secure, httpOnly) } -// SetCookie sets the cookie in the response. It sets 4 cookies -// 1 COOKIE_NAME.access_token jwt token for the host (temp.abc.com) -// 2 COOKIE_NAME.access_token.domain jwt token for the domain (abc.com). -// 3 COOKIE_NAME.fingerprint fingerprint hash for the refresh token verification. -// 4 COOKIE_NAME.refresh_token refresh token -// Note all sites don't allow 2nd type of cookie -func SetCookie(gc *gin.Context, accessToken, refreshToken, fingerprintHash string) { +// DeleteSession sets session cookies to expire +func DeleteSession(gc *gin.Context) { secure := true httpOnly := true hostname := utils.GetHost(gc) @@ -50,77 +44,33 @@ func SetCookie(gc *gin.Context, accessToken, refreshToken, fingerprintHash strin domain = "." + domain } - year := 60 * 60 * 24 * 365 - thirtyMin := 60 * 30 - gc.SetSameSite(http.SameSiteNoneMode) - // set cookie for host - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token", accessToken, thirtyMin, "/", host, secure, httpOnly) + gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", "", -1, "/", host, secure, httpOnly) + gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_domain", "", -1, "/", domain, secure, httpOnly) - // in case of subdomain, set cookie for domain - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token.domain", accessToken, thirtyMin, "/", domain, secure, httpOnly) - - // set finger print cookie (this should be accessed via cookie only) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".fingerprint", fingerprintHash, year, "/", host, secure, httpOnly) - - // set refresh token cookie (this should be accessed via cookie only) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".refresh_token", refreshToken, year, "/", host, secure, httpOnly) + gc.SetSameSite(http.SameSiteDefaultMode) + gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session_compat", "", -1, "/", host, secure, httpOnly) + gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session.domain_compat", "", -1, "/", domain, secure, httpOnly) } -// GetAccessTokenCookie to get access token cookie from the request -func GetAccessTokenCookie(gc *gin.Context) (string, error) { - cookie, err := gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + ".access_token") +// GetSession gets the session cookie from context +func GetSession(gc *gin.Context) (string, error) { + var cookie *http.Cookie + var err error + cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session") if err != nil { - cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + ".access_token.domain") + cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session_domain") if err != nil { - return "", err + cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session_compat") + if err != nil { + cookie, err = gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + "_session_domain_compat") + } + + if err != nil { + return "", err + } } } return cookie.Value, nil } - -// GetRefreshTokenCookie to get refresh token cookie -func GetRefreshTokenCookie(gc *gin.Context) (string, error) { - cookie, err := gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + ".refresh_token") - if err != nil { - return "", err - } - - return cookie.Value, nil -} - -// GetFingerPrintCookie to get fingerprint cookie -func GetFingerPrintCookie(gc *gin.Context) (string, error) { - cookie, err := gc.Request.Cookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName) + ".fingerprint") - if err != nil { - return "", err - } - - // cookie escapes special characters like $ - // hence we need to unescape before comparing - decodedValue, err := url.QueryUnescape(cookie.Value) - if err != nil { - return "", err - } - - return decodedValue, nil -} - -// DeleteCookie sets response cookies to expire -func DeleteCookie(gc *gin.Context) { - secure := true - httpOnly := true - hostname := utils.GetHost(gc) - host, _ := utils.GetHostParts(hostname) - domain := utils.GetDomainName(hostname) - if domain != "localhost" { - domain = "." + domain - } - - gc.SetSameSite(http.SameSiteNoneMode) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token", "", -1, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token.domain", "", -1, "/", domain, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".fingerprint", "", -1, "/", host, secure, httpOnly) - gc.SetCookie(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".refresh_token", "", -1, "/", host, secure, httpOnly) -} diff --git a/server/crypto/aes.go b/server/crypto/aes.go index 5d0bed4..2750486 100644 --- a/server/crypto/aes.go +++ b/server/crypto/aes.go @@ -3,71 +3,40 @@ package crypto import ( "crypto/aes" "crypto/cipher" - "crypto/rand" - "io" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/envstore" ) -// EncryptAES encrypts data using AES algorithm -func EncryptAES(text []byte) ([]byte, error) { +var bytes = []byte{35, 46, 57, 24, 85, 35, 24, 74, 87, 35, 88, 98, 66, 32, 14, 05} + +// EncryptAES method is to encrypt or hide any classified text +func EncryptAES(text string) (string, error) { key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) - c, err := aes.NewCipher(key) - var res []byte + block, err := aes.NewCipher(key) if err != nil { - return res, err + return "", err } - - // gcm or Galois/Counter Mode, is a mode of operation - // for symmetric key cryptographic block ciphers - // - https://en.wikipedia.org/wiki/Galois/Counter_Mode - gcm, err := cipher.NewGCM(c) - if err != nil { - return res, err - } - - // creates a new byte array the size of the nonce - // which must be passed to Seal - nonce := make([]byte, gcm.NonceSize()) - // populates our nonce with a cryptographically secure - // random sequence - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - return res, err - } - - // here we encrypt our text using the Seal function - // Seal encrypts and authenticates plaintext, authenticates the - // additional data and appends the result to dst, returning the updated - // slice. The nonce must be NonceSize() bytes long and unique for all - // time, for a given key. - return gcm.Seal(nonce, nonce, text, nil), nil + plainText := []byte(text) + cfb := cipher.NewCFBEncrypter(block, bytes) + cipherText := make([]byte, len(plainText)) + cfb.XORKeyStream(cipherText, plainText) + return EncryptB64(string(cipherText)), nil } -// DecryptAES decrypts data using AES algorithm -func DecryptAES(ciphertext []byte) ([]byte, error) { +// DecryptAES method is to extract back the encrypted text +func DecryptAES(text string) (string, error) { key := []byte(envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEncryptionKey)) - c, err := aes.NewCipher(key) - var res []byte + block, err := aes.NewCipher(key) if err != nil { - return res, err + return "", err } - - gcm, err := cipher.NewGCM(c) + cipherText, err := DecryptB64(text) if err != nil { - return res, err + return "", err } - - nonceSize := gcm.NonceSize() - if len(ciphertext) < nonceSize { - return res, err - } - - nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] - plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return res, err - } - - return plaintext, nil + cfb := cipher.NewCFBDecrypter(block, bytes) + plainText := make([]byte, len(cipherText)) + cfb.XORKeyStream(plainText, []byte(cipherText)) + return string(plainText), nil } diff --git a/server/crypto/common.go b/server/crypto/common.go index 45f8917..df3754c 100644 --- a/server/crypto/common.go +++ b/server/crypto/common.go @@ -94,7 +94,7 @@ func EncryptEnvData(data envstore.Store) (string, error) { if err != nil { return "", err } - encryptedConfig, err := EncryptAES(configData) + encryptedConfig, err := EncryptAES(string(configData)) if err != nil { return "", err } diff --git a/server/db/models/verification_requests.go b/server/db/models/verification_requests.go index 5c23301..eec5427 100644 --- a/server/db/models/verification_requests.go +++ b/server/db/models/verification_requests.go @@ -12,6 +12,7 @@ type VerificationRequest struct { CreatedAt int64 `json:"created_at" bson:"created_at"` UpdatedAt int64 `json:"updated_at" bson:"updated_at"` Email string `gorm:"uniqueIndex:idx_email_identifier" json:"email" bson:"email"` + Nonce string `gorm:"type:char(36)" json:"nonce" bson:"nonce"` } func (v *VerificationRequest) AsAPIVerificationRequest() *model.VerificationRequest { diff --git a/server/db/providers/sql/sql.go b/server/db/providers/sql/sql.go index 28af527..279b707 100644 --- a/server/db/providers/sql/sql.go +++ b/server/db/providers/sql/sql.go @@ -56,7 +56,10 @@ func NewProvider() (*provider, error) { return nil, err } - sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{}) + err = sqlDB.AutoMigrate(&models.User{}, &models.VerificationRequest{}, &models.Session{}, &models.Env{}) + if err != nil { + return nil, err + } return &provider{ db: sqlDB, }, nil diff --git a/server/email/email.go b/server/email/email.go index 2c57c3c..7b423d2 100644 --- a/server/email/email.go +++ b/server/email/email.go @@ -31,6 +31,10 @@ func addEmailTemplate(a string, b map[string]interface{}, templateName string) s // SendMail function to send mail func SendMail(to []string, Subject, bodyMessage string) error { + // dont trigger email sending in case of test + if envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyEnv) == "test" { + return nil + } m := gomail.NewMessage() m.SetHeader("From", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeySenderEmail)) m.SetHeader("To", to...) diff --git a/server/env/persist_env.go b/server/env/persist_env.go index a7fba71..24df138 100644 --- a/server/env/persist_env.go +++ b/server/env/persist_env.go @@ -33,17 +33,13 @@ func GetEnvData() (envstore.Store, error) { } envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, decryptedEncryptionKey) - b64DecryptedConfig, err := crypto.DecryptB64(env.EnvData) + + decryptedConfigs, err := crypto.DecryptAES(env.EnvData) if err != nil { return result, err } - decryptedConfigs, err := crypto.DecryptAES([]byte(b64DecryptedConfig)) - if err != nil { - return result, err - } - - err = json.Unmarshal(decryptedConfigs, &result) + err = json.Unmarshal([]byte(decryptedConfigs), &result) if err != nil { return result, err } @@ -85,12 +81,8 @@ func PersistEnv() error { } envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEncryptionKey, decryptedEncryptionKey) - b64DecryptedConfig, err := crypto.DecryptB64(env.EnvData) - if err != nil { - return err - } - decryptedConfigs, err := crypto.DecryptAES([]byte(b64DecryptedConfig)) + decryptedConfigs, err := crypto.DecryptAES(env.EnvData) if err != nil { return err } @@ -98,7 +90,7 @@ func PersistEnv() error { // temp store variable var storeData envstore.Store - err = json.Unmarshal(decryptedConfigs, &storeData) + err = json.Unmarshal([]byte(decryptedConfigs), &storeData) if err != nil { return err } diff --git a/server/graph/generated/generated.go b/server/graph/generated/generated.go index 39e22a8..69b021f 100644 --- a/server/graph/generated/generated.go +++ b/server/graph/generated/generated.go @@ -44,10 +44,12 @@ type DirectiveRoot struct { type ComplexityRoot struct { AuthResponse struct { - AccessToken func(childComplexity int) int - ExpiresAt func(childComplexity int) int - Message func(childComplexity int) int - User func(childComplexity int) int + AccessToken func(childComplexity int) int + ExpiresIn func(childComplexity int) int + IDToken func(childComplexity int) int + Message func(childComplexity int) int + RefreshToken func(childComplexity int) int + User func(childComplexity int) int } Env struct { @@ -134,7 +136,6 @@ type ComplexityRoot struct { Query struct { AdminSession func(childComplexity int) int Env func(childComplexity int) int - IsValidJwt func(childComplexity int, params *model.IsValidJWTQueryInput) int Meta func(childComplexity int) int Profile func(childComplexity int) int Session func(childComplexity int, params *model.SessionQueryInput) int @@ -171,11 +172,6 @@ type ComplexityRoot struct { Users func(childComplexity int) int } - ValidJWTResponse struct { - Message func(childComplexity int) int - Valid func(childComplexity int) int - } - VerificationRequest struct { CreatedAt func(childComplexity int) int Email func(childComplexity int) int @@ -212,7 +208,6 @@ type MutationResolver interface { type QueryResolver interface { Meta(ctx context.Context) (*model.Meta, error) Session(ctx context.Context, params *model.SessionQueryInput) (*model.AuthResponse, error) - IsValidJwt(ctx context.Context, params *model.IsValidJWTQueryInput) (*model.ValidJWTResponse, error) Profile(ctx context.Context) (*model.User, error) Users(ctx context.Context, params *model.PaginatedInput) (*model.Users, error) VerificationRequests(ctx context.Context, params *model.PaginatedInput) (*model.VerificationRequests, error) @@ -242,12 +237,19 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.AuthResponse.AccessToken(childComplexity), true - case "AuthResponse.expires_at": - if e.complexity.AuthResponse.ExpiresAt == nil { + case "AuthResponse.expires_in": + if e.complexity.AuthResponse.ExpiresIn == nil { break } - return e.complexity.AuthResponse.ExpiresAt(childComplexity), true + return e.complexity.AuthResponse.ExpiresIn(childComplexity), true + + case "AuthResponse.id_token": + if e.complexity.AuthResponse.IDToken == nil { + break + } + + return e.complexity.AuthResponse.IDToken(childComplexity), true case "AuthResponse.message": if e.complexity.AuthResponse.Message == nil { @@ -256,6 +258,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.AuthResponse.Message(childComplexity), true + case "AuthResponse.refresh_token": + if e.complexity.AuthResponse.RefreshToken == nil { + break + } + + return e.complexity.AuthResponse.RefreshToken(childComplexity), true + case "AuthResponse.user": if e.complexity.AuthResponse.User == nil { break @@ -804,18 +813,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Env(childComplexity), true - case "Query.is_valid_jwt": - if e.complexity.Query.IsValidJwt == nil { - break - } - - args, err := ec.field_Query_is_valid_jwt_args(context.TODO(), rawArgs) - if err != nil { - return 0, false - } - - return e.complexity.Query.IsValidJwt(childComplexity, args["params"].(*model.IsValidJWTQueryInput)), true - case "Query.meta": if e.complexity.Query.Meta == nil { break @@ -1006,20 +1003,6 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Users.Users(childComplexity), true - case "ValidJWTResponse.message": - if e.complexity.ValidJWTResponse.Message == nil { - break - } - - return e.complexity.ValidJWTResponse.Message(childComplexity), true - - case "ValidJWTResponse.valid": - if e.complexity.ValidJWTResponse.Valid == nil { - break - } - - return e.complexity.ValidJWTResponse.Valid(childComplexity), true - case "VerificationRequest.created_at": if e.complexity.VerificationRequest.CreatedAt == nil { break @@ -1221,7 +1204,9 @@ type Error { type AuthResponse { message: String! access_token: String - expires_at: Int64 + id_token: String + refresh_token: String + expires_in: Int64 user: User } @@ -1229,11 +1214,6 @@ type Response { message: String! } -type ValidJWTResponse { - valid: Boolean! - message: String! -} - type Env { ADMIN_SECRET: String DATABASE_NAME: String! @@ -1337,6 +1317,7 @@ input LoginInput { email: String! password: String! roles: [String!] + scope: [String!] } input VerifyEmailInput { @@ -1395,15 +1376,12 @@ input DeleteUserInput { input MagicLinkLoginInput { email: String! roles: [String!] + scope: [String!] } input SessionQueryInput { roles: [String!] -} - -input IsValidJWTQueryInput { - jwt: String - roles: [String!] + scope: [String!] } input PaginationInput { @@ -1437,7 +1415,6 @@ type Mutation { type Query { meta: Meta! session(params: SessionQueryInput): AuthResponse! - is_valid_jwt(params: IsValidJWTQueryInput): ValidJWTResponse! profile: User! # admin only apis _users(params: PaginatedInput): Users! @@ -1693,21 +1670,6 @@ func (ec *executionContext) field_Query__verification_requests_args(ctx context. return args, nil } -func (ec *executionContext) field_Query_is_valid_jwt_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { - var err error - args := map[string]interface{}{} - var arg0 *model.IsValidJWTQueryInput - if tmp, ok := rawArgs["params"]; ok { - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("params")) - arg0, err = ec.unmarshalOIsValidJWTQueryInput2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐIsValidJWTQueryInput(ctx, tmp) - if err != nil { - return nil, err - } - } - args["params"] = arg0 - return args, nil -} - func (ec *executionContext) field_Query_session_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -1828,7 +1790,7 @@ func (ec *executionContext) _AuthResponse_access_token(ctx context.Context, fiel return ec.marshalOString2ᚖstring(ctx, field.Selections, res) } -func (ec *executionContext) _AuthResponse_expires_at(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { +func (ec *executionContext) _AuthResponse_id_token(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { ec.Error(ctx, ec.Recover(ctx, r)) @@ -1846,7 +1808,71 @@ func (ec *executionContext) _AuthResponse_expires_at(ctx context.Context, field ctx = graphql.WithFieldContext(ctx, fc) resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.ExpiresAt, nil + return obj.IDToken, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) _AuthResponse_refresh_token(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "AuthResponse", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.RefreshToken, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) _AuthResponse_expires_in(ctx context.Context, field graphql.CollectedField, obj *model.AuthResponse) (ret graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + fc := &graphql.FieldContext{ + Object: "AuthResponse", + Field: field, + Args: nil, + IsMethod: false, + IsResolver: false, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.ExpiresIn, nil }) if err != nil { ec.Error(ctx, err) @@ -4274,48 +4300,6 @@ func (ec *executionContext) _Query_session(ctx context.Context, field graphql.Co return ec.marshalNAuthResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐAuthResponse(ctx, field.Selections, res) } -func (ec *executionContext) _Query_is_valid_jwt(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - fc := &graphql.FieldContext{ - Object: "Query", - Field: field, - Args: nil, - IsMethod: true, - IsResolver: true, - } - - ctx = graphql.WithFieldContext(ctx, fc) - rawArgs := field.ArgumentMap(ec.Variables) - args, err := ec.field_Query_is_valid_jwt_args(ctx, rawArgs) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - fc.Args = args - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().IsValidJwt(rctx, args["params"].(*model.IsValidJWTQueryInput)) - }) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } - return graphql.Null - } - res := resTmp.(*model.ValidJWTResponse) - fc.Result = res - return ec.marshalNValidJWTResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidJWTResponse(ctx, field.Selections, res) -} - func (ec *executionContext) _Query_profile(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -5240,76 +5224,6 @@ func (ec *executionContext) _Users_users(ctx context.Context, field graphql.Coll return ec.marshalNUser2ᚕᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐUserᚄ(ctx, field.Selections, res) } -func (ec *executionContext) _ValidJWTResponse_valid(ctx context.Context, field graphql.CollectedField, obj *model.ValidJWTResponse) (ret graphql.Marshaler) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - fc := &graphql.FieldContext{ - Object: "ValidJWTResponse", - Field: field, - Args: nil, - IsMethod: false, - IsResolver: false, - } - - ctx = graphql.WithFieldContext(ctx, fc) - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return obj.Valid, nil - }) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } - return graphql.Null - } - res := resTmp.(bool) - fc.Result = res - return ec.marshalNBoolean2bool(ctx, field.Selections, res) -} - -func (ec *executionContext) _ValidJWTResponse_message(ctx context.Context, field graphql.CollectedField, obj *model.ValidJWTResponse) (ret graphql.Marshaler) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - ret = graphql.Null - } - }() - fc := &graphql.FieldContext{ - Object: "ValidJWTResponse", - Field: field, - Args: nil, - IsMethod: false, - IsResolver: false, - } - - ctx = graphql.WithFieldContext(ctx, fc) - resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { - ctx = rctx // use context from middleware stack in children - return obj.Message, nil - }) - if err != nil { - ec.Error(ctx, err) - return graphql.Null - } - if resTmp == nil { - if !graphql.HasFieldError(ctx, fc) { - ec.Errorf(ctx, "must not be null") - } - return graphql.Null - } - res := resTmp.(string) - fc.Result = res - return ec.marshalNString2string(ctx, field.Selections, res) -} - func (ec *executionContext) _VerificationRequest_id(ctx context.Context, field graphql.CollectedField, obj *model.VerificationRequest) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -6821,37 +6735,6 @@ func (ec *executionContext) unmarshalInputForgotPasswordInput(ctx context.Contex return it, nil } -func (ec *executionContext) unmarshalInputIsValidJWTQueryInput(ctx context.Context, obj interface{}) (model.IsValidJWTQueryInput, error) { - var it model.IsValidJWTQueryInput - asMap := map[string]interface{}{} - for k, v := range obj.(map[string]interface{}) { - asMap[k] = v - } - - for k, v := range asMap { - switch k { - case "jwt": - var err error - - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("jwt")) - it.Jwt, err = ec.unmarshalOString2ᚖstring(ctx, v) - if err != nil { - return it, err - } - case "roles": - var err error - - ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("roles")) - it.Roles, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) - if err != nil { - return it, err - } - } - } - - return it, nil -} - func (ec *executionContext) unmarshalInputLoginInput(ctx context.Context, obj interface{}) (model.LoginInput, error) { var it model.LoginInput asMap := map[string]interface{}{} @@ -6885,6 +6768,14 @@ func (ec *executionContext) unmarshalInputLoginInput(ctx context.Context, obj in if err != nil { return it, err } + case "scope": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scope")) + it.Scope, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } } } @@ -6916,6 +6807,14 @@ func (ec *executionContext) unmarshalInputMagicLinkLoginInput(ctx context.Contex if err != nil { return it, err } + case "scope": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scope")) + it.Scope, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } } } @@ -7063,6 +6962,14 @@ func (ec *executionContext) unmarshalInputSessionQueryInput(ctx context.Context, if err != nil { return it, err } + case "scope": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("scope")) + it.Scope, err = ec.unmarshalOString2ᚕstringᚄ(ctx, v) + if err != nil { + return it, err + } } } @@ -7730,8 +7637,12 @@ func (ec *executionContext) _AuthResponse(ctx context.Context, sel ast.Selection } case "access_token": out.Values[i] = ec._AuthResponse_access_token(ctx, field, obj) - case "expires_at": - out.Values[i] = ec._AuthResponse_expires_at(ctx, field, obj) + case "id_token": + out.Values[i] = ec._AuthResponse_id_token(ctx, field, obj) + case "refresh_token": + out.Values[i] = ec._AuthResponse_refresh_token(ctx, field, obj) + case "expires_in": + out.Values[i] = ec._AuthResponse_expires_in(ctx, field, obj) case "user": out.Values[i] = ec._AuthResponse_user(ctx, field, obj) default: @@ -8136,20 +8047,6 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr } return res }) - case "is_valid_jwt": - field := field - out.Concurrently(i, func() (res graphql.Marshaler) { - defer func() { - if r := recover(); r != nil { - ec.Error(ctx, ec.Recover(ctx, r)) - } - }() - res = ec._Query_is_valid_jwt(ctx, field) - if res == graphql.Null { - atomic.AddUint32(&invalids, 1) - } - return res - }) case "profile": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -8365,38 +8262,6 @@ func (ec *executionContext) _Users(ctx context.Context, sel ast.SelectionSet, ob return out } -var validJWTResponseImplementors = []string{"ValidJWTResponse"} - -func (ec *executionContext) _ValidJWTResponse(ctx context.Context, sel ast.SelectionSet, obj *model.ValidJWTResponse) graphql.Marshaler { - fields := graphql.CollectFields(ec.OperationContext, sel, validJWTResponseImplementors) - - out := graphql.NewFieldSet(fields) - var invalids uint32 - for i, field := range fields { - switch field.Name { - case "__typename": - out.Values[i] = graphql.MarshalString("ValidJWTResponse") - case "valid": - out.Values[i] = ec._ValidJWTResponse_valid(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } - case "message": - out.Values[i] = ec._ValidJWTResponse_message(ctx, field, obj) - if out.Values[i] == graphql.Null { - invalids++ - } - default: - panic("unknown field " + strconv.Quote(field.Name)) - } - } - out.Dispatch() - if invalids > 0 { - return graphql.Null - } - return out -} - var verificationRequestImplementors = []string{"VerificationRequest"} func (ec *executionContext) _VerificationRequest(ctx context.Context, sel ast.SelectionSet, obj *model.VerificationRequest) graphql.Marshaler { @@ -9012,20 +8877,6 @@ func (ec *executionContext) marshalNUsers2ᚖgithubᚗcomᚋauthorizerdevᚋauth return ec._Users(ctx, sel, v) } -func (ec *executionContext) marshalNValidJWTResponse2githubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidJWTResponse(ctx context.Context, sel ast.SelectionSet, v model.ValidJWTResponse) graphql.Marshaler { - return ec._ValidJWTResponse(ctx, sel, &v) -} - -func (ec *executionContext) marshalNValidJWTResponse2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐValidJWTResponse(ctx context.Context, sel ast.SelectionSet, v *model.ValidJWTResponse) graphql.Marshaler { - if v == nil { - if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) { - ec.Errorf(ctx, "must not be null") - } - return graphql.Null - } - return ec._ValidJWTResponse(ctx, sel, v) -} - func (ec *executionContext) marshalNVerificationRequest2ᚕᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐVerificationRequestᚄ(ctx context.Context, sel ast.SelectionSet, v []*model.VerificationRequest) graphql.Marshaler { ret := make(graphql.Array, len(v)) var wg sync.WaitGroup @@ -9395,14 +9246,6 @@ func (ec *executionContext) marshalOInt642ᚖint64(ctx context.Context, sel ast. return graphql.MarshalInt64(*v) } -func (ec *executionContext) unmarshalOIsValidJWTQueryInput2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐIsValidJWTQueryInput(ctx context.Context, v interface{}) (*model.IsValidJWTQueryInput, error) { - if v == nil { - return nil, nil - } - res, err := ec.unmarshalInputIsValidJWTQueryInput(ctx, v) - return &res, graphql.ErrorOnPath(ctx, err) -} - func (ec *executionContext) unmarshalOPaginatedInput2ᚖgithubᚗcomᚋauthorizerdevᚋauthorizerᚋserverᚋgraphᚋmodelᚐPaginatedInput(ctx context.Context, v interface{}) (*model.PaginatedInput, error) { if v == nil { return nil, nil diff --git a/server/graph/model/models_gen.go b/server/graph/model/models_gen.go index b5213b2..5134dda 100644 --- a/server/graph/model/models_gen.go +++ b/server/graph/model/models_gen.go @@ -11,10 +11,12 @@ type AdminSignupInput struct { } type AuthResponse struct { - Message string `json:"message"` - AccessToken *string `json:"access_token"` - ExpiresAt *int64 `json:"expires_at"` - User *User `json:"user"` + Message string `json:"message"` + AccessToken *string `json:"access_token"` + IDToken *string `json:"id_token"` + RefreshToken *string `json:"refresh_token"` + ExpiresIn *int64 `json:"expires_in"` + User *User `json:"user"` } type DeleteUserInput struct { @@ -70,20 +72,17 @@ type ForgotPasswordInput struct { Email string `json:"email"` } -type IsValidJWTQueryInput struct { - Jwt *string `json:"jwt"` - Roles []string `json:"roles"` -} - type LoginInput struct { Email string `json:"email"` Password string `json:"password"` Roles []string `json:"roles"` + Scope []string `json:"scope"` } type MagicLinkLoginInput struct { Email string `json:"email"` Roles []string `json:"roles"` + Scope []string `json:"scope"` } type Meta struct { @@ -130,6 +129,7 @@ type Response struct { type SessionQueryInput struct { Roles []string `json:"roles"` + Scope []string `json:"scope"` } type SignUpInput struct { @@ -238,11 +238,6 @@ type Users struct { Users []*User `json:"users"` } -type ValidJWTResponse struct { - Valid bool `json:"valid"` - Message string `json:"message"` -} - type VerificationRequest struct { ID string `json:"id"` Identifier *string `json:"identifier"` diff --git a/server/graph/schema.graphqls b/server/graph/schema.graphqls index cca567e..cf01374 100644 --- a/server/graph/schema.graphqls +++ b/server/graph/schema.graphqls @@ -72,7 +72,9 @@ type Error { type AuthResponse { message: String! access_token: String - expires_at: Int64 + id_token: String + refresh_token: String + expires_in: Int64 user: User } @@ -80,11 +82,6 @@ type Response { message: String! } -type ValidJWTResponse { - valid: Boolean! - message: String! -} - type Env { ADMIN_SECRET: String DATABASE_NAME: String! @@ -188,6 +185,7 @@ input LoginInput { email: String! password: String! roles: [String!] + scope: [String!] } input VerifyEmailInput { @@ -246,15 +244,12 @@ input DeleteUserInput { input MagicLinkLoginInput { email: String! roles: [String!] + scope: [String!] } input SessionQueryInput { roles: [String!] -} - -input IsValidJWTQueryInput { - jwt: String - roles: [String!] + scope: [String!] } input PaginationInput { @@ -288,7 +283,6 @@ type Mutation { type Query { meta: Meta! session(params: SessionQueryInput): AuthResponse! - is_valid_jwt(params: IsValidJWTQueryInput): ValidJWTResponse! profile: User! # admin only apis _users(params: PaginatedInput): Users! diff --git a/server/graph/schema.resolvers.go b/server/graph/schema.resolvers.go index e3cd241..79f718a 100644 --- a/server/graph/schema.resolvers.go +++ b/server/graph/schema.resolvers.go @@ -79,10 +79,6 @@ func (r *queryResolver) Session(ctx context.Context, params *model.SessionQueryI return resolvers.SessionResolver(ctx, params) } -func (r *queryResolver) IsValidJwt(ctx context.Context, params *model.IsValidJWTQueryInput) (*model.ValidJWTResponse, error) { - return resolvers.IsValidJwtResolver(ctx, params) -} - func (r *queryResolver) Profile(ctx context.Context) (*model.User, error) { return resolvers.ProfileResolver(ctx) } diff --git a/server/handlers/oauth_callback.go b/server/handlers/oauth_callback.go index 3232c47..b16c789 100644 --- a/server/handlers/oauth_callback.go +++ b/server/handlers/oauth_callback.go @@ -144,11 +144,13 @@ func OAuthCallbackHandler() gin.HandlerFunc { } } - authToken, _ := token.CreateAuthToken(user, inputRoles) - sessionstore.SetUserSession(user.ID, authToken.FingerPrint, authToken.RefreshToken.Token) - cookie.SetCookie(c, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash) - utils.SaveSessionInDB(user.ID, c) + // TODO use query param + scope := []string{"openid", "email", "profile"} + authToken, _ := token.CreateAuthToken(c, user, inputRoles, scope) + sessionstore.SetState(authToken.FingerPrint, user.ID) + cookie.SetSession(c, authToken.FingerPrintHash) + go utils.SaveSessionInDB(c, user.ID) c.Redirect(http.StatusTemporaryRedirect, redirectURL) } } @@ -227,7 +229,7 @@ func processGithubUserInfo(code string) (models.User, error) { GivenName: &firstName, FamilyName: &lastName, Picture: &picture, - Email: userRawData["email"], + Email: userRawData["sub"], } return user, nil @@ -260,7 +262,7 @@ func processFacebookUserInfo(code string) (models.User, error) { userRawData := make(map[string]interface{}) json.Unmarshal(body, &userRawData) - email := fmt.Sprintf("%v", userRawData["email"]) + email := fmt.Sprintf("%v", userRawData["sub"]) picObject := userRawData["picture"].(map[string]interface{})["data"] picDataObject := picObject.(map[string]interface{}) diff --git a/server/handlers/verify_email.go b/server/handlers/verify_email.go index 9aecff4..acc5dea 100644 --- a/server/handlers/verify_email.go +++ b/server/handlers/verify_email.go @@ -11,6 +11,7 @@ import ( "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" "github.com/gin-gonic/gin" + "github.com/google/uuid" ) // VerifyEmailHandler handles the verify email route. @@ -18,7 +19,7 @@ import ( func VerifyEmailHandler() gin.HandlerFunc { return func(c *gin.Context) { errorRes := gin.H{ - "message": "invalid token", + "error": "invalid token", } tokenInQuery := c.Query("token") if tokenInQuery == "" { @@ -33,13 +34,21 @@ func VerifyEmailHandler() gin.HandlerFunc { } // verify if token exists in db - claim, err := token.ParseJWTToken(tokenInQuery) + hostname := utils.GetHost(c) + encryptedNonce, err := utils.EncryptNonce(verificationRequest.Nonce) + if err != nil { + c.JSON(400, gin.H{ + "error": err.Error(), + }) + return + } + claim, err := token.ParseJWTToken(tokenInQuery, hostname, encryptedNonce, verificationRequest.Email) if err != nil { c.JSON(400, errorRes) return } - user, err := db.Provider.GetUserByEmail(claim["email"].(string)) + user, err := db.Provider.GetUserByEmail(claim["sub"].(string)) if err != nil { c.JSON(400, gin.H{ "message": err.Error(), @@ -57,16 +66,19 @@ func VerifyEmailHandler() gin.HandlerFunc { db.Provider.DeleteVerificationRequest(verificationRequest) roles := strings.Split(user.Roles, ",") - authToken, err := token.CreateAuthToken(user, roles) + scope := []string{"openid", "email", "profile"} + nonce := uuid.New().String() + _, authToken, err := token.CreateSessionToken(user, nonce, roles, scope) if err != nil { c.JSON(400, gin.H{ "message": err.Error(), }) return } - sessionstore.SetUserSession(user.ID, authToken.FingerPrint, authToken.RefreshToken.Token) - cookie.SetCookie(c, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash) - utils.SaveSessionInDB(user.ID, c) + sessionstore.SetState(authToken, nonce+"@"+user.ID) + cookie.SetSession(c, authToken) + + go utils.SaveSessionInDB(c, user.ID) c.Redirect(http.StatusTemporaryRedirect, claim["redirect_url"].(string)) } diff --git a/server/resolvers/delete_user.go b/server/resolvers/delete_user.go index 56f4c3a..164c413 100644 --- a/server/resolvers/delete_user.go +++ b/server/resolvers/delete_user.go @@ -29,7 +29,7 @@ func DeleteUserResolver(ctx context.Context, params model.DeleteUserInput) (*mod return res, err } - sessionstore.DeleteAllUserSession(fmt.Sprintf("%x", user.ID)) + go sessionstore.DeleteAllUserSession(fmt.Sprintf("%x", user.ID)) err = db.Provider.DeleteUser(user) if err != nil { diff --git a/server/resolvers/forgot_password.go b/server/resolvers/forgot_password.go index 2ca5b52..49eaa75 100644 --- a/server/resolvers/forgot_password.go +++ b/server/resolvers/forgot_password.go @@ -39,7 +39,11 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu } hostname := utils.GetHost(gc) - verificationToken, err := token.CreateVerificationToken(params.Email, constants.VerificationTypeForgotPassword, hostname) + nonce, nonceHash, err := utils.GenerateNonce() + if err != nil { + return res, err + } + verificationToken, err := token.CreateVerificationToken(params.Email, constants.VerificationTypeForgotPassword, hostname, nonceHash) if err != nil { log.Println(`error generating token`, err) } @@ -48,12 +52,11 @@ func ForgotPasswordResolver(ctx context.Context, params model.ForgotPasswordInpu Identifier: constants.VerificationTypeForgotPassword, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), Email: params.Email, + Nonce: nonce, }) // exec it as go routin so that we can reduce the api latency - go func() { - email.SendForgotPasswordMail(params.Email, verificationToken, hostname) - }() + go email.SendForgotPasswordMail(params.Email, verificationToken, hostname) res = &model.Response{ Message: `Please check your inbox! We have sent a password reset link.`, diff --git a/server/resolvers/is_valid_jwt.go b/server/resolvers/is_valid_jwt.go deleted file mode 100644 index d2bd6bf..0000000 --- a/server/resolvers/is_valid_jwt.go +++ /dev/null @@ -1,52 +0,0 @@ -package resolvers - -import ( - "context" - "errors" - "fmt" - - "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/token" - tokenHelper "github.com/authorizerdev/authorizer/server/token" - "github.com/authorizerdev/authorizer/server/utils" -) - -// IsValidJwtResolver resolver to return if given jwt is valid -func IsValidJwtResolver(ctx context.Context, params *model.IsValidJWTQueryInput) (*model.ValidJWTResponse, error) { - gc, err := utils.GinContextFromContext(ctx) - token, err := token.GetAccessToken(gc) - - if token == "" || err != nil { - if params != nil && *params.Jwt != "" { - token = *params.Jwt - } else { - return nil, errors.New("no jwt provided via cookie / header / params") - } - } - - claims, err := tokenHelper.ParseJWTToken(token) - if err != nil { - return nil, err - } - - claimRoleInterface := claims[envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtRoleClaim)].([]interface{}) - claimRoles := []string{} - for _, v := range claimRoleInterface { - claimRoles = append(claimRoles, v.(string)) - } - - if params != nil && params.Roles != nil && len(params.Roles) > 0 { - for _, v := range params.Roles { - if !utils.StringSliceContains(claimRoles, v) { - return nil, fmt.Errorf(`unauthorized`) - } - } - } - - return &model.ValidJWTResponse{ - Valid: true, - Message: "Valid JWT", - }, nil -} diff --git a/server/resolvers/login.go b/server/resolvers/login.go index 76f31d7..03ffbb5 100644 --- a/server/resolvers/login.go +++ b/server/resolvers/login.go @@ -59,20 +59,36 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes roles = params.Roles } - authToken, err := token.CreateAuthToken(user, roles) + scope := []string{"openid", "email", "profile"} + if params.Scope != nil && len(scope) > 0 { + scope = params.Scope + } + + authToken, err := token.CreateAuthToken(gc, user, roles, scope) if err != nil { return res, err } - sessionstore.SetUserSession(user.ID, authToken.FingerPrint, authToken.RefreshToken.Token) - cookie.SetCookie(gc, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash) - utils.SaveSessionInDB(user.ID, gc) + cookie.SetSession(gc, authToken.FingerPrintHash) + + expiresIn := int64(1800) res = &model.AuthResponse{ Message: `Logged in successfully`, AccessToken: &authToken.AccessToken.Token, - ExpiresAt: &authToken.AccessToken.ExpiresAt, + IDToken: &authToken.IDToken.Token, + ExpiresIn: &expiresIn, User: user.AsAPIUser(), } + sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + + if authToken.RefreshToken != nil { + res.RefreshToken = &authToken.RefreshToken.Token + sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + } + + go utils.SaveSessionInDB(gc, user.ID) + return res, nil } diff --git a/server/resolvers/logout.go b/server/resolvers/logout.go index a6626e7..d2dfbc2 100644 --- a/server/resolvers/logout.go +++ b/server/resolvers/logout.go @@ -7,7 +7,6 @@ import ( "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/sessionstore" - "github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/utils" ) @@ -19,34 +18,21 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) { return res, err } - // get refresh token - refreshToken, err := token.GetRefreshToken(gc) - if err != nil { - return res, err - } - // get fingerprint hash - fingerprintHash, err := token.GetFingerPrint(gc) + fingerprintHash, err := cookie.GetSession(gc) if err != nil { return res, err } - decryptedFingerPrint, err := crypto.DecryptAES([]byte(fingerprintHash)) + decryptedFingerPrint, err := crypto.DecryptAES(fingerprintHash) if err != nil { return res, err } fingerPrint := string(decryptedFingerPrint) - // verify refresh token and fingerprint - claims, err := token.ParseJWTToken(refreshToken) - if err != nil { - return res, err - } - - userID := claims["id"].(string) - sessionstore.DeleteUserSession(userID, fingerPrint) - cookie.DeleteCookie(gc) + sessionstore.RemoveState(fingerPrint) + cookie.DeleteSession(gc) res = &model.Response{ Message: "Logged out successfully", diff --git a/server/resolvers/magic_link_login.go b/server/resolvers/magic_link_login.go index 5ed1801..a5d80ed 100644 --- a/server/resolvers/magic_link_login.go +++ b/server/resolvers/magic_link_login.go @@ -109,8 +109,12 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu hostname := utils.GetHost(gc) if !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { // insert verification request + nonce, nonceHash, err := utils.GenerateNonce() + if err != nil { + return res, err + } verificationType := constants.VerificationTypeMagicLinkLogin - verificationToken, err := token.CreateVerificationToken(params.Email, verificationType, hostname) + verificationToken, err := token.CreateVerificationToken(params.Email, verificationType, hostname, nonceHash) if err != nil { log.Println(`error generating token`, err) } @@ -119,12 +123,11 @@ func MagicLinkLoginResolver(ctx context.Context, params model.MagicLinkLoginInpu Identifier: verificationType, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), Email: params.Email, + Nonce: nonce, }) // exec it as go routin so that we can reduce the api latency - go func() { - email.SendVerificationMail(params.Email, verificationToken, hostname) - }() + go email.SendVerificationMail(params.Email, verificationToken, hostname) } res = &model.Response{ diff --git a/server/resolvers/profile.go b/server/resolvers/profile.go index 8b1e44c..882f250 100644 --- a/server/resolvers/profile.go +++ b/server/resolvers/profile.go @@ -2,7 +2,6 @@ package resolvers import ( "context" - "fmt" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" @@ -18,13 +17,17 @@ func ProfileResolver(ctx context.Context) (*model.User, error) { return res, err } - claims, err := token.ValidateAccessToken(gc) + accessToken, err := token.GetAccessToken(gc) if err != nil { return res, err } - userID := fmt.Sprintf("%v", claims["id"]) + claims, err := token.ValidateAccessToken(gc, accessToken) + if err != nil { + return res, err + } + userID := claims["sub"].(string) user, err := db.Provider.GetUserByID(userID) if err != nil { return res, err diff --git a/server/resolvers/resend_verify_email.go b/server/resolvers/resend_verify_email.go index 214668f..f514681 100644 --- a/server/resolvers/resend_verify_email.go +++ b/server/resolvers/resend_verify_email.go @@ -44,7 +44,11 @@ func ResendVerifyEmailResolver(ctx context.Context, params model.ResendVerifyEma } hostname := utils.GetHost(gc) - verificationToken, err := token.CreateVerificationToken(params.Email, params.Identifier, hostname) + nonce, nonceHash, err := utils.GenerateNonce() + if err != nil { + return res, err + } + verificationToken, err := token.CreateVerificationToken(params.Email, params.Identifier, hostname, nonceHash) if err != nil { log.Println(`error generating token`, err) } @@ -53,12 +57,11 @@ func ResendVerifyEmailResolver(ctx context.Context, params model.ResendVerifyEma Identifier: params.Identifier, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), Email: params.Email, + Nonce: nonce, }) // exec it as go routin so that we can reduce the api latency - go func() { - email.SendVerificationMail(params.Email, verificationToken, hostname) - }() + go email.SendVerificationMail(params.Email, verificationToken, hostname) res = &model.Response{ Message: `Verification email has been sent. Please check your inbox`, diff --git a/server/resolvers/reset_password.go b/server/resolvers/reset_password.go index 2eceb79..d94707f 100644 --- a/server/resolvers/reset_password.go +++ b/server/resolvers/reset_password.go @@ -12,11 +12,16 @@ import ( "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/token" + "github.com/authorizerdev/authorizer/server/utils" ) // ResetPasswordResolver is a resolver for reset password mutation func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput) (*model.Response, error) { var res *model.Response + gc, err := utils.GinContextFromContext(ctx) + if err != nil { + return res, err + } if envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableBasicAuthentication) { return res, fmt.Errorf(`basic authentication is disabled for this instance`) } @@ -31,12 +36,17 @@ func ResetPasswordResolver(ctx context.Context, params model.ResetPasswordInput) } // verify if token exists in db - claim, err := token.ParseJWTToken(params.Token) + hostname := utils.GetHost(gc) + encryptedNonce, err := utils.EncryptNonce(verificationRequest.Nonce) + if err != nil { + return res, err + } + claim, err := token.ParseJWTToken(params.Token, hostname, encryptedNonce, verificationRequest.Email) if err != nil { return res, fmt.Errorf(`invalid token`) } - user, err := db.Provider.GetUserByEmail(claim["email"].(string)) + user, err := db.Provider.GetUserByEmail(claim["sub"].(string)) if err != nil { return res, err } diff --git a/server/resolvers/session.go b/server/resolvers/session.go index acd72c4..3d9c668 100644 --- a/server/resolvers/session.go +++ b/server/resolvers/session.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/authorizerdev/authorizer/server/cookie" - "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/sessionstore" @@ -14,6 +13,7 @@ import ( ) // SessionResolver is a resolver for session query +// TODO allow validating with code and code verifier instead of cookie (PKCE flow) func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*model.AuthResponse, error) { var res *model.AuthResponse @@ -22,48 +22,27 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod return res, err } - // get refresh token - refreshToken, err := token.GetRefreshToken(gc) + sessionToken, err := cookie.GetSession(gc) if err != nil { return res, err } - // get fingerprint hash - fingerprintHash, err := token.GetFingerPrint(gc) + // get session from cookie + claims, err := token.ValidateBrowserSession(gc, sessionToken) if err != nil { return res, err } - - decryptedFingerPrint, err := crypto.DecryptAES([]byte(fingerprintHash)) - if err != nil { - return res, err - } - - fingerPrint := string(decryptedFingerPrint) - - // verify refresh token and fingerprint - claims, err := token.ParseJWTToken(refreshToken) - if err != nil { - return res, err - } - - userID := claims["id"].(string) - - persistedRefresh := sessionstore.GetUserSession(userID, fingerPrint) - if refreshToken != persistedRefresh { - return res, fmt.Errorf(`unauthorized`) - } - + userID := claims.Subject user, err := db.Provider.GetUserByID(userID) if err != nil { return res, err } // refresh token has "roles" as claim - claimRoleInterface := claims["roles"].([]interface{}) + claimRoleInterface := claims.Roles claimRoles := []string{} for _, v := range claimRoleInterface { - claimRoles = append(claimRoles, v.(string)) + claimRoles = append(claimRoles, v) } if params != nil && params.Roles != nil && len(params.Roles) > 0 { @@ -74,22 +53,35 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod } } - // delete older session - sessionstore.DeleteUserSession(userID, fingerPrint) + scope := []string{"openid", "email", "profile"} + if params != nil && params.Scope != nil && len(scope) > 0 { + scope = params.Scope + } - authToken, err := token.CreateAuthToken(user, claimRoles) + authToken, err := token.CreateAuthToken(gc, user, claimRoles, scope) if err != nil { return res, err } - sessionstore.SetUserSession(user.ID, authToken.FingerPrint, authToken.RefreshToken.Token) - cookie.SetCookie(gc, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash) + // rollover the session for security + sessionstore.RemoveState(sessionToken) + sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + cookie.SetSession(gc, authToken.FingerPrintHash) + + expiresIn := int64(1800) res = &model.AuthResponse{ Message: `Session token refreshed`, AccessToken: &authToken.AccessToken.Token, - ExpiresAt: &authToken.AccessToken.ExpiresAt, + ExpiresIn: &expiresIn, + IDToken: &authToken.IDToken.Token, User: user.AsAPIUser(), } + if authToken.RefreshToken != nil { + res.RefreshToken = &authToken.RefreshToken.Token + sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + } + return res, nil } diff --git a/server/resolvers/signup.go b/server/resolvers/signup.go index 0ce3953..69e57c3 100644 --- a/server/resolvers/signup.go +++ b/server/resolvers/signup.go @@ -123,41 +123,48 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR hostname := utils.GetHost(gc) if !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { // insert verification request - verificationType := constants.VerificationTypeBasicAuthSignup - verificationToken, err := token.CreateVerificationToken(params.Email, verificationType, hostname) + nonce, nonceHash, err := utils.GenerateNonce() if err != nil { - log.Println(`error generating token`, err) + return res, err + } + verificationType := constants.VerificationTypeBasicAuthSignup + verificationToken, err := token.CreateVerificationToken(params.Email, verificationType, hostname, nonceHash) + if err != nil { + return res, err } db.Provider.AddVerificationRequest(models.VerificationRequest{ Token: verificationToken, Identifier: verificationType, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), Email: params.Email, + Nonce: nonce, }) // exec it as go routin so that we can reduce the api latency - go func() { - email.SendVerificationMail(params.Email, verificationToken, hostname) - }() + go email.SendVerificationMail(params.Email, verificationToken, hostname) res = &model.AuthResponse{ Message: `Verification email has been sent. Please check your inbox`, User: userToReturn, } } else { + scope := []string{"openid", "email", "profile"} - authToken, err := token.CreateAuthToken(user, roles) + authToken, err := token.CreateAuthToken(gc, user, roles, scope) if err != nil { return res, err } - sessionstore.SetUserSession(user.ID, authToken.FingerPrint, authToken.RefreshToken.Token) - cookie.SetCookie(gc, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash) - utils.SaveSessionInDB(user.ID, gc) + + sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + cookie.SetSession(gc, authToken.FingerPrintHash) + go utils.SaveSessionInDB(gc, user.ID) + + expiresIn := int64(1800) res = &model.AuthResponse{ Message: `Signed up successfully.`, AccessToken: &authToken.AccessToken.Token, - ExpiresAt: &authToken.AccessToken.ExpiresAt, + ExpiresIn: &expiresIn, User: userToReturn, } } diff --git a/server/resolvers/update_profile.go b/server/resolvers/update_profile.go index 2fef3e3..cc43f4c 100644 --- a/server/resolvers/update_profile.go +++ b/server/resolvers/update_profile.go @@ -13,6 +13,7 @@ import ( "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" + "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/authorizerdev/authorizer/server/token" @@ -28,7 +29,11 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) return res, err } - claims, err := token.ValidateAccessToken(gc) + accessToken, err := token.GetAccessToken(gc) + if err != nil { + return res, err + } + claims, err := token.ValidateAccessToken(gc, accessToken) if err != nil { return res, err } @@ -38,8 +43,8 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) return res, fmt.Errorf("please enter at least one param to update") } - userEmail := fmt.Sprintf("%v", claims["email"]) - user, err := db.Provider.GetUserByEmail(userEmail) + userID := claims["sub"].(string) + user, err := db.Provider.GetUserByID(userID) if err != nil { return res, err } @@ -108,38 +113,44 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) newEmail := strings.ToLower(*params.Email) // check if user with new email exists _, err := db.Provider.GetUserByEmail(newEmail) - // err = nil means user exists if err == nil { return res, fmt.Errorf("user with this email address already exists") } - sessionstore.DeleteAllUserSession(fmt.Sprintf("%v", user.ID)) - cookie.DeleteCookie(gc) + // TODO figure out how to delete all user sessions + go sessionstore.DeleteAllUserSession(user.ID) - hostname := utils.GetHost(gc) + cookie.DeleteSession(gc) user.Email = newEmail - user.EmailVerifiedAt = nil - hasEmailChanged = true - // insert verification request - verificationType := constants.VerificationTypeUpdateEmail - verificationToken, err := token.CreateVerificationToken(newEmail, verificationType, hostname) - if err != nil { - log.Println(`error generating token`, err) + + if !envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification) { + hostname := utils.GetHost(gc) + user.EmailVerifiedAt = nil + hasEmailChanged = true + // insert verification request + nonce, nonceHash, err := utils.GenerateNonce() + if err != nil { + return res, err + } + verificationType := constants.VerificationTypeUpdateEmail + verificationToken, err := token.CreateVerificationToken(newEmail, verificationType, hostname, nonceHash) + if err != nil { + log.Println(`error generating token`, err) + } + db.Provider.AddVerificationRequest(models.VerificationRequest{ + Token: verificationToken, + Identifier: verificationType, + ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), + Email: newEmail, + Nonce: nonce, + }) + + // exec it as go routin so that we can reduce the api latency + go email.SendVerificationMail(newEmail, verificationToken, hostname) + } - db.Provider.AddVerificationRequest(models.VerificationRequest{ - Token: verificationToken, - Identifier: verificationType, - ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), - Email: newEmail, - }) - - // exec it as go routin so that we can reduce the api latency - go func() { - email.SendVerificationMail(newEmail, verificationToken, hostname) - }() } - _, err = db.Provider.UpdateUser(user) if err != nil { log.Println("error updating user:", err) diff --git a/server/resolvers/update_user.go b/server/resolvers/update_user.go index 40b4206..948565e 100644 --- a/server/resolvers/update_user.go +++ b/server/resolvers/update_user.go @@ -8,7 +8,6 @@ import ( "time" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/cookie" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/email" @@ -95,15 +94,19 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod return res, fmt.Errorf("user with this email address already exists") } - sessionstore.DeleteAllUserSession(fmt.Sprintf("%v", user.ID)) - cookie.DeleteCookie(gc) + // TODO figure out how to do this + go sessionstore.DeleteAllUserSession(user.ID) hostname := utils.GetHost(gc) user.Email = newEmail user.EmailVerifiedAt = nil // insert verification request + nonce, nonceHash, err := utils.GenerateNonce() + if err != nil { + return res, err + } verificationType := constants.VerificationTypeUpdateEmail - verificationToken, err := token.CreateVerificationToken(newEmail, verificationType, hostname) + verificationToken, err := token.CreateVerificationToken(newEmail, verificationType, hostname, nonceHash) if err != nil { log.Println(`error generating token`, err) } @@ -112,12 +115,12 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod Identifier: verificationType, ExpiresAt: time.Now().Add(time.Minute * 30).Unix(), Email: newEmail, + Nonce: nonce, }) // exec it as go routin so that we can reduce the api latency - go func() { - email.SendVerificationMail(newEmail, verificationToken, hostname) - }() + go email.SendVerificationMail(newEmail, verificationToken, hostname) + } rolesToSave := "" @@ -136,8 +139,7 @@ func UpdateUserResolver(ctx context.Context, params model.UpdateUserInput) (*mod rolesToSave = strings.Join(inputRoles, ",") } - sessionstore.DeleteAllUserSession(fmt.Sprintf("%v", user.ID)) - cookie.DeleteCookie(gc) + go sessionstore.DeleteAllUserSession(user.ID) } if rolesToSave != "" { diff --git a/server/resolvers/verify_email.go b/server/resolvers/verify_email.go index c244992..953db47 100644 --- a/server/resolvers/verify_email.go +++ b/server/resolvers/verify_email.go @@ -28,12 +28,17 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m } // verify if token exists in db - claim, err := token.ParseJWTToken(params.Token) + hostname := utils.GetHost(gc) + encryptedNonce, err := utils.EncryptNonce(verificationRequest.Nonce) + if err != nil { + return res, err + } + claim, err := token.ParseJWTToken(params.Token, hostname, encryptedNonce, verificationRequest.Email) if err != nil { return res, fmt.Errorf(`invalid token: %s`, err.Error()) } - user, err := db.Provider.GetUserByEmail(claim["email"].(string)) + user, err := db.Provider.GetUserByEmail(claim["sub"].(string)) if err != nil { return res, err } @@ -41,25 +46,35 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m // update email_verified_at in users table now := time.Now().Unix() user.EmailVerifiedAt = &now - db.Provider.UpdateUser(user) - // delete from verification table - db.Provider.DeleteVerificationRequest(verificationRequest) - - roles := strings.Split(user.Roles, ",") - authToken, err := token.CreateAuthToken(user, roles) + user, err = db.Provider.UpdateUser(user) + if err != nil { + return res, err + } + // delete from verification table + err = db.Provider.DeleteVerificationRequest(verificationRequest) if err != nil { return res, err } - sessionstore.SetUserSession(user.ID, authToken.FingerPrint, authToken.RefreshToken.Token) - cookie.SetCookie(gc, authToken.AccessToken.Token, authToken.RefreshToken.Token, authToken.FingerPrintHash) - utils.SaveSessionInDB(user.ID, gc) + roles := strings.Split(user.Roles, ",") + scope := []string{"openid", "email", "profile"} + authToken, err := token.CreateAuthToken(gc, user, roles, scope) + if err != nil { + return res, err + } + + sessionstore.SetState(authToken.FingerPrintHash, authToken.FingerPrint+"@"+user.ID) + sessionstore.SetState(authToken.AccessToken.Token, authToken.FingerPrint+"@"+user.ID) + cookie.SetSession(gc, authToken.FingerPrintHash) + go utils.SaveSessionInDB(gc, user.ID) + + expiresIn := int64(1800) res = &model.AuthResponse{ Message: `Email verified successfully.`, AccessToken: &authToken.AccessToken.Token, - ExpiresAt: &authToken.AccessToken.ExpiresAt, + IDToken: &authToken.IDToken.Token, + ExpiresIn: &expiresIn, User: user.AsAPIUser(), } - return res, nil } diff --git a/server/sessionstore/in_memory_session.go b/server/sessionstore/in_memory_session.go index 567c881..edc6924 100644 --- a/server/sessionstore/in_memory_session.go +++ b/server/sessionstore/in_memory_session.go @@ -1,6 +1,7 @@ package sessionstore import ( + "strings" "sync" ) @@ -11,42 +12,6 @@ type InMemoryStore struct { stateStore map[string]string } -// AddUserSession adds a user session to the in-memory store. -func (c *InMemoryStore) AddUserSession(userId, accessToken, refreshToken string) { - c.mutex.Lock() - defer c.mutex.Unlock() - // delete sessions > 500 // not recommended for production - if len(c.sessionStore) >= 500 { - c.sessionStore = map[string]map[string]string{} - } - // check if entry exists in map - _, exists := c.sessionStore[userId] - if exists { - tempMap := c.sessionStore[userId] - tempMap[accessToken] = refreshToken - c.sessionStore[userId] = tempMap - } else { - tempMap := map[string]string{ - accessToken: refreshToken, - } - c.sessionStore[userId] = tempMap - } -} - -// DeleteAllUserSession deletes all the user sessions from in-memory store. -func (c *InMemoryStore) DeleteAllUserSession(userId string) { - c.mutex.Lock() - defer c.mutex.Unlock() - delete(c.sessionStore, userId) -} - -// DeleteUserSession deletes the particular user session from in-memory store. -func (c *InMemoryStore) DeleteUserSession(userId, accessToken string) { - c.mutex.Lock() - defer c.mutex.Unlock() - delete(c.sessionStore[userId], accessToken) -} - // ClearStore clears the in-memory store. func (c *InMemoryStore) ClearStore() { c.mutex.Lock() @@ -54,32 +19,29 @@ func (c *InMemoryStore) ClearStore() { c.sessionStore = map[string]map[string]string{} } -// GetUserSession returns the user session token from the in-memory store. -func (c *InMemoryStore) GetUserSession(userId, accessToken string) string { - // c.mutex.Lock() - // defer c.mutex.Unlock() - - token := "" - if sessionMap, ok := c.sessionStore[userId]; ok { - if val, ok := sessionMap[accessToken]; ok { - token = val - } - } - - return token -} - // GetUserSessions returns all the user session token from the in-memory store. func (c *InMemoryStore) GetUserSessions(userId string) map[string]string { // c.mutex.Lock() // defer c.mutex.Unlock() - - sessionMap, ok := c.sessionStore[userId] - if !ok { - return nil + res := map[string]string{} + for k, v := range c.stateStore { + split := strings.Split(v, "@") + if split[1] == userId { + res[k] = split[0] + } } - return sessionMap + return res +} + +// DeleteAllUserSession deletes all the user sessions from in-memory store. +func (c *InMemoryStore) DeleteAllUserSession(userId string) { + // c.mutex.Lock() + // defer c.mutex.Unlock() + sessions := GetUserSessions(userId) + for k := range sessions { + RemoveState(k) + } } // SetState sets the state in the in-memory store. diff --git a/server/sessionstore/redis_store.go b/server/sessionstore/redis_store.go index 98a1711..7e48335 100644 --- a/server/sessionstore/redis_store.go +++ b/server/sessionstore/redis_store.go @@ -2,8 +2,8 @@ package sessionstore import ( "context" - "fmt" "log" + "strings" ) type RedisStore struct { @@ -11,32 +11,6 @@ type RedisStore struct { store RedisSessionClient } -// AddUserSession adds the user session to redis -func (c *RedisStore) AddUserSession(userId, accessToken, refreshToken string) { - err := c.store.HMSet(c.ctx, "authorizer_"+userId, map[string]string{ - accessToken: refreshToken, - }).Err() - if err != nil { - log.Fatalln("Error saving redis token:", err) - } -} - -// DeleteAllUserSession deletes all the user session from redis -func (c *RedisStore) DeleteAllUserSession(userId string) { - err := c.store.Del(c.ctx, "authorizer_"+userId).Err() - if err != nil { - log.Fatalln("Error deleting redis token:", err) - } -} - -// DeleteUserSession deletes the particular user session from redis -func (c *RedisStore) DeleteUserSession(userId, accessToken string) { - err := c.store.HDel(c.ctx, "authorizer_"+userId, accessToken).Err() - if err != nil { - log.Fatalln("Error deleting redis token:", err) - } -} - // ClearStore clears the redis store for authorizer related tokens func (c *RedisStore) ClearStore() { err := c.store.Del(c.ctx, "authorizer_*").Err() @@ -45,32 +19,40 @@ func (c *RedisStore) ClearStore() { } } -// GetUserSession returns the user session token from the redis store. -func (c *RedisStore) GetUserSession(userId, accessToken string) string { - token := "" - res, err := c.store.HMGet(c.ctx, "authorizer_"+userId, accessToken).Result() - if err != nil { - log.Println("error getting token from redis store:", err) - } - if len(res) > 0 && res[0] != nil { - token = fmt.Sprintf("%v", res[0]) - } - return token -} - // GetUserSessions returns all the user session token from the redis store. func (c *RedisStore) GetUserSessions(userID string) map[string]string { - res, err := c.store.HGetAll(c.ctx, "authorizer_"+userID).Result() + data, err := c.store.HGetAll(c.ctx, "*").Result() if err != nil { log.Println("error getting token from redis store:", err) } + res := map[string]string{} + for k, v := range data { + split := strings.Split(v, "@") + if split[1] == userID { + res[k] = split[0] + } + } + return res } +// DeleteAllUserSession deletes all the user session from redis +func (c *RedisStore) DeleteAllUserSession(userId string) { + sessions := GetUserSessions(userId) + for k, v := range sessions { + if k == "token" { + err := c.store.Del(c.ctx, v) + if err != nil { + log.Println("Error deleting redis token:", err) + } + } + } +} + // SetState sets the state in redis store. -func (c *RedisStore) SetState(key, state string) { - err := c.store.Set(c.ctx, key, state, 0).Err() +func (c *RedisStore) SetState(key, value string) { + err := c.store.Set(c.ctx, key, value, 0).Err() if err != nil { log.Fatalln("Error saving redis token:", err) } diff --git a/server/sessionstore/session.go b/server/sessionstore/session.go index f33638c..659ddaa 100644 --- a/server/sessionstore/session.go +++ b/server/sessionstore/session.go @@ -22,26 +22,6 @@ type SessionStore struct { // reference to various session store instances var SessionStoreObj SessionStore -// SetUserSession sets the user session in the session store -func SetUserSession(userId, fingerprint, refreshToken string) { - if SessionStoreObj.RedisMemoryStoreObj != nil { - SessionStoreObj.RedisMemoryStoreObj.AddUserSession(userId, fingerprint, refreshToken) - } - if SessionStoreObj.InMemoryStoreObj != nil { - SessionStoreObj.InMemoryStoreObj.AddUserSession(userId, fingerprint, refreshToken) - } -} - -// DeleteUserSession deletes the particular user session from the session store -func DeleteUserSession(userId, fingerprint string) { - if SessionStoreObj.RedisMemoryStoreObj != nil { - SessionStoreObj.RedisMemoryStoreObj.DeleteUserSession(userId, fingerprint) - } - if SessionStoreObj.InMemoryStoreObj != nil { - SessionStoreObj.InMemoryStoreObj.DeleteUserSession(userId, fingerprint) - } -} - // DeleteAllSessions deletes all the sessions from the session store func DeleteAllUserSession(userId string) { if SessionStoreObj.RedisMemoryStoreObj != nil { @@ -52,18 +32,6 @@ func DeleteAllUserSession(userId string) { } } -// GetUserSession returns the user session from the session store -func GetUserSession(userId, fingerprint string) string { - if SessionStoreObj.RedisMemoryStoreObj != nil { - return SessionStoreObj.RedisMemoryStoreObj.GetUserSession(userId, fingerprint) - } - if SessionStoreObj.InMemoryStoreObj != nil { - return SessionStoreObj.InMemoryStoreObj.GetUserSession(userId, fingerprint) - } - - return "" -} - // GetUserSessions returns all the user sessions from the session store func GetUserSessions(userId string) map[string]string { if SessionStoreObj.RedisMemoryStoreObj != nil { diff --git a/server/test/is_valid_jwt_test.go b/server/test/is_valid_jwt_test.go deleted file mode 100644 index 12abbc1..0000000 --- a/server/test/is_valid_jwt_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package test - -import ( - "testing" - - "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/graph/model" - "github.com/authorizerdev/authorizer/server/resolvers" - "github.com/authorizerdev/authorizer/server/token" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" -) - -func isValidJWTTests(t *testing.T, s TestSetup) { - t.Helper() - _, ctx := createContext(s) - expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhbGxvd2VkX3JvbGVzIjpbIiJdLCJiaXJ0aGRhdGUiOm51bGwsImNyZWF0ZWRfYXQiOjAsImVtYWlsIjoiam9obi5kb2VAZ21haWwuY29tIiwiZW1haWxfdmVyaWZpZWQiOmZhbHNlLCJleHAiOjE2NDI5NjEwMTEsImV4dHJhIjp7IngtZXh0cmEtaWQiOiJkMmNhMjQwNy05MzZmLTQwYzQtOTQ2NS05Y2M5MWYxZTJhNDQifSwiZmFtaWx5X25hbWUiOm51bGwsImdlbmRlciI6bnVsbCwiZ2l2ZW5fbmFtZSI6bnVsbCwiaWF0IjoxNjQyOTYwOTgxLCJpZCI6ImQyY2EyNDA3LTkzNmYtNDBjNC05NDY1LTljYzkxZjFlMmE0NCIsIm1pZGRsZV9uYW1lIjpudWxsLCJuaWNrbmFtZSI6bnVsbCwicGhvbmVfbnVtYmVyIjpudWxsLCJwaG9uZV9udW1iZXJfdmVyaWZpZWQiOmZhbHNlLCJwaWN0dXJlIjpudWxsLCJwcmVmZXJyZWRfdXNlcm5hbWUiOiJqb2huLmRvZUBnbWFpbC5jb20iLCJyb2xlIjpbXSwic2lnbnVwX21ldGhvZHMiOiIiLCJ0b2tlbl90eXBlIjoiYWNjZXNzX3Rva2VuIiwidXBkYXRlZF9hdCI6MH0.FrdyeOC5e8uU1SowGj0omFJuwRnh4BrEk89S_fbEkzs" - - t.Run(`should fail for invalid jwt`, func(t *testing.T) { - _, err := resolvers.IsValidJwtResolver(ctx, &model.IsValidJWTQueryInput{ - Jwt: &expiredToken, - }) - assert.NotNil(t, err) - }) - - t.Run(`should pass with valid jwt`, func(t *testing.T) { - authToken, err := token.CreateAuthToken(models.User{ - ID: uuid.New().String(), - Email: "john.doe@gmail.com", - }, []string{}) - assert.Nil(t, err) - res, err := resolvers.IsValidJwtResolver(ctx, &model.IsValidJWTQueryInput{ - Jwt: &authToken.AccessToken.Token, - }) - assert.Nil(t, err) - assert.True(t, res.Valid) - }) -} diff --git a/server/test/jwt_test.go b/server/test/jwt_test.go index a72cedf..71f74be 100644 --- a/server/test/jwt_test.go +++ b/server/test/jwt_test.go @@ -9,6 +9,7 @@ import ( "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/token" "github.com/golang-jwt/jwt" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -18,12 +19,17 @@ func TestJwt(t *testing.T) { publicKey := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPublicKey) privateKey := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtPrivateKey) clientID := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) + nonce := uuid.New().String() + hostname := "localhost" + subject := "test" claims := jwt.MapClaims{ "exp": time.Now().Add(time.Minute * 30).Unix(), "iat": time.Now().Unix(), "email": "test@yopmail.com", - "sub": "test", + "sub": subject, "aud": clientID, + "nonce": nonce, + "iss": hostname, } t.Run("invalid jwt type", func(t *testing.T) { @@ -42,7 +48,7 @@ func TestJwt(t *testing.T) { } jwtToken, err := token.SignJWTToken(expiredClaims) assert.NoError(t, err) - _, err = token.ParseJWTToken(jwtToken) + _, err = token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.Error(t, err, err.Error(), "Token is expired") }) t.Run("HMAC algorithms", func(t *testing.T) { @@ -52,7 +58,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -61,7 +67,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -70,7 +76,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -86,7 +92,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -99,7 +105,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -112,7 +118,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -128,7 +134,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -141,7 +147,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) @@ -154,7 +160,7 @@ func TestJwt(t *testing.T) { jwtToken, err := token.SignJWTToken(claims) assert.NoError(t, err) assert.NotEmpty(t, jwtToken) - c, err := token.ParseJWTToken(jwtToken) + c, err := token.ParseJWTToken(jwtToken, hostname, nonce, subject) assert.NoError(t, err) assert.Equal(t, c["email"].(string), claims["email"]) }) diff --git a/server/test/login_test.go b/server/test/login_test.go index 4c25f5a..ebfbf68 100644 --- a/server/test/login_test.go +++ b/server/test/login_test.go @@ -5,14 +5,17 @@ import ( "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" + "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/resolvers" + "github.com/authorizerdev/authorizer/server/utils" "github.com/stretchr/testify/assert" ) func loginTests(t *testing.T, s TestSetup) { t.Helper() t.Run(`should login`, func(t *testing.T) { + t.Logf("=> is enabled: %v", envstore.EnvStoreObj.GetBoolStoreEnvVariable(constants.EnvKeyDisableEmailVerification)) _, ctx := createContext(s) email := "login." + s.TestInfo.Email _, err := resolvers.SignupResolver(ctx, model.SignUpInput{ @@ -21,15 +24,19 @@ func loginTests(t *testing.T, s TestSetup) { ConfirmPassword: s.TestInfo.Password, }) - _, err = resolvers.LoginResolver(ctx, model.LoginInput{ + res, err := resolvers.LoginResolver(ctx, model.LoginInput{ Email: email, Password: s.TestInfo.Password, }) assert.NotNil(t, err, "should fail because email is not verified") - + assert.Nil(t, res) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(email, constants.VerificationTypeBasicAuthSignup) - res, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ + n, err := utils.EncryptNonce(verificationRequest.Nonce) + assert.NoError(t, err) + assert.NotEmpty(t, n) + assert.NotNil(t, verificationRequest) + res, err = resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) assert.NoError(t, err) diff --git a/server/test/logout_test.go b/server/test/logout_test.go index db5688a..8956b31 100644 --- a/server/test/logout_test.go +++ b/server/test/logout_test.go @@ -2,11 +2,9 @@ package test import ( "fmt" - "net/url" "testing" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" @@ -30,18 +28,15 @@ func logoutTests(t *testing.T, s TestSetup) { Token: verificationRequest.Token, }) - sessions := sessionstore.GetUserSessions(verifyRes.User.ID) - fingerPrint := "" - refreshToken := "" - for key, val := range sessions { - fingerPrint = key - refreshToken = val - } - - fingerPrintHash, _ := crypto.EncryptAES([]byte(fingerPrint)) - token := *verifyRes.AccessToken - cookie := fmt.Sprintf("%s=%s;%s=%s;%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".fingerprint", url.QueryEscape(string(fingerPrintHash)), envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".refresh_token", refreshToken, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token", token) + sessions := sessionstore.GetUserSessions(verifyRes.User.ID) + cookie := "" + // set all they keys in cookie one of them should be session cookie + for key := range sessions { + if key != token { + cookie += fmt.Sprintf("%s=%s;", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", key) + } + } req.Header.Set("Cookie", cookie) _, err = resolvers.LogoutResolver(ctx) diff --git a/server/test/magic_link_login_test.go b/server/test/magic_link_login_test.go index b42378f..95e0712 100644 --- a/server/test/magic_link_login_test.go +++ b/server/test/magic_link_login_test.go @@ -1,12 +1,11 @@ package test import ( - "fmt" + "context" "testing" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" @@ -27,12 +26,13 @@ func magicLinkLoginTests(t *testing.T, s TestSetup) { verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) - - token := *verifyRes.AccessToken - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token", token)) + assert.NoError(t, err) + assert.NotNil(t, verifyRes.AccessToken) + s.GinContext.Request.Header.Set("Authorization", "Bearer "+*verifyRes.AccessToken) + ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) _, err = resolvers.ProfileResolver(ctx) assert.Nil(t, err) - + s.GinContext.Request.Header.Set("Authorization", "") cleanData(email) }) } diff --git a/server/test/profile_test.go b/server/test/profile_test.go index 4801afc..8f5a283 100644 --- a/server/test/profile_test.go +++ b/server/test/profile_test.go @@ -1,12 +1,11 @@ package test import ( - "fmt" + "context" "testing" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" @@ -14,7 +13,7 @@ import ( func profileTests(t *testing.T, s TestSetup) { t.Helper() - t.Run(`should get profile only with token`, func(t *testing.T) { + t.Run(`should get profile only access_token token`, func(t *testing.T) { req, ctx := createContext(s) email := "profile." + s.TestInfo.Email @@ -31,11 +30,14 @@ func profileTests(t *testing.T, s TestSetup) { verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) + assert.NoError(t, err) + assert.NotNil(t, verifyRes.AccessToken) - token := *verifyRes.AccessToken - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token", token)) + s.GinContext.Request.Header.Set("Authorization", "Bearer "+*verifyRes.AccessToken) + ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) profileRes, err := resolvers.ProfileResolver(ctx) assert.Nil(t, err) + s.GinContext.Request.Header.Set("Authorization", "") newEmail := *&profileRes.Email assert.Equal(t, email, newEmail, "emails should be equal") diff --git a/server/test/resolvers_test.go b/server/test/resolvers_test.go index a39f3b5..f47e034 100644 --- a/server/test/resolvers_test.go +++ b/server/test/resolvers_test.go @@ -15,15 +15,16 @@ func TestResolvers(t *testing.T) { // constants.DbTypeArangodb: "http://localhost:8529", // constants.DbTypeMongodb: "mongodb://localhost:27017", } - envstore.EnvStoreObj.ResetStore() envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyVersion, "test") for dbType, dbURL := range databases { + s := testSetup() envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseURL, dbURL) envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseType, dbType) - - s := testSetup() defer s.Server.Close() - db.InitDB() + err := db.InitDB() + if err != nil { + t.Errorf("Error initializing database: %s", err) + } // clean the persisted config for test to use fresh config envData, err := db.Provider.GetEnv() @@ -31,12 +32,10 @@ func TestResolvers(t *testing.T) { envData.EnvData = "" db.Provider.UpdateEnv(envData) } - err = env.InitAllEnv() - if err != nil { - t.Error(err) - } env.PersistEnv() + envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEnv, "test") + envstore.EnvStoreObj.UpdateEnvVariable(constants.BoolStoreIdentifier, constants.EnvKeyIsProd, false) t.Run("should pass tests for "+dbType, func(t *testing.T) { // admin tests adminSignupTests(t, s) @@ -63,7 +62,6 @@ func TestResolvers(t *testing.T) { magicLinkLoginTests(t, s) logoutTests(t, s) metaTests(t, s) - isValidJWTTests(t, s) }) } } diff --git a/server/test/session_test.go b/server/test/session_test.go index 99b1295..65ce57e 100644 --- a/server/test/session_test.go +++ b/server/test/session_test.go @@ -2,11 +2,10 @@ package test import ( "fmt" - "net/url" + "strings" "testing" "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" @@ -36,17 +35,15 @@ func sessionTests(t *testing.T, s TestSetup) { }) sessions := sessionstore.GetUserSessions(verifyRes.User.ID) - fingerPrint := "" - refreshToken := "" - for key, val := range sessions { - fingerPrint = key - refreshToken = val - } - - fingerPrintHash, _ := crypto.EncryptAES([]byte(fingerPrint)) - + cookie := "" token := *verifyRes.AccessToken - cookie := fmt.Sprintf("%s=%s;%s=%s;%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".fingerprint", url.QueryEscape(string(fingerPrintHash)), envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".refresh_token", refreshToken, envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token", token) + // set all they keys in cookie one of them should be session cookie + for key := range sessions { + if key != token { + cookie += fmt.Sprintf("%s=%s;", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+"_session", key) + } + } + cookie = strings.TrimSuffix(cookie, ";") req.Header.Set("Cookie", cookie) diff --git a/server/test/test.go b/server/test/test.go index e99b7cd..6c39cc6 100644 --- a/server/test/test.go +++ b/server/test/test.go @@ -72,13 +72,13 @@ func testSetup() TestSetup { } envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyEnvPath, "../../.env.sample") + env.InitRequiredEnv() envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpHost, "smtp.yopmail.com") envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpPort, "2525") envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpUsername, "lakhan@yopmail.com") envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySmtpPassword, "test") envstore.EnvStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeySenderEmail, "info@yopmail.com") envstore.EnvStoreObj.UpdateEnvVariable(constants.SliceStoreIdentifier, constants.EnvKeyProtectedRoles, []string{"admin"}) - env.InitRequiredEnv() db.InitDB() env.InitAllEnv() sessionstore.InitSession() diff --git a/server/test/update_profile_test.go b/server/test/update_profile_test.go index 04d1785..68dae47 100644 --- a/server/test/update_profile_test.go +++ b/server/test/update_profile_test.go @@ -1,12 +1,11 @@ package test import ( - "fmt" + "context" "testing" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" - "github.com/authorizerdev/authorizer/server/envstore" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/resolvers" "github.com/stretchr/testify/assert" @@ -34,18 +33,16 @@ func updateProfileTests(t *testing.T, s TestSetup) { verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ Token: verificationRequest.Token, }) + assert.NoError(t, err) - token := *verifyRes.AccessToken - req.Header.Set("Cookie", fmt.Sprintf("%s=%s", envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyCookieName)+".access_token", token)) - _, err = resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{ - FamilyName: &fName, - }) - assert.Nil(t, err) + s.GinContext.Request.Header.Set("Authorization", "Bearer "+*verifyRes.AccessToken) + ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) newEmail := "new_" + email _, err = resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{ Email: &newEmail, }) + s.GinContext.Request.Header.Set("Authorization", "") assert.Nil(t, err) _, err = resolvers.ProfileResolver(ctx) assert.NotNil(t, err, "unauthorized") diff --git a/server/test/verification_requests_test.go b/server/test/verification_requests_test.go index 58891a0..b81a35f 100644 --- a/server/test/verification_requests_test.go +++ b/server/test/verification_requests_test.go @@ -19,12 +19,15 @@ func verificationRequestsTest(t *testing.T, s TestSetup) { req, ctx := createContext(s) email := "verification_requests." + s.TestInfo.Email - resolvers.SignupResolver(ctx, model.SignUpInput{ + res, err := resolvers.SignupResolver(ctx, model.SignUpInput{ Email: email, Password: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password, }) + assert.NoError(t, err) + assert.NotNil(t, res) + limit := int64(10) page := int64(1) pagination := &model.PaginatedInput{ diff --git a/server/token/auth_token.go b/server/token/auth_token.go index b9eb25d..16abbd6 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -2,23 +2,23 @@ package token import ( "encoding/json" - "errors" "fmt" "log" "os" "strings" "time" - "github.com/authorizerdev/authorizer/server/constants" - "github.com/authorizerdev/authorizer/server/cookie" - "github.com/authorizerdev/authorizer/server/crypto" - "github.com/authorizerdev/authorizer/server/db/models" - "github.com/authorizerdev/authorizer/server/envstore" - "github.com/authorizerdev/authorizer/server/sessionstore" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/google/uuid" "github.com/robertkrimen/otto" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/crypto" + "github.com/authorizerdev/authorizer/server/db/models" + "github.com/authorizerdev/authorizer/server/envstore" + "github.com/authorizerdev/authorizer/server/sessionstore" + "github.com/authorizerdev/authorizer/server/utils" ) // JWTToken is a struct to hold JWT token and its expiration time @@ -33,60 +33,219 @@ type Token struct { FingerPrintHash string `json:"fingerprint_hash"` RefreshToken *JWTToken `json:"refresh_token"` AccessToken *JWTToken `json:"access_token"` + IDToken *JWTToken `json:"id_token"` +} + +// SessionData +type SessionData struct { + Subject string `json:"sub"` + Roles []string `json:"roles"` + Scope []string `json:"scope"` + Nonce string `json:"nonce"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp"` +} + +// CreateSessionToken creates a new session token +func CreateSessionToken(user models.User, nonce string, roles, scope []string) (*SessionData, string, error) { + fingerPrintMap := &SessionData{ + Nonce: nonce, + Roles: roles, + Subject: user.ID, + Scope: scope, + IssuedAt: time.Now().Unix(), + ExpiresAt: time.Now().AddDate(1, 0, 0).Unix(), + } + fingerPrintBytes, _ := json.Marshal(fingerPrintMap) + fingerPrintHash, err := crypto.EncryptAES(string(fingerPrintBytes)) + if err != nil { + return nil, "", err + } + + return fingerPrintMap, fingerPrintHash, nil } // CreateAuthToken creates a new auth token when userlogs in -func CreateAuthToken(user models.User, roles []string) (*Token, error) { - fingerprint := uuid.NewString() - fingerPrintHashBytes, err := crypto.EncryptAES([]byte(fingerprint)) +func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string) (*Token, error) { + hostname := utils.GetHost(gc) + nonce := uuid.New().String() + _, fingerPrintHash, err := CreateSessionToken(user, nonce, roles, scope) if err != nil { return nil, err } - refreshToken, refreshTokenExpiresAt, err := CreateRefreshToken(user, roles) + accessToken, accessTokenExpiresAt, err := CreateAccessToken(user, roles, scope, hostname, nonce) if err != nil { return nil, err } - accessToken, accessTokenExpiresAt, err := CreateAccessToken(user, roles) + idToken, idTokenExpiresAt, err := CreateIDToken(user, roles, hostname, nonce) if err != nil { return nil, err } - return &Token{ - FingerPrint: fingerprint, - FingerPrintHash: string(fingerPrintHashBytes), - RefreshToken: &JWTToken{Token: refreshToken, ExpiresAt: refreshTokenExpiresAt}, + res := &Token{ + FingerPrint: nonce, + FingerPrintHash: fingerPrintHash, AccessToken: &JWTToken{Token: accessToken, ExpiresAt: accessTokenExpiresAt}, - }, nil + IDToken: &JWTToken{Token: idToken, ExpiresAt: idTokenExpiresAt}, + } + + if utils.StringSliceContains(scope, "offline_access") { + refreshToken, refreshTokenExpiresAt, err := CreateRefreshToken(user, roles, hostname, nonce) + if err != nil { + return nil, err + } + + res.RefreshToken = &JWTToken{Token: refreshToken, ExpiresAt: refreshTokenExpiresAt} + } + + return res, nil } // CreateRefreshToken util to create JWT token -func CreateRefreshToken(user models.User, roles []string) (string, int64, error) { +func CreateRefreshToken(user models.User, roles []string, hostname, nonce string) (string, int64, error) { // expires in 1 year expiryBound := time.Hour * 8760 expiresAt := time.Now().Add(expiryBound).Unix() - customClaims := jwt.MapClaims{ - "iss": "", - "aud": "", + "iss": hostname, + "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), "sub": user.ID, "exp": expiresAt, "iat": time.Now().Unix(), "token_type": constants.TokenTypeRefreshToken, "roles": roles, - "id": user.ID, + "nonce": nonce, } token, err := SignJWTToken(customClaims) if err != nil { return "", 0, err } + return token, expiresAt, nil } // CreateAccessToken util to create JWT token, based on // user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT -func CreateAccessToken(user models.User, roles []string) (string, int64, error) { +func CreateAccessToken(user models.User, roles, scopes []string, hostName, nonce string) (string, int64, error) { + expiryBound := time.Minute * 30 + expiresAt := time.Now().Add(expiryBound).Unix() + + customClaims := jwt.MapClaims{ + "iss": hostName, + "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), + "nonce": nonce, + "sub": user.ID, + "exp": expiresAt, + "iat": time.Now().Unix(), + "token_type": constants.TokenTypeAccessToken, + "scope": scopes, + "roles": roles, + } + + token, err := SignJWTToken(customClaims) + if err != nil { + return "", 0, err + } + + return token, expiresAt, nil +} + +// GetAccessToken returns the access token from the request (either from header or cookie) +func GetAccessToken(gc *gin.Context) (string, error) { + // try to check in auth header for cookie + auth := gc.Request.Header.Get("Authorization") + if auth == "" { + return "", fmt.Errorf(`unauthorized`) + } + + if !strings.HasPrefix(auth, "Bearer ") { + return "", fmt.Errorf(`not a bearer token`) + } + + token := strings.TrimPrefix(auth, "Bearer ") + return token, nil +} + +// Function to validate access token for authorizer apis (profile, update_profile) +func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interface{}, error) { + var res map[string]interface{} + + if accessToken == "" { + return res, fmt.Errorf(`unauthorized`) + } + + savedSession := sessionstore.GetState(accessToken) + if savedSession == "" { + return res, fmt.Errorf(`unauthorized`) + } + + savedSessionSplit := strings.Split(savedSession, "@") + nonce := savedSessionSplit[0] + userID := savedSessionSplit[1] + + hostname := utils.GetHost(gc) + res, err := ParseJWTToken(accessToken, hostname, nonce, userID) + if err != nil { + return res, err + } + + if res["token_type"] != constants.TokenTypeAccessToken { + return res, fmt.Errorf(`unauthorized: invalid token type`) + } + + return res, nil +} + +func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionData, error) { + if encryptedSession == "" { + return nil, fmt.Errorf(`unauthorized`) + } + + savedSession := sessionstore.GetState(encryptedSession) + if savedSession == "" { + return nil, fmt.Errorf(`unauthorized`) + } + + savedSessionSplit := strings.Split(savedSession, "@") + nonce := savedSessionSplit[0] + userID := savedSessionSplit[1] + + decryptedFingerPrint, err := crypto.DecryptAES(encryptedSession) + if err != nil { + return nil, err + } + + var res SessionData + err = json.Unmarshal([]byte(decryptedFingerPrint), &res) + if err != nil { + return nil, err + } + + if res.Nonce != nonce { + return nil, fmt.Errorf(`unauthorized: invalid nonce`) + } + + if res.Subject != userID { + return nil, fmt.Errorf(`unauthorized: invalid user id`) + } + + if res.ExpiresAt < time.Now().Unix() { + return nil, fmt.Errorf(`unauthorized: token expired`) + } + + // TODO validate scope + // if !reflect.DeepEqual(res.Roles, roles) { + // return res, "", fmt.Errorf(`unauthorized`) + // } + + return &res, nil +} + +// CreateIDToken util to create JWT token, based on +// user information, roles config and CUSTOM_ACCESS_TOKEN_SCRIPT +func CreateIDToken(user models.User, roles []string, hostname, nonce string) (string, int64, error) { expiryBound := time.Minute * 30 expiresAt := time.Now().Add(expiryBound).Unix() @@ -97,13 +256,13 @@ func CreateAccessToken(user models.User, roles []string) (string, int64, error) claimKey := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtRoleClaim) customClaims := jwt.MapClaims{ - "iss": "", + "iss": hostname, "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), - "nonce": "", + "nonce": nonce, "sub": user.ID, "exp": expiresAt, "iat": time.Now().Unix(), - "token_type": constants.TokenTypeAccessToken, + "token_type": constants.TokenTypeIdentityToken, "allowed_roles": strings.Split(user.Roles, ","), claimKey: roles, } @@ -152,58 +311,18 @@ func CreateAccessToken(user models.User, roles []string) (string, int64, error) return token, expiresAt, nil } -// GetAccessToken returns the access token from the request (either from header or cookie) -func GetAccessToken(gc *gin.Context) (string, error) { - token, err := cookie.GetAccessTokenCookie(gc) - if err != nil || token == "" { - // try to check in auth header for cookie - auth := gc.Request.Header.Get("Authorization") - if auth == "" { - return "", fmt.Errorf(`unauthorized`) - } - - token = strings.TrimPrefix(auth, "Bearer ") - - } - return token, nil -} - -// GetRefreshToken returns the refresh token from cookie / request query url -func GetRefreshToken(gc *gin.Context) (string, error) { - token, err := cookie.GetRefreshTokenCookie(gc) - - if err != nil || token == "" { +// GetIDToken returns the id token from the request header +func GetIDToken(gc *gin.Context) (string, error) { + // try to check in auth header for cookie + auth := gc.Request.Header.Get("Authorization") + if auth == "" { return "", fmt.Errorf(`unauthorized`) } + if !strings.HasPrefix(auth, "Bearer ") { + return "", fmt.Errorf(`not a bearer token`) + } + + token := strings.TrimPrefix(auth, "Bearer ") return token, nil } - -// GetFingerPrint returns the finger print from cookie -func GetFingerPrint(gc *gin.Context) (string, error) { - fingerPrint, err := cookie.GetFingerPrintCookie(gc) - if err != nil || fingerPrint == "" { - return "", fmt.Errorf(`no finger print`) - } - return fingerPrint, nil -} - -func ValidateAccessToken(gc *gin.Context) (map[string]interface{}, error) { - token, err := GetAccessToken(gc) - if err != nil { - return nil, err - } - - claims, err := ParseJWTToken(token) - if err != nil { - return nil, err - } - - // also validate if there is user session present with access token - sessions := sessionstore.GetUserSessions(claims["id"].(string)) - if len(sessions) == 0 { - return nil, errors.New("unauthorized") - } - - return claims, nil -} diff --git a/server/token/jwt.go b/server/token/jwt.go index 6631ffe..90f6333 100644 --- a/server/token/jwt.go +++ b/server/token/jwt.go @@ -44,7 +44,7 @@ func SignJWTToken(claims jwt.MapClaims) (string, error) { } // ParseJWTToken common util to parse jwt token -func ParseJWTToken(token string) (jwt.MapClaims, error) { +func ParseJWTToken(token, hostname, nonce, subject string) (jwt.MapClaims, error) { jwtType := envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyJwtType) signingMethod := jwt.GetSigningMethod(jwtType) @@ -87,5 +87,21 @@ func ParseJWTToken(token string) (jwt.MapClaims, error) { claims["exp"] = intExp claims["iat"] = intIat + if claims["aud"] != envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID) { + return claims, errors.New("invalid audience") + } + + if claims["nonce"] != nonce { + return claims, errors.New("invalid nonce") + } + + if claims["iss"] != hostname { + return claims, errors.New("invalid issuer") + } + + if claims["sub"] != subject { + return claims, errors.New("invalid subject") + } + return claims, nil } diff --git a/server/token/verification_token.go b/server/token/verification_token.go index ca9a64b..7765392 100644 --- a/server/token/verification_token.go +++ b/server/token/verification_token.go @@ -9,13 +9,15 @@ import ( ) // CreateVerificationToken creates a verification JWT token -func CreateVerificationToken(email, tokenType, hostname string) (string, error) { +func CreateVerificationToken(email, tokenType, hostname, nonceHash string) (string, error) { claims := jwt.MapClaims{ + "iss": hostname, + "aud": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyClientID), + "sub": email, "exp": time.Now().Add(time.Minute * 30).Unix(), "iat": time.Now().Unix(), "token_type": tokenType, - "email": email, - "host": hostname, + "nonce": nonceHash, "redirect_url": envstore.EnvStoreObj.GetStringStoreEnvVariable(constants.EnvKeyAppURL), } diff --git a/server/utils/common.go b/server/utils/common.go index 2438d3b..d4a8d51 100644 --- a/server/utils/common.go +++ b/server/utils/common.go @@ -20,7 +20,7 @@ func StringSliceContains(s []string, e string) bool { // SaveSessionInDB saves sessions generated for a given user with meta information // Do not store token here as that could be security breach -func SaveSessionInDB(userId string, c *gin.Context) { +func SaveSessionInDB(c *gin.Context, userId string) { sessionData := models.Session{ UserID: userId, UserAgent: GetUserAgent(c.Request), diff --git a/server/utils/nonce.go b/server/utils/nonce.go new file mode 100644 index 0000000..311b325 --- /dev/null +++ b/server/utils/nonce.go @@ -0,0 +1,36 @@ +package utils + +import ( + "github.com/google/uuid" + + "github.com/authorizerdev/authorizer/server/crypto" +) + +// GenerateNonce generats random nonce string and returns +// the nonce string, nonce hash, error +func GenerateNonce() (string, string, error) { + nonce := uuid.New().String() + nonceHash, err := crypto.EncryptAES(nonce) + if err != nil { + return "", "", err + } + return nonce, nonceHash, err +} + +// EncryptNonce nonce string +func EncryptNonce(nonce string) (string, error) { + nonceHash, err := crypto.EncryptAES(nonce) + if err != nil { + return "", err + } + return nonceHash, err +} + +// DecryptNonce nonce string +func DecryptNonce(nonceHash string) (string, error) { + nonce, err := crypto.DecryptAES(nonceHash) + if err != nil { + return "", err + } + return nonce, err +}