@@ -10,7 +10,7 @@ import (
|
||||
func SetCookie(gc *gin.Context, token string) {
|
||||
secure := true
|
||||
httpOnly := true
|
||||
host := GetHostName(constants.AUTHORIZER_URL)
|
||||
host, _ := GetHostParts(constants.AUTHORIZER_URL)
|
||||
domain := GetDomainName(constants.AUTHORIZER_URL)
|
||||
if domain != "localhost" {
|
||||
domain = "." + domain
|
||||
@@ -37,7 +37,7 @@ func DeleteCookie(gc *gin.Context) {
|
||||
secure := true
|
||||
httpOnly := true
|
||||
|
||||
host := GetDomainName(constants.AUTHORIZER_URL)
|
||||
host, _ := GetHostParts(constants.AUTHORIZER_URL)
|
||||
domain := GetDomainName(constants.AUTHORIZER_URL)
|
||||
if domain != "localhost" {
|
||||
domain = "." + domain
|
||||
|
@@ -5,21 +5,32 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetHostName function to get hostname
|
||||
func GetHostName(auth_url string) string {
|
||||
u, err := url.Parse(auth_url)
|
||||
// GetHostName function returns hostname and port
|
||||
func GetHostParts(uri string) (string, string) {
|
||||
tempURI := uri
|
||||
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
|
||||
tempURI = "https://" + tempURI
|
||||
}
|
||||
|
||||
u, err := url.Parse(tempURI)
|
||||
if err != nil {
|
||||
return `localhost`
|
||||
return "localhost", "8080"
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
port := u.Port()
|
||||
|
||||
return host
|
||||
return host, port
|
||||
}
|
||||
|
||||
// GetDomainName function to get domain name
|
||||
func GetDomainName(auth_url string) string {
|
||||
u, err := url.Parse(auth_url)
|
||||
func GetDomainName(uri string) string {
|
||||
tempURI := uri
|
||||
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
|
||||
tempURI = "https://" + tempURI
|
||||
}
|
||||
|
||||
u, err := url.Parse(tempURI)
|
||||
if err != nil {
|
||||
return `localhost`
|
||||
}
|
||||
|
@@ -7,12 +7,13 @@ import (
|
||||
)
|
||||
|
||||
func TestGetHostName(t *testing.T) {
|
||||
authorizer_url := "http://test.herokuapp.com"
|
||||
authorizer_url := "http://test.herokuapp.com:80"
|
||||
|
||||
got := GetHostName(authorizer_url)
|
||||
want := "test.herokuapp.com"
|
||||
host, port := GetHostParts(authorizer_url)
|
||||
expectedHost := "test.herokuapp.com"
|
||||
|
||||
assert.Equal(t, got, want, "hostname should be equal")
|
||||
assert.Equal(t, host, expectedHost, "hostname should be equal")
|
||||
assert.Equal(t, port, "80", "port should be 80")
|
||||
}
|
||||
|
||||
func TestGetDomainName(t *testing.T) {
|
||||
|
@@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
@@ -13,16 +14,32 @@ func IsValidEmail(email string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func IsValidRedirectURL(url string) bool {
|
||||
func IsValidOrigin(url string) bool {
|
||||
if len(constants.ALLOWED_ORIGINS) == 1 && constants.ALLOWED_ORIGINS[0] == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
hasValidURL := false
|
||||
urlDomain := GetDomainName(url)
|
||||
hostName, port := GetHostParts(url)
|
||||
currentOrigin := hostName + ":" + port
|
||||
|
||||
for _, val := range constants.ALLOWED_ORIGINS {
|
||||
if strings.Contains(val, urlDomain) {
|
||||
for _, origin := range constants.ALLOWED_ORIGINS {
|
||||
replacedString := origin
|
||||
// if has regex whitelisted domains
|
||||
if strings.Contains(origin, "*") {
|
||||
replacedString = strings.Replace(origin, ".", "\\.", -1)
|
||||
replacedString = strings.Replace(replacedString, "*", ".*", -1)
|
||||
|
||||
if strings.HasPrefix(replacedString, ".*") {
|
||||
replacedString += "\\b"
|
||||
}
|
||||
|
||||
if strings.HasSuffix(replacedString, ".*") {
|
||||
replacedString = "\\b" + replacedString
|
||||
}
|
||||
}
|
||||
|
||||
if matched, _ := regexp.MatchString(replacedString, currentOrigin); matched {
|
||||
hasValidURL = true
|
||||
break
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -15,3 +16,19 @@ func TestIsValidEmail(t *testing.T) {
|
||||
assert.False(t, IsValidEmail(invalidEmail1), "it should be invalid email")
|
||||
assert.False(t, IsValidEmail(invalidEmail2), "it should be invalid email")
|
||||
}
|
||||
|
||||
func TestIsValidOrigin(t *testing.T) {
|
||||
// don't use portocal(http/https) for ALLOWED_ORIGINS while testing,
|
||||
// as we trim them off while running the main function
|
||||
constants.ALLOWED_ORIGINS = []string{"localhost:8080", "*.google.com", "*.google.in", "*abc.*"}
|
||||
|
||||
assert.False(t, IsValidOrigin("http://myapp.com"), "it should be invalid origin")
|
||||
assert.False(t, IsValidOrigin("http://appgoogle.com"), "it should be invalid origin")
|
||||
assert.True(t, IsValidOrigin("http://app.google.com"), "it should be valid origin")
|
||||
assert.False(t, IsValidOrigin("http://app.google.ind"), "it should be invalid origin")
|
||||
assert.True(t, IsValidOrigin("http://app.google.in"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://xyx.abc.com"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://xyx.abc.in"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://xyxabc.in"), "it should be valid origin")
|
||||
assert.True(t, IsValidOrigin("http://localhost:8080"), "it should be valid origin")
|
||||
}
|
||||
|
Reference in New Issue
Block a user