diff --git a/Makefile b/Makefile index 0d7d38e..d661c3f 100644 --- a/Makefile +++ b/Makefile @@ -6,4 +6,4 @@ cmd: clean: rm -rf build test: - cd server && go clean --testcache && go test ./... \ No newline at end of file + cd server && go clean --testcache && go test -v ./... \ No newline at end of file diff --git a/server/env/env.go b/server/env/env.go index ffcd692..931cd72 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -87,15 +87,30 @@ func InitEnv() { allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",") allowedOrigins := []string{} + hasWildCard := false + for _, val := range allowedOriginsSplit { trimVal := strings.TrimSpace(val) if trimVal != "" { - allowedOrigins = append(allowedOrigins, trimVal) + if trimVal != "*" { + host, port := utils.GetHostParts(trimVal) + allowedOrigins = append(allowedOrigins, host+":"+port) + } else { + hasWildCard = true + allowedOrigins = append(allowedOrigins, trimVal) + break + } } } + + if len(allowedOrigins) > 1 && hasWildCard { + allowedOrigins = []string{"*"} + } + if len(allowedOrigins) == 0 { allowedOrigins = []string{"*"} } + constants.ALLOWED_ORIGINS = allowedOrigins if *ARG_AUTHORIZER_URL != "" { diff --git a/server/handlers/app.go b/server/handlers/app.go index 1ffeb14..4fb7070 100644 --- a/server/handlers/app.go +++ b/server/handlers/app.go @@ -49,7 +49,7 @@ func AppHandler() gin.HandlerFunc { stateObj.RedirectURL = strings.TrimSuffix(stateObj.RedirectURL, "/") // validate redirect url with allowed origins - if !utils.IsValidRedirectURL(stateObj.RedirectURL) { + if !utils.IsValidOrigin(stateObj.RedirectURL) { c.JSON(400, gin.H{"error": "invalid redirect url"}) return } diff --git a/server/integration_test/cors_test.go b/server/integration_test/cors_test.go new file mode 100644 index 0000000..226c1be --- /dev/null +++ b/server/integration_test/cors_test.go @@ -0,0 +1,44 @@ +package integration_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/env" + "github.com/authorizerdev/authorizer/server/middlewares" + "github.com/gin-contrib/location" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestCors(t *testing.T) { + constants.ENV_PATH = "../../.env.local" + env.InitEnv() + r := gin.Default() + r.Use(location.Default()) + r.Use(middlewares.GinContextToContextMiddleware()) + r.Use(middlewares.CORSMiddleware()) + allowedOrigin := "http://localhost:8080" // The allowed origin that you want to check + notAllowedOrigin := "http://myapp.com" + + server := httptest.NewServer(r) + defer server.Close() + + client := &http.Client{} + req, _ := http.NewRequest( + "GET", + "http://"+server.Listener.Addr().String()+"/api", + nil, + ) + req.Header.Add("Origin", allowedOrigin) + + get, _ := client.Do(req) + + // You should get your origin (or a * depending on your config) if the + // passed origin is allowed. + o := get.Header.Get("Access-Control-Allow-Origin") + assert.NotEqual(t, o, notAllowedOrigin, "Origins should not match") + assert.Equal(t, o, allowedOrigin, "Origins don't match") +} diff --git a/server/main.go b/server/main.go index a88c924..8fb5374 100644 --- a/server/main.go +++ b/server/main.go @@ -1,13 +1,10 @@ package main import ( - "context" - "log" - - "github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/env" "github.com/authorizerdev/authorizer/server/handlers" + "github.com/authorizerdev/authorizer/server/middlewares" "github.com/authorizerdev/authorizer/server/oauth" "github.com/authorizerdev/authorizer/server/session" "github.com/authorizerdev/authorizer/server/utils" @@ -15,39 +12,6 @@ import ( "github.com/gin-gonic/gin" ) -func GinContextToContextMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - if constants.AUTHORIZER_URL == "" { - url := location.Get(c) - constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host - log.Println("=> authorizer url:", constants.AUTHORIZER_URL) - } - ctx := context.WithValue(c.Request.Context(), "GinContextKey", c) - c.Request = c.Request.WithContext(ctx) - c.Next() - } -} - -// TODO use allowed origins for cors origin -// TODO throw error if url is not allowed -func CORSMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - origin := c.Request.Header.Get("Origin") - constants.APP_URL = origin - c.Writer.Header().Set("Access-Control-Allow-Origin", origin) - c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") - - if c.Request.Method == "OPTIONS" { - c.AbortWithStatus(204) - return - } - - c.Next() - } -} - func main() { env.InitEnv() db.InitDB() @@ -57,8 +21,8 @@ func main() { r := gin.Default() r.Use(location.Default()) - r.Use(GinContextToContextMiddleware()) - r.Use(CORSMiddleware()) + r.Use(middlewares.GinContextToContextMiddleware()) + r.Use(middlewares.CORSMiddleware()) r.GET("/", handlers.PlaygroundHandler()) r.POST("/graphql", handlers.GraphqlHandler()) diff --git a/server/middlewares/context.go b/server/middlewares/context.go new file mode 100644 index 0000000..390a078 --- /dev/null +++ b/server/middlewares/context.go @@ -0,0 +1,23 @@ +package middlewares + +import ( + "context" + "log" + + "github.com/authorizerdev/authorizer/server/constants" + "github.com/gin-contrib/location" + "github.com/gin-gonic/gin" +) + +func GinContextToContextMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if constants.AUTHORIZER_URL == "" { + url := location.Get(c) + constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host + log.Println("=> authorizer url:", constants.AUTHORIZER_URL) + } + ctx := context.WithValue(c.Request.Context(), "GinContextKey", c) + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} diff --git a/server/middlewares/cors.go b/server/middlewares/cors.go new file mode 100644 index 0000000..e0bb4a9 --- /dev/null +++ b/server/middlewares/cors.go @@ -0,0 +1,29 @@ +package middlewares + +import ( + "github.com/authorizerdev/authorizer/server/constants" + "github.com/authorizerdev/authorizer/server/utils" + "github.com/gin-gonic/gin" +) + +func CORSMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + origin := c.Request.Header.Get("Origin") + constants.APP_URL = origin + + if utils.IsValidOrigin(origin) { + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + } + + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + } +} diff --git a/server/utils/cookie.go b/server/utils/cookie.go index a7598b4..38cf327 100644 --- a/server/utils/cookie.go +++ b/server/utils/cookie.go @@ -10,7 +10,7 @@ import ( func SetCookie(gc *gin.Context, token string) { secure := true httpOnly := true - host := GetHostName(constants.AUTHORIZER_URL) + host, _ := GetHostParts(constants.AUTHORIZER_URL) domain := GetDomainName(constants.AUTHORIZER_URL) if domain != "localhost" { domain = "." + domain @@ -37,7 +37,7 @@ func DeleteCookie(gc *gin.Context) { secure := true httpOnly := true - host := GetDomainName(constants.AUTHORIZER_URL) + host, _ := GetHostParts(constants.AUTHORIZER_URL) domain := GetDomainName(constants.AUTHORIZER_URL) if domain != "localhost" { domain = "." + domain diff --git a/server/utils/urls.go b/server/utils/urls.go index 3a850f1..25acac3 100644 --- a/server/utils/urls.go +++ b/server/utils/urls.go @@ -5,21 +5,32 @@ import ( "strings" ) -// GetHostName function to get hostname -func GetHostName(auth_url string) string { - u, err := url.Parse(auth_url) +// GetHostName function returns hostname and port +func GetHostParts(uri string) (string, string) { + tempURI := uri + if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") { + tempURI = "https://" + tempURI + } + + u, err := url.Parse(tempURI) if err != nil { - return `localhost` + return "localhost", "8080" } host := u.Hostname() + port := u.Port() - return host + return host, port } // GetDomainName function to get domain name -func GetDomainName(auth_url string) string { - u, err := url.Parse(auth_url) +func GetDomainName(uri string) string { + tempURI := uri + if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") { + tempURI = "https://" + tempURI + } + + u, err := url.Parse(tempURI) if err != nil { return `localhost` } diff --git a/server/utils/urls_test.go b/server/utils/urls_test.go index b54d386..19e53a9 100644 --- a/server/utils/urls_test.go +++ b/server/utils/urls_test.go @@ -7,12 +7,13 @@ import ( ) func TestGetHostName(t *testing.T) { - authorizer_url := "http://test.herokuapp.com" + authorizer_url := "http://test.herokuapp.com:80" - got := GetHostName(authorizer_url) - want := "test.herokuapp.com" + host, port := GetHostParts(authorizer_url) + expectedHost := "test.herokuapp.com" - assert.Equal(t, got, want, "hostname should be equal") + assert.Equal(t, host, expectedHost, "hostname should be equal") + assert.Equal(t, port, "80", "port should be 80") } func TestGetDomainName(t *testing.T) { diff --git a/server/utils/validator.go b/server/utils/validator.go index e19ae58..1d3ffb0 100644 --- a/server/utils/validator.go +++ b/server/utils/validator.go @@ -2,6 +2,7 @@ package utils import ( "net/mail" + "regexp" "strings" "github.com/authorizerdev/authorizer/server/constants" @@ -13,16 +14,32 @@ func IsValidEmail(email string) bool { return err == nil } -func IsValidRedirectURL(url string) bool { +func IsValidOrigin(url string) bool { if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" { return true } hasValidURL := false - urlDomain := GetDomainName(url) + hostName, port := GetHostParts(url) + currentOrigin := hostName + ":" + port - for _, val := range constants.ALLOWED_ORIGINS { - if strings.Contains(val, urlDomain) { + for _, origin := range constants.ALLOWED_ORIGINS { + replacedString := origin + // if has regex whitelisted domains + if strings.Contains(origin, "*") { + replacedString = strings.Replace(origin, ".", "\\.", -1) + replacedString = strings.Replace(replacedString, "*", ".*", -1) + + if strings.HasPrefix(replacedString, ".*") { + replacedString += "\\b" + } + + if strings.HasSuffix(replacedString, ".*") { + replacedString = "\\b" + replacedString + } + } + + if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched { hasValidURL = true break } diff --git a/server/utils/validator_test.go b/server/utils/validator_test.go index 1ecd269..f342ae5 100644 --- a/server/utils/validator_test.go +++ b/server/utils/validator_test.go @@ -3,6 +3,7 @@ package utils import ( "testing" + "github.com/authorizerdev/authorizer/server/constants" "github.com/stretchr/testify/assert" ) @@ -15,3 +16,19 @@ func TestIsValidEmail(t *testing.T) { assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email") assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email") } + +func TestIsValidOrigin(t *testing.T) { + // don't use portocal(http/https) for ALLOWED_ORIGINS while testing, + // as we trim them off while running the main function + constants.ALLOWED_ORIGINS = []string{"localhost:8080", "*.google.com", "*.google.in", "*abc.*"} + + assert.False(t, IsValidOrigin("http://myapp.com"), "it should be invalid origin") + assert.False(t, IsValidOrigin("http://appgoogle.com"), "it should be invalid origin") + assert.True(t, IsValidOrigin("http://app.google.com"), "it should be valid origin") + assert.False(t, IsValidOrigin("http://app.google.ind"), "it should be invalid origin") + assert.True(t, IsValidOrigin("http://app.google.in"), "it should be valid origin") + assert.True(t, IsValidOrigin("http://xyx.abc.com"), "it should be valid origin") + assert.True(t, IsValidOrigin("http://xyx.abc.in"), "it should be valid origin") + assert.True(t, IsValidOrigin("http://xyxabc.in"), "it should be valid origin") + assert.True(t, IsValidOrigin("http://localhost:8080"), "it should be valid origin") +}