diff --git a/server/handlers/token.go b/server/handlers/token.go index 4a18a4d..fdfbf9f 100644 --- a/server/handlers/token.go +++ b/server/handlers/token.go @@ -15,6 +15,7 @@ import ( "github.com/gin-gonic/gin" ) +// grant type required func TokenHandler() gin.HandlerFunc { return func(gc *gin.Context) { var reqBody map[string]string @@ -29,6 +30,22 @@ func TokenHandler() gin.HandlerFunc { codeVerifier := strings.TrimSpace(reqBody["code_verifier"]) code := strings.TrimSpace(reqBody["code"]) clientID := strings.TrimSpace(reqBody["client_id"]) + grantType := strings.TrimSpace(reqBody["grant_type"]) + refreshToken := strings.TrimSpace(reqBody["refresh_token"]) + + if grantType == "" { + grantType = "authorization_code" + } + + isRefreshTokenGrant := grantType == "refresh_token" + isAuthorizationCodeGrant := grantType == "authorization_code" + + if !isRefreshTokenGrant && !isAuthorizationCodeGrant { + gc.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid_grant_type", + "error_description": "grant_type is invalid", + }) + } if clientID == "" { gc.JSON(http.StatusBadRequest, gin.H{ @@ -46,58 +63,87 @@ func TokenHandler() gin.HandlerFunc { return } - if codeVerifier == "" { - gc.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid_code_verifier", - "error_description": "The code verifier is required", - }) - return + var userID string + var roles, scope []string + if isAuthorizationCodeGrant { + + if codeVerifier == "" { + gc.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid_code_verifier", + "error_description": "The code verifier is required", + }) + return + } + + if code == "" { + gc.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid_code", + "error_description": "The code is required", + }) + return + } + + hash := sha256.New() + hash.Write([]byte(codeVerifier)) + encryptedCode := strings.ReplaceAll(base64.URLEncoding.EncodeToString(hash.Sum(nil)), "+", "-") + encryptedCode = strings.ReplaceAll(encryptedCode, "/", "_") + encryptedCode = strings.ReplaceAll(encryptedCode, "=", "") + sessionData := sessionstore.GetState(encryptedCode) + if sessionData == "" { + gc.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid_code_verifier", + "error_description": "The code verifier is invalid", + }) + return + } + + // split session data + // it contains code@sessiontoken + sessionDataSplit := strings.Split(sessionData, "@") + + if sessionDataSplit[0] != code { + gc.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid_code_verifier", + "error_description": "The code verifier is invalid", + }) + return + } + + // rollover the session for security + sessionstore.RemoveState(sessionDataSplit[1]) + // validate session + claims, err := token.ValidateBrowserSession(gc, sessionDataSplit[1]) + if err != nil { + gc.JSON(http.StatusUnauthorized, gin.H{ + "error": "unauthorized", + "error_description": "Invalid session data", + }) + return + } + userID = claims.Subject + roles = claims.Roles + scope = claims.Scope + } else { + // validate refresh token + if refreshToken == "" { + gc.JSON(http.StatusBadRequest, gin.H{ + "error": "invalid_refresh_token", + "error_description": "The refresh token is invalid", + }) + } + + claims, err := token.ValidateRefreshToken(gc, refreshToken) + if err != nil { + gc.JSON(http.StatusUnauthorized, gin.H{ + "error": "unauthorized", + "error_description": err.Error(), + }) + } + userID = claims["sub"].(string) + roles = claims["roles"].([]string) + scope = claims["scope"].([]string) } - if code == "" { - gc.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid_code", - "error_description": "The code is required", - }) - return - } - - hash := sha256.New() - hash.Write([]byte(codeVerifier)) - encryptedCode := strings.ReplaceAll(base64.URLEncoding.EncodeToString(hash.Sum(nil)), "+", "-") - encryptedCode = strings.ReplaceAll(encryptedCode, "/", "_") - encryptedCode = strings.ReplaceAll(encryptedCode, "=", "") - sessionData := sessionstore.GetState(encryptedCode) - if sessionData == "" { - gc.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid_code_verifier", - "error_description": "The code verifier is invalid", - }) - return - } - - // split session data - // it contains code@sessiontoken - sessionDataSplit := strings.Split(sessionData, "@") - - if sessionDataSplit[0] != code { - gc.JSON(http.StatusBadRequest, gin.H{ - "error": "invalid_code_verifier", - "error_description": "The code verifier is invalid", - }) - return - } - - // validate session - claims, err := token.ValidateBrowserSession(gc, sessionDataSplit[1]) - if err != nil { - gc.JSON(http.StatusUnauthorized, gin.H{ - "error": "unauthorized", - "error_description": "Invalid session data", - }) - return - } - userID := claims.Subject user, err := db.Provider.GetUserByID(userID) if err != nil { gc.JSON(http.StatusUnauthorized, gin.H{ @@ -106,9 +152,8 @@ func TokenHandler() gin.HandlerFunc { }) return } - // rollover the session for security - sessionstore.RemoveState(sessionDataSplit[1]) - authToken, err := token.CreateAuthToken(gc, user, claims.Roles, claims.Scope) + + authToken, err := token.CreateAuthToken(gc, user, roles, scope) if err != nil { gc.JSON(http.StatusUnauthorized, gin.H{ "error": "unauthorized", @@ -124,7 +169,8 @@ func TokenHandler() gin.HandlerFunc { res := map[string]interface{}{ "access_token": authToken.AccessToken.Token, "id_token": authToken.IDToken.Token, - "scope": claims.Scope, + "scope": scope, + "roles": roles, "expires_in": expiresIn, } diff --git a/server/token/auth_token.go b/server/token/auth_token.go index 16abbd6..350da17 100644 --- a/server/token/auth_token.go +++ b/server/token/auth_token.go @@ -91,7 +91,7 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string) ( } if utils.StringSliceContains(scope, "offline_access") { - refreshToken, refreshTokenExpiresAt, err := CreateRefreshToken(user, roles, hostname, nonce) + refreshToken, refreshTokenExpiresAt, err := CreateRefreshToken(user, roles, scope, hostname, nonce) if err != nil { return nil, err } @@ -103,7 +103,7 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string) ( } // CreateRefreshToken util to create JWT token -func CreateRefreshToken(user models.User, roles []string, hostname, nonce string) (string, int64, error) { +func CreateRefreshToken(user models.User, roles, scopes []string, hostname, nonce string) (string, int64, error) { // expires in 1 year expiryBound := time.Hour * 8760 expiresAt := time.Now().Add(expiryBound).Unix() @@ -115,6 +115,7 @@ func CreateRefreshToken(user models.User, roles []string, hostname, nonce string "iat": time.Now().Unix(), "token_type": constants.TokenTypeRefreshToken, "roles": roles, + "scope": scopes, "nonce": nonce, } @@ -198,6 +199,36 @@ func ValidateAccessToken(gc *gin.Context, accessToken string) (map[string]interf return res, nil } +// Function to validate refreshToken +func ValidateRefreshToken(gc *gin.Context, refreshToken string) (map[string]interface{}, error) { + var res map[string]interface{} + + if refreshToken == "" { + return res, fmt.Errorf(`unauthorized`) + } + + savedSession := sessionstore.GetState(refreshToken) + if savedSession == "" { + return res, fmt.Errorf(`unauthorized`) + } + + savedSessionSplit := strings.Split(savedSession, "@") + nonce := savedSessionSplit[0] + userID := savedSessionSplit[1] + + hostname := utils.GetHost(gc) + res, err := ParseJWTToken(refreshToken, hostname, nonce, userID) + if err != nil { + return res, err + } + + if res["token_type"] != constants.TokenTypeRefreshToken { + return res, fmt.Errorf(`unauthorized: invalid token type`) + } + + return res, nil +} + func ValidateBrowserSession(gc *gin.Context, encryptedSession string) (*SessionData, error) { if encryptedSession == "" { return nil, fmt.Errorf(`unauthorized`)