diff --git a/server/db/db.go b/server/db/db.go index ca14e16..9890878 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -1,8 +1,6 @@ package db import ( - "log" - "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db/providers" "github.com/authorizerdev/authorizer/server/db/providers/arangodb" @@ -14,7 +12,7 @@ import ( // Provider returns the current database provider var Provider providers.Provider -func InitDB() { +func InitDB() error { var err error isSQL := envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeArangodb && envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyDatabaseType) != constants.DbTypeMongodb @@ -24,21 +22,23 @@ func InitDB() { if isSQL { Provider, err = sql.NewProvider() if err != nil { - log.Fatal("=> error setting sql provider:", err) + return err } } if isArangoDB { Provider, err = arangodb.NewProvider() if err != nil { - log.Fatal("=> error setting arangodb provider:", err) + return err } } if isMongoDB { Provider, err = mongodb.NewProvider() if err != nil { - log.Fatal("=> error setting arangodb provider:", err) + return err } } + + return nil } diff --git a/server/env/env.go b/server/env/env.go index 03ad1de..576f438 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -1,7 +1,7 @@ package env import ( - "fmt" + "errors" "log" "os" "strings" @@ -15,7 +15,7 @@ import ( ) // InitRequiredEnv to initialize EnvData and through error if required env are not present -func InitRequiredEnv() { +func InitRequiredEnv() error { envPath := os.Getenv(constants.EnvKeyEnvPath) if envPath == "" { @@ -35,23 +35,23 @@ func InitRequiredEnv() { dbType := os.Getenv(constants.EnvKeyDatabaseType) dbName := os.Getenv(constants.EnvKeyDatabaseName) - if dbType == "" { + if strings.TrimSpace(dbType) == "" { if envstore.ARG_DB_TYPE != nil && *envstore.ARG_DB_TYPE != "" { - dbType = *envstore.ARG_DB_TYPE + dbType = strings.TrimSpace(*envstore.ARG_DB_TYPE) } if dbType == "" { - panic("DATABASE_TYPE is required") + return errors.New("invalid database type. DATABASE_TYPE is empty") } } - if dbURL == "" { + if strings.TrimSpace(dbURL) == "" { if envstore.ARG_DB_URL != nil && *envstore.ARG_DB_URL != "" { - dbURL = *envstore.ARG_DB_URL + dbURL = strings.TrimSpace(*envstore.ARG_DB_URL) } if dbURL == "" { - panic("DATABASE_URL is required") + return errors.New("invalid database url. DATABASE_URL is required") } } @@ -65,10 +65,11 @@ func InitRequiredEnv() { envstore.EnvInMemoryStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseURL, dbURL) envstore.EnvInMemoryStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseType, dbType) envstore.EnvInMemoryStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyDatabaseName, dbName) + return nil } // InitEnv to initialize EnvData and through error if required env are not present -func InitAllEnv() { +func InitAllEnv() error { envData, err := GetEnvData() if err != nil { log.Println("No env data found in db, using local clone of env data") @@ -134,7 +135,7 @@ func InitAllEnv() { } else { algo = envData.StringEnv[constants.EnvKeyJwtType] if !crypto.IsHMACA(algo) && !crypto.IsRSA(algo) && !crypto.IsECDSA(algo) { - panic("JWT_TYPE is invalid") + return errors.New("invalid JWT_TYPE") } } } @@ -163,12 +164,12 @@ func InitAllEnv() { if crypto.IsRSA(algo) { _, privateKey, publicKey, err = crypto.NewRSAKey() if err != nil { - panic(err) + return err } } else if crypto.IsECDSA(algo) { _, privateKey, publicKey, err = crypto.NewECDSAKey() if err != nil { - panic(err) + return err } } } else { @@ -176,28 +177,26 @@ func InitAllEnv() { if crypto.IsRSA(algo) { _, err = crypto.ParseRsaPrivateKeyFromPemStr(privateKey) if err != nil { - panic(err) + return err } _, err = crypto.ParseRsaPublicKeyFromPemStr(publicKey) if err != nil { - panic(err) + return err } } else if crypto.IsECDSA(algo) { _, err = crypto.ParseEcdsaPrivateKeyFromPemStr(privateKey) if err != nil { - panic(err) + return err } _, err = crypto.ParseEcdsaPublicKeyFromPemStr(publicKey) if err != nil { - panic(err) + return err } } - fmt.Println("=> keys parsed successfully") } - fmt.Println(privateKey) - fmt.Println(publicKey) + envData.StringEnv[constants.EnvKeyJwtPrivateKey] = privateKey envData.StringEnv[constants.EnvKeyJwtPublicKey] = publicKey } @@ -333,7 +332,7 @@ func InitAllEnv() { } if len(roles) > 0 && len(defaultRoles) == 0 && len(defaultRolesEnv) > 0 { - panic(`Invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`) + return errors.New(`invalid DEFAULT_ROLE environment variable. It can be one from give ROLES environment variable value`) } envData.SliceEnv[constants.EnvKeyRoles] = roles @@ -349,4 +348,5 @@ func InitAllEnv() { } envstore.EnvInMemoryStoreObj.UpdateEnvStore(envData) + return nil } diff --git a/server/main.go b/server/main.go index f41ccd3..89b8fc0 100644 --- a/server/main.go +++ b/server/main.go @@ -23,21 +23,43 @@ func main() { envstore.EnvInMemoryStoreObj.UpdateEnvVariable(constants.StringStoreIdentifier, constants.EnvKeyVersion, VERSION) - // initialize required envs (mainly db env & env file path) - env.InitRequiredEnv() - // initialize db provider - db.InitDB() - // initialize all envs - env.InitAllEnv() - // persist all envs - err := env.PersistEnv() + // initialize required envs (mainly db & env file path) + err := env.InitRequiredEnv() if err != nil { - log.Println("Error persisting env:", err) + log.Fatal("Error while initializing required envs:", err) } - sessionstore.InitSession() - oauth.InitOAuth() - router := routes.InitRouter() + // initialize db provider + err = db.InitDB() + if err != nil { + log.Fatalln("Error while initializing db:", err) + } + // initialize all envs + // (get if present from db else construct from os env + defaults) + err = env.InitAllEnv() + if err != nil { + log.Fatalln("Error while initializing env: ", err) + } + + // persist all envs + err = env.PersistEnv() + if err != nil { + log.Fatalln("Error while persisting env:", err) + } + + // initialize session store (redis or in-memory based on env) + err = sessionstore.InitSession() + if err != nil { + log.Fatalln("Error while initializing session store:", err) + } + + // initialize oauth providers based on env + err = oauth.InitOAuth() + if err != nil { + log.Fatalln("Error while initializing oauth:", err) + } + + router := routes.InitRouter() router.Run(":" + envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyPort)) } diff --git a/server/oauth/oauth.go b/server/oauth/oauth.go index c0ff694..5d79125 100644 --- a/server/oauth/oauth.go +++ b/server/oauth/oauth.go @@ -2,7 +2,6 @@ package oauth import ( "context" - "log" "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/envstore" @@ -32,12 +31,12 @@ var ( ) // InitOAuth initializes the OAuth providers based on EnvData -func InitOAuth() { +func InitOAuth() error { ctx := context.Background() if envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientID) != "" && envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyGoogleClientSecret) != "" { p, err := oidc.NewProvider(ctx, "https://accounts.google.com") if err != nil { - log.Fatalln("error creating oidc provider for google:", err) + return err } OIDCProviders.GoogleOIDC = p OAuthProviders.GoogleConfig = &oauth2.Config{ @@ -65,4 +64,6 @@ func InitOAuth() { Scopes: []string{"public_profile", "email"}, } } + + return nil } diff --git a/server/sessionstore/session.go b/server/sessionstore/session.go index 523dee6..0287a46 100644 --- a/server/sessionstore/session.go +++ b/server/sessionstore/session.go @@ -119,62 +119,64 @@ func RemoveSocialLoginState(key string) { } // InitializeSessionStore initializes the SessionStoreObj based on environment variables -func InitSession() { +func InitSession() error { if envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL) != "" { log.Println("using redis store to save sessions") - if isCluster(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL)) { - clusterOpt, err := getClusterOptions(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL)) + + redisURL := envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL) + redisURLHostPortsList := strings.Split(redisURL, ",") + + if len(redisURLHostPortsList) > 1 { + opt, err := redis.ParseURL(redisURLHostPortsList[0]) if err != nil { - log.Fatalln("Error parsing redis url:", err) + return err } + urls := []string{opt.Addr} + urlList := redisURLHostPortsList[1:] + urls = append(urls, urlList...) + clusterOpt := &redis.ClusterOptions{Addrs: urls} + rdb := redis.NewClusterClient(clusterOpt) ctx := context.Background() _, err = rdb.Ping(ctx).Result() if err != nil { - log.Fatalln("Error connecting to redis cluster server", err) + return err } SessionStoreObj.RedisMemoryStoreObj = &RedisStore{ ctx: ctx, store: rdb, } - return + + // return on successful initialization + return nil } + opt, err := redis.ParseURL(envstore.EnvInMemoryStoreObj.GetStringStoreEnvVariable(constants.EnvKeyRedisURL)) if err != nil { - log.Fatalln("Error parsing redis url:", err) + return err } + rdb := redis.NewClient(opt) ctx := context.Background() _, err = rdb.Ping(ctx).Result() - if err != nil { - log.Fatalln("Error connecting to redis server", err) + return err } + SessionStoreObj.RedisMemoryStoreObj = &RedisStore{ ctx: ctx, store: rdb, } - } else { - SessionStoreObj.InMemoryStoreObj = &InMemoryStore{ - store: map[string]map[string]string{}, - socialLoginState: map[string]string{}, - } + // return on successful initialization + return nil } -} -func isCluster(url string) bool { - return len(strings.Split(url, ",")) > 1 -} - -func getClusterOptions(url string) (*redis.ClusterOptions, error) { - hostPortsList := strings.Split(url, ",") - opt, err := redis.ParseURL(hostPortsList[0]) - if err != nil { - return nil, err + // if redis url is not set use in memory store + SessionStoreObj.InMemoryStoreObj = &InMemoryStore{ + store: map[string]map[string]string{}, + socialLoginState: map[string]string{}, } - urls := []string{opt.Addr} - urlList := hostPortsList[1:] - urls = append(urls, urlList...) - return &redis.ClusterOptions{Addrs: urls}, nil + + return nil }