diff --git a/server/resolvers/deactivate_account.go b/server/resolvers/deactivate_account.go index 0773e5d..539575c 100644 --- a/server/resolvers/deactivate_account.go +++ b/server/resolvers/deactivate_account.go @@ -21,15 +21,15 @@ func DeactivateAccountResolver(ctx context.Context) (*model.Response, error) { log.Debug("Failed to get GinContext: ", err) return res, err } - userID, err := token.GetUserIDFromSessionOrAccessToken(gc) + tokenData, err := token.GetUserIDFromSessionOrAccessToken(gc) if err != nil { log.Debug("Failed GetUserIDFromSessionOrAccessToken: ", err) return res, err } log := log.WithFields(log.Fields{ - "user_id": userID, + "user_id": tokenData.UserID, }) - user, err := db.Provider.GetUserByID(ctx, userID) + user, err := db.Provider.GetUserByID(ctx, tokenData.UserID) if err != nil { log.Debug("Failed to get user by id: ", err) return res, err diff --git a/server/resolvers/logout.go b/server/resolvers/logout.go index 54ce873..0988e16 100644 --- a/server/resolvers/logout.go +++ b/server/resolvers/logout.go @@ -2,12 +2,10 @@ package resolvers import ( "context" - "encoding/json" log "github.com/sirupsen/logrus" "github.com/authorizerdev/authorizer/server/cookie" - "github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/token" @@ -22,31 +20,18 @@ func LogoutResolver(ctx context.Context) (*model.Response, error) { return nil, err } - // get fingerprint hash - fingerprintHash, err := cookie.GetSession(gc) + tokenData, err := token.GetUserIDFromSessionOrAccessToken(gc) if err != nil { - log.Debug("Failed to get fingerprint hash: ", err) + log.Debug("Failed GetUserIDFromSessionOrAccessToken: ", err) return nil, err } - decryptedFingerPrint, err := crypto.DecryptAES(fingerprintHash) - if err != nil { - log.Debug("Failed to decrypt fingerprint hash: ", err) - return nil, err + sessionKey := tokenData.UserID + if tokenData.LoginMethod != "" { + sessionKey = tokenData.LoginMethod + ":" + tokenData.UserID } - var sessionData token.SessionData - err = json.Unmarshal([]byte(decryptedFingerPrint), &sessionData) - if err != nil { - return nil, err - } - - sessionKey := sessionData.Subject - if sessionData.LoginMethod != "" { - sessionKey = sessionData.LoginMethod + ":" + sessionData.Subject - } - - memorystore.Provider.DeleteUserSession(sessionKey, sessionData.Nonce) + memorystore.Provider.DeleteUserSession(sessionKey, tokenData.Nonce) cookie.DeleteSession(gc) res := &model.Response{ diff --git a/server/resolvers/profile.go b/server/resolvers/profile.go index 521ce44..df7092c 100644 --- a/server/resolvers/profile.go +++ b/server/resolvers/profile.go @@ -20,15 +20,15 @@ func ProfileResolver(ctx context.Context) (*model.User, error) { log.Debug("Failed to get GinContext: ", err) return res, err } - userID, err := token.GetUserIDFromSessionOrAccessToken(gc) + tokenData, err := token.GetUserIDFromSessionOrAccessToken(gc) if err != nil { log.Debug("Failed GetUserIDFromSessionOrAccessToken: ", err) return res, err } log := log.WithFields(log.Fields{ - "user_id": userID, + "user_id": tokenData.UserID, }) - user, err := db.Provider.GetUserByID(ctx, userID) + user, err := db.Provider.GetUserByID(ctx, tokenData.UserID) if err != nil { log.Debug("Failed to get user: ", err) return res, err diff --git a/server/resolvers/update_profile.go b/server/resolvers/update_profile.go index 60ede61..3b9af37 100644 --- a/server/resolvers/update_profile.go +++ b/server/resolvers/update_profile.go @@ -36,7 +36,7 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) log.Debug("Failed to get GinContext: ", err) return res, err } - userID, err := token.GetUserIDFromSessionOrAccessToken(gc) + tokenData, err := token.GetUserIDFromSessionOrAccessToken(gc) if err != nil { log.Debug("Failed GetUserIDFromSessionOrAccessToken: ", err) return res, err @@ -48,9 +48,9 @@ func UpdateProfileResolver(ctx context.Context, params model.UpdateProfileInput) return res, fmt.Errorf("please enter at least one param to update") } log := log.WithFields(log.Fields{ - "user_id": userID, + "user_id": tokenData.UserID, }) - user, err := db.Provider.GetUserByID(ctx, userID) + user, err := db.Provider.GetUserByID(ctx, tokenData.UserID) if err != nil { log.Debug("Failed to get user by id: ", err) return res, err diff --git a/server/test/logout_test.go b/server/test/logout_test.go index 3d95cf5..0a79c0b 100644 --- a/server/test/logout_test.go +++ b/server/test/logout_test.go @@ -35,6 +35,30 @@ func logoutTests(t *testing.T, s TestSetup) { assert.NotNil(t, verifyRes) accessToken := *verifyRes.AccessToken assert.NotEmpty(t, accessToken) + // Test logout with access token + req.Header.Set("Authorization", "Bearer "+accessToken) + logoutRes, err := resolvers.LogoutResolver(ctx) + assert.Nil(t, err) + assert.NotNil(t, logoutRes) + assert.NotEmpty(t, logoutRes.Message) + req.Header.Set("Authorization", "") + + // Test logout with session cookie + magicLoginRes, err = resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ + Email: email, + }) + assert.NoError(t, err) + assert.NotNil(t, magicLoginRes) + verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) + assert.NoError(t, err) + assert.NotNil(t, verificationRequest) + verifyRes, err = resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ + Token: verificationRequest.Token, + }) + assert.NoError(t, err) + assert.NotNil(t, verifyRes) + accessToken = *verifyRes.AccessToken + assert.NotEmpty(t, accessToken) claims, err := token.ParseJWTToken(accessToken) assert.NoError(t, err) assert.NotEmpty(t, claims) diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 707865f..6570fd5 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -482,8 +482,15 @@ func GetIDToken(gc *gin.Context) (string, error) { return token, nil } +// SessionOrAccessTokenData is a struct to hold session or access token data +type SessionOrAccessTokenData struct { + UserID string + LoginMethod string + Nonce string +} + // GetUserIDFromSessionOrAccessToken returns the user id from the session or access token -func GetUserIDFromSessionOrAccessToken(gc *gin.Context) (string, error) { +func GetUserIDFromSessionOrAccessToken(gc *gin.Context) (*SessionOrAccessTokenData, error) { // First try to get the user id from the session isSession := true token, err := cookie.GetSession(gc) @@ -493,22 +500,30 @@ func GetUserIDFromSessionOrAccessToken(gc *gin.Context) (string, error) { token, err = GetAccessToken(gc) if err != nil || token == "" { log.Debug("Failed to get access token: ", err) - return "", fmt.Errorf(`unauthorized`) + return nil, fmt.Errorf(`unauthorized`) } } if isSession { claims, err := ValidateBrowserSession(gc, token) if err != nil { log.Debug("Failed to validate session token: ", err) - return "", fmt.Errorf(`unauthorized`) + return nil, fmt.Errorf(`unauthorized`) } - return claims.Subject, nil + return &SessionOrAccessTokenData{ + UserID: claims.Subject, + LoginMethod: claims.LoginMethod, + Nonce: claims.Nonce, + }, nil } // If not session, then validate the access token claims, err := ValidateAccessToken(gc, token) if err != nil { log.Debug("Failed to validate access token: ", err) - return "", fmt.Errorf(`unauthorized`) + return nil, fmt.Errorf(`unauthorized`) } - return claims["sub"].(string), nil + return &SessionOrAccessTokenData{ + UserID: claims["sub"].(string), + LoginMethod: claims["login_method"].(string), + Nonce: claims["nonce"].(string), + }, nil }