diff --git a/server/env.go b/server/env.go index 31dd673..d41fc8e 100644 --- a/server/env.go +++ b/server/env.go @@ -163,7 +163,7 @@ func InitEnv() { roles = append(roles, trimVal) } - if utils.StringContains(defaultRoleSplit, trimVal) { + if utils.StringSliceContains(defaultRoleSplit, trimVal) { defaultRoles = append(defaultRoles, trimVal) } } diff --git a/server/handlers/oauthCallback.go b/server/handlers/oauthCallback.go index 3abc0af..b920d8a 100644 --- a/server/handlers/oauthCallback.go +++ b/server/handlers/oauthCallback.go @@ -19,7 +19,7 @@ import ( "golang.org/x/oauth2" ) -func processGoogleUserInfo(code string, roles []string, c *gin.Context) (db.User, error) { +func processGoogleUserInfo(code string) (db.User, error) { user := db.User{} token, err := oauth.OAuthProvider.GoogleConfig.Exchange(oauth2.NoContext, code) if err != nil { @@ -40,7 +40,6 @@ func processGoogleUserInfo(code string, roles []string, c *gin.Context) (db.User userRawData := make(map[string]string) json.Unmarshal(body, &userRawData) - existingUser, err := db.Mgr.GetUserByEmail(userRawData["email"]) user = db.User{ FirstName: userRawData["given_name"], LastName: userRawData["family_name"], @@ -48,30 +47,11 @@ func processGoogleUserInfo(code string, roles []string, c *gin.Context) (db.User Email: userRawData["email"], EmailVerifiedAt: time.Now().Unix(), } - if err != nil { - // user not registered, register user and generate session token - user.SignupMethod = enum.Google.String() - user.Roles = strings.Join(roles, ",") - } else { - // user exists in db, check if method was google - // if not append google to existing signup method and save it - signupMethod := existingUser.SignupMethod - if !strings.Contains(signupMethod, enum.Google.String()) { - signupMethod = signupMethod + "," + enum.Google.String() - } - user.SignupMethod = signupMethod - user.Password = existingUser.Password - if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) { - return user, fmt.Errorf("invalid role") - } - - user.Roles = existingUser.Roles - } return user, nil } -func processGithubUserInfo(code string, roles []string, c *gin.Context) (db.User, error) { +func processGithubUserInfo(code string) (db.User, error) { user := db.User{} token, err := oauth.OAuthProvider.GithubConfig.Exchange(oauth2.NoContext, code) if err != nil { @@ -100,7 +80,6 @@ func processGithubUserInfo(code string, roles []string, c *gin.Context) (db.User userRawData := make(map[string]string) json.Unmarshal(body, &userRawData) - existingUser, err := db.Mgr.GetUserByEmail(userRawData["email"]) name := strings.Split(userRawData["name"], " ") firstName := "" lastName := "" @@ -117,32 +96,11 @@ func processGithubUserInfo(code string, roles []string, c *gin.Context) (db.User Email: userRawData["email"], EmailVerifiedAt: time.Now().Unix(), } - if err != nil { - // user not registered, register user and generate session token - user.SignupMethod = enum.Github.String() - user.Roles = strings.Join(roles, ",") - } else { - // user exists in db, check if method was google - // if not append google to existing signup method and save it - - signupMethod := existingUser.SignupMethod - if !strings.Contains(signupMethod, enum.Github.String()) { - signupMethod = signupMethod + "," + enum.Github.String() - } - user.SignupMethod = signupMethod - user.Password = existingUser.Password - - if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) { - return user, fmt.Errorf("invalid role") - } - - user.Roles = existingUser.Roles - } return user, nil } -func processFacebookUserInfo(code string, roles []string, c *gin.Context) (db.User, error) { +func processFacebookUserInfo(code string) (db.User, error) { user := db.User{} token, err := oauth.OAuthProvider.FacebookConfig.Exchange(oauth2.NoContext, code) if err != nil { @@ -170,7 +128,6 @@ func processFacebookUserInfo(code string, roles []string, c *gin.Context) (db.Us json.Unmarshal(body, &userRawData) email := fmt.Sprintf("%v", userRawData["email"]) - existingUser, err := db.Mgr.GetUserByEmail(email) picObject := userRawData["picture"].(map[string]interface{})["data"] picDataObject := picObject.(map[string]interface{}) @@ -182,28 +139,6 @@ func processFacebookUserInfo(code string, roles []string, c *gin.Context) (db.Us EmailVerifiedAt: time.Now().Unix(), } - if err != nil { - // user not registered, register user and generate session token - user.SignupMethod = enum.Github.String() - user.Roles = strings.Join(roles, ",") - } else { - // user exists in db, check if method was google - // if not append google to existing signup method and save it - - signupMethod := existingUser.SignupMethod - if !strings.Contains(signupMethod, enum.Github.String()) { - signupMethod = signupMethod + "," + enum.Github.String() - } - user.SignupMethod = signupMethod - user.Password = existingUser.Password - - if !utils.IsValidRoles(strings.Split(existingUser.Roles, ","), roles) { - return user, fmt.Errorf("invalid role") - } - - user.Roles = existingUser.Roles - } - return user, nil } @@ -226,7 +161,7 @@ func OAuthCallbackHandler() gin.HandlerFunc { return } - roles := strings.Split(sessionSplit[2], ",") + inputRoles := strings.Split(sessionSplit[2], ",") redirectURL := sessionSplit[1] var err error @@ -234,11 +169,11 @@ func OAuthCallbackHandler() gin.HandlerFunc { code := c.Request.FormValue("code") switch provider { case enum.Google.String(): - user, err = processGoogleUserInfo(code, roles, c) + user, err = processGoogleUserInfo(code) case enum.Github.String(): - user, err = processGithubUserInfo(code, roles, c) + user, err = processGithubUserInfo(code) case enum.Facebook.String(): - user, err = processFacebookUserInfo(code, roles, c) + user, err = processFacebookUserInfo(code) default: err = fmt.Errorf(`invalid oauth provider`) } @@ -248,12 +183,76 @@ func OAuthCallbackHandler() gin.HandlerFunc { return } + existingUser, err := db.Mgr.GetUserByEmail(user.Email) + + if err != nil { + // user not registered, register user and generate session token + user.SignupMethod = provider + // make sure inputRoles don't include protected roles + hasProtectedRole := false + for _, ir := range inputRoles { + if utils.StringSliceContains(constants.PROTECTED_ROLES, ir) { + hasProtectedRole = true + } + } + + if hasProtectedRole { + c.JSON(400, gin.H{"error": "invalid role"}) + return + } + + user.Roles = strings.Join(inputRoles, ",") + } else { + // user exists in db, check if method was google + // if not append google to existing signup method and save it + + signupMethod := existingUser.SignupMethod + if !strings.Contains(signupMethod, provider) { + signupMethod = signupMethod + "," + enum.Github.String() + } + user.SignupMethod = signupMethod + user.Password = existingUser.Password + + // There multiple scenarios with roles here in social login + // 1. user has access to protected roles + roles and trying to login + // 2. user has not signed up for one of the available role but trying to signup. + // Need to modify roles in this case + + // find the unassigned roles + existingRoles := strings.Split(existingUser.Roles, ",") + unasignedRoles := []string{} + for _, ir := range inputRoles { + if !utils.StringSliceContains(existingRoles, ir) { + unasignedRoles = append(unasignedRoles, ir) + } + } + + if len(unasignedRoles) > 0 { + // check if it contains protected unassigned role + hasProtectedRole := false + for _, ur := range unasignedRoles { + if utils.StringSliceContains(constants.PROTECTED_ROLES, ur) { + hasProtectedRole = true + } + } + + if hasProtectedRole { + c.JSON(400, gin.H{"error": "invalid role"}) + return + } else { + user.Roles = existingUser.Roles + "," + strings.Join(unasignedRoles, ",") + } + } else { + user.Roles = existingUser.Roles + } + } + user, _ = db.Mgr.SaveUser(user) user, _ = db.Mgr.GetUserByEmail(user.Email) userIdStr := fmt.Sprintf("%v", user.ID) - refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, roles) + refreshToken, _, _ := utils.CreateAuthToken(user, enum.RefreshToken, inputRoles) - accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, roles) + accessToken, _, _ := utils.CreateAuthToken(user, enum.AccessToken, inputRoles) utils.SetCookie(c, accessToken) session.SetToken(userIdStr, refreshToken) diff --git a/server/resolvers/token.go b/server/resolvers/token.go index 9345300..b00bf12 100644 --- a/server/resolvers/token.go +++ b/server/resolvers/token.go @@ -54,7 +54,7 @@ func Token(ctx context.Context, roles []string) (*model.AuthResponse, error) { if len(roles) > 0 { for _, v := range roles { - if !utils.StringContains(claimRoles, v) { + if !utils.StringSliceContains(claimRoles, v) { return res, fmt.Errorf(`unauthorized`) } } diff --git a/server/utils/common.go b/server/utils/common.go index 82f2570..0a58959 100644 --- a/server/utils/common.go +++ b/server/utils/common.go @@ -19,7 +19,7 @@ func WriteToFile(filename string, data string) error { return file.Sync() } -func StringContains(s []string, e string) bool { +func StringSliceContains(s []string, e string) bool { for _, a := range s { if a == e { return true diff --git a/server/utils/initServer.go b/server/utils/initServer.go index a15c08d..4634c4a 100644 --- a/server/utils/initServer.go +++ b/server/utils/initServer.go @@ -18,6 +18,11 @@ func InitServer() { Role: val, }) } + for _, val := range constants.PROTECTED_ROLES { + roles = append(roles, db.Role{ + Role: val, + }) + } err := db.Mgr.SaveRoles(roles) if err != nil { log.Println(`Error saving roles`, err) diff --git a/server/utils/validator.go b/server/utils/validator.go index 1a02376..e19ae58 100644 --- a/server/utils/validator.go +++ b/server/utils/validator.go @@ -43,7 +43,7 @@ func IsSuperAdmin(gc *gin.Context) bool { func IsValidRoles(userRoles []string, roles []string) bool { valid := true for _, role := range roles { - if !StringContains(userRoles, role) { + if !StringSliceContains(userRoles, role) { valid = false break }