|
|
|
@@ -28,21 +28,21 @@ import (
|
|
|
|
|
|
|
|
|
|
// OAuthCallbackHandler handles the OAuth callback for various oauth providers
|
|
|
|
|
func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
return func(c *gin.Context) {
|
|
|
|
|
provider := c.Param("oauth_provider")
|
|
|
|
|
state := c.Request.FormValue("state")
|
|
|
|
|
return func(ctx *gin.Context) {
|
|
|
|
|
provider := ctx.Param("oauth_provider")
|
|
|
|
|
state := ctx.Request.FormValue("state")
|
|
|
|
|
|
|
|
|
|
sessionState, err := memorystore.Provider.GetState(state)
|
|
|
|
|
if sessionState == "" || err != nil {
|
|
|
|
|
log.Debug("Invalid oauth state: ", state)
|
|
|
|
|
c.JSON(400, gin.H{"error": "invalid oauth state"})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": "invalid oauth state"})
|
|
|
|
|
}
|
|
|
|
|
// contains random token, redirect url, role
|
|
|
|
|
sessionSplit := strings.Split(state, "___")
|
|
|
|
|
|
|
|
|
|
if len(sessionSplit) < 3 {
|
|
|
|
|
log.Debug("Unable to get redirect url from state: ", state)
|
|
|
|
|
c.JSON(400, gin.H{"error": "invalid redirect url"})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": "invalid redirect url"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -55,7 +55,7 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
scopes := strings.Split(sessionSplit[3], ",")
|
|
|
|
|
|
|
|
|
|
user := models.User{}
|
|
|
|
|
code := c.Request.FormValue("code")
|
|
|
|
|
code := ctx.Request.FormValue("code")
|
|
|
|
|
switch provider {
|
|
|
|
|
case constants.AuthRecipeMethodGoogle:
|
|
|
|
|
user, err = processGoogleUserInfo(code)
|
|
|
|
@@ -74,23 +74,23 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Debug("Failed to process user info: ", err)
|
|
|
|
|
c.JSON(400, gin.H{"error": err.Error()})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
existingUser, err := db.Provider.GetUserByEmail(user.Email)
|
|
|
|
|
existingUser, err := db.Provider.GetUserByEmail(ctx, user.Email)
|
|
|
|
|
log := log.WithField("user", user.Email)
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
isSignupDisabled, err := memorystore.Provider.GetBoolStoreEnvVariable(constants.EnvKeyDisableSignUp)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Debug("Failed to get signup disabled env variable: ", err)
|
|
|
|
|
c.JSON(400, gin.H{"error": err.Error()})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if isSignupDisabled {
|
|
|
|
|
log.Debug("Failed to signup as disabled")
|
|
|
|
|
c.JSON(400, gin.H{"error": "signup is disabled for this instance"})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": "signup is disabled for this instance"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
// user not registered, register user and generate session token
|
|
|
|
@@ -113,19 +113,19 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
|
|
|
|
|
if hasProtectedRole {
|
|
|
|
|
log.Debug("Signup is not allowed with protected roles:", inputRoles)
|
|
|
|
|
c.JSON(400, gin.H{"error": "invalid role"})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": "invalid role"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
user.Roles = strings.Join(inputRoles, ",")
|
|
|
|
|
now := time.Now().Unix()
|
|
|
|
|
user.EmailVerifiedAt = &now
|
|
|
|
|
user, _ = db.Provider.AddUser(user)
|
|
|
|
|
user, _ = db.Provider.AddUser(ctx, user)
|
|
|
|
|
} else {
|
|
|
|
|
user = existingUser
|
|
|
|
|
if user.RevokedTimestamp != nil {
|
|
|
|
|
log.Debug("User access revoked at: ", user.RevokedTimestamp)
|
|
|
|
|
c.JSON(400, gin.H{"error": "user access has been revoked"})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": "user access has been revoked"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -175,7 +175,7 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
|
|
|
|
|
if hasProtectedRole {
|
|
|
|
|
log.Debug("Invalid role. User is using protected unassigned role")
|
|
|
|
|
c.JSON(400, gin.H{"error": "invalid role"})
|
|
|
|
|
ctx.JSON(400, gin.H{"error": "invalid role"})
|
|
|
|
|
return
|
|
|
|
|
} else {
|
|
|
|
|
user.Roles = existingUser.Roles + "," + strings.Join(unasignedRoles, ",")
|
|
|
|
@@ -184,18 +184,18 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
user.Roles = existingUser.Roles
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
user, err = db.Provider.UpdateUser(user)
|
|
|
|
|
user, err = db.Provider.UpdateUser(ctx, user)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Debug("Failed to update user: ", err)
|
|
|
|
|
c.JSON(500, gin.H{"error": err.Error()})
|
|
|
|
|
ctx.JSON(500, gin.H{"error": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
authToken, err := token.CreateAuthToken(c, user, inputRoles, scopes, provider)
|
|
|
|
|
authToken, err := token.CreateAuthToken(ctx, user, inputRoles, scopes, provider)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Debug("Failed to create auth token: ", err)
|
|
|
|
|
c.JSON(500, gin.H{"error": err.Error()})
|
|
|
|
|
ctx.JSON(500, gin.H{"error": err.Error()})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
|
|
|
|
@@ -206,7 +206,7 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
params := "access_token=" + authToken.AccessToken.Token + "&token_type=bearer&expires_in=" + strconv.FormatInt(expiresIn, 10) + "&state=" + stateValue + "&id_token=" + authToken.IDToken.Token
|
|
|
|
|
|
|
|
|
|
sessionKey := provider + ":" + user.ID
|
|
|
|
|
cookie.SetSession(c, authToken.FingerPrintHash)
|
|
|
|
|
cookie.SetSession(ctx, authToken.FingerPrintHash)
|
|
|
|
|
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash)
|
|
|
|
|
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token)
|
|
|
|
|
|
|
|
|
@@ -215,10 +215,10 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
go db.Provider.AddSession(models.Session{
|
|
|
|
|
go db.Provider.AddSession(ctx, models.Session{
|
|
|
|
|
UserID: user.ID,
|
|
|
|
|
UserAgent: utils.GetUserAgent(c.Request),
|
|
|
|
|
IP: utils.GetIP(c.Request),
|
|
|
|
|
UserAgent: utils.GetUserAgent(ctx.Request),
|
|
|
|
|
IP: utils.GetIP(ctx.Request),
|
|
|
|
|
})
|
|
|
|
|
if strings.Contains(redirectURL, "?") {
|
|
|
|
|
redirectURL = redirectURL + "&" + params
|
|
|
|
@@ -226,7 +226,7 @@ func OAuthCallbackHandler() gin.HandlerFunc {
|
|
|
|
|
redirectURL = redirectURL + "?" + strings.TrimPrefix(params, "&")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.Redirect(http.StatusFound, redirectURL)
|
|
|
|
|
ctx.Redirect(http.StatusFound, redirectURL)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|