Compare commits

...

15 Commits

Author SHA1 Message Date
Lakhan Samani
1ebba7f2b7 Merge pull request #343 from authorizerdev/fix/session-storage
fix: session storage
2023-04-08 18:07:52 +05:30
Lakhan Samani
428a0be3db feat: add cache clear 2023-04-08 18:02:53 +05:30
Lakhan Samani
02c0ebb9c4 fix: session storage 2023-04-08 13:06:15 +05:30
Lakhan Samani
9a284c03ca fix: redis session 2023-04-03 10:26:27 +05:30
Lakhan Samani
c8fe05eabc Merge pull request #342 from authorizerdev/feat/default-oauth-configs
feat: add support for default response mode & type env
2023-04-01 17:42:02 +05:30
Lakhan Samani
48344ffd4c feat: add support for default response mode & type env
Resolves #341
2023-04-01 17:36:07 +05:30
Lakhan Samani
77f34e1149 Merge pull request #339 from authorizerdev/fix/webhooks
fix: allow multiple hooks for same event
2023-03-29 07:31:46 +05:30
Lakhan Samani
16136931a9 fix: add event description to webhook res 2023-03-29 07:31:07 +05:30
Lakhan Samani
c908ac94da fix: continue in case of error for register events 2023-03-29 07:29:44 +05:30
Lakhan Samani
6604b6bbdd fix: update dashboard ui for webhooks 2023-03-29 07:27:56 +05:30
Lakhan Samani
2c227b5518 chore: delete couchbase container after tests 2023-03-29 07:10:36 +05:30
Lakhan Samani
e822b6f31a fix: queries for webhooks + improve tests 2023-03-29 07:06:33 +05:30
Lakhan Samani
a38e9d4e6c fix: rename title -> event_description 2023-03-26 07:48:06 +05:30
Lakhan Samani
deaf1e2ff7 fix: allow multiple hooks for same event 2023-03-26 07:20:45 +05:30
Lakhan Samani
f324976801 Merge pull request #338 from authorizerdev/fix/accessibility
fix: accessibility
2023-03-25 10:48:49 +05:30
81 changed files with 1237 additions and 566 deletions

View File

@@ -51,6 +51,6 @@ test-all-db:
docker rm -vf authorizer_mongodb_db docker rm -vf authorizer_mongodb_db
docker rm -vf authorizer_arangodb docker rm -vf authorizer_arangodb
docker rm -vf dynamodb-local-test docker rm -vf dynamodb-local-test
# docker rm -vf couchbase-local-test docker rm -vf couchbase-local-test
generate: generate:
cd server && go run github.com/99designs/gqlgen generate && go mod tidy cd server && go run github.com/99designs/gqlgen generate && go mod tidy

View File

@@ -61,7 +61,6 @@ const JSTConfigurations = ({
return ( return (
<div> <div>
{' '}
<Flex <Flex
borderRadius={5} borderRadius={5}
width="100%" width="100%"

View File

@@ -18,7 +18,13 @@ import {
FaTwitter, FaTwitter,
FaMicrosoft, FaMicrosoft,
} from 'react-icons/fa'; } from 'react-icons/fa';
import { TextInputType, HiddenInputType } from '../../constants'; import {
TextInputType,
HiddenInputType,
ResponseModes,
ResponseTypes,
SelectInputType,
} from '../../constants';
const OAuthConfig = ({ const OAuthConfig = ({
envVariables, envVariables,
@@ -70,6 +76,42 @@ const OAuthConfig = ({
/> />
</Center> </Center>
</Flex> </Flex>
<Flex direction={isNotSmallerScreen ? 'row' : 'column'}>
<Flex w="30%" justifyContent="start" alignItems="center">
<Text fontSize="sm">Default Response Type:</Text>
</Flex>
<Flex
w={isNotSmallerScreen ? '70%' : '100%'}
mt={isNotSmallerScreen ? '0' : '2'}
>
<InputField
borderRadius={5}
variables={envVariables}
setVariables={setVariables}
inputType={SelectInputType.DEFAULT_AUTHORIZE_RESPONSE_TYPE}
value={SelectInputType}
options={ResponseTypes}
/>
</Flex>
</Flex>
<Flex direction={isNotSmallerScreen ? 'row' : 'column'}>
<Flex w="30%" justifyContent="start" alignItems="center">
<Text fontSize="sm">Default Response Mode:</Text>
</Flex>
<Flex
w={isNotSmallerScreen ? '70%' : '100%'}
mt={isNotSmallerScreen ? '0' : '2'}
>
<InputField
borderRadius={5}
variables={envVariables}
setVariables={setVariables}
inputType={SelectInputType.DEFAULT_AUTHORIZE_RESPONSE_MODE}
value={SelectInputType}
options={ResponseModes}
/>
</Flex>
</Flex>
</Stack> </Stack>
<Divider mt={5} mb={2} color="blackAlpha.700" /> <Divider mt={5} mb={2} color="blackAlpha.700" />
<Text fontSize="md" paddingTop="2%" fontWeight="bold" mb={4}> <Text fontSize="md" paddingTop="2%" fontWeight="bold" mb={4}>

View File

@@ -63,6 +63,7 @@ interface headersValidatorDataType {
interface selecetdWebhookDataTypes { interface selecetdWebhookDataTypes {
[WebhookInputDataFields.ID]: string; [WebhookInputDataFields.ID]: string;
[WebhookInputDataFields.EVENT_NAME]: string; [WebhookInputDataFields.EVENT_NAME]: string;
[WebhookInputDataFields.EVENT_DESCRIPTION]?: string;
[WebhookInputDataFields.ENDPOINT]: string; [WebhookInputDataFields.ENDPOINT]: string;
[WebhookInputDataFields.ENABLED]: boolean; [WebhookInputDataFields.ENABLED]: boolean;
[WebhookInputDataFields.HEADERS]?: Record<string, string>; [WebhookInputDataFields.HEADERS]?: Record<string, string>;
@@ -86,6 +87,7 @@ const initHeadersValidatorData: headersValidatorDataType = {
interface webhookDataType { interface webhookDataType {
[WebhookInputDataFields.EVENT_NAME]: string; [WebhookInputDataFields.EVENT_NAME]: string;
[WebhookInputDataFields.EVENT_DESCRIPTION]?: string;
[WebhookInputDataFields.ENDPOINT]: string; [WebhookInputDataFields.ENDPOINT]: string;
[WebhookInputDataFields.ENABLED]: boolean; [WebhookInputDataFields.ENABLED]: boolean;
[WebhookInputDataFields.HEADERS]: headersDataType[]; [WebhookInputDataFields.HEADERS]: headersDataType[];
@@ -98,6 +100,7 @@ interface validatorDataType {
const initWebhookData: webhookDataType = { const initWebhookData: webhookDataType = {
[WebhookInputDataFields.EVENT_NAME]: webhookEventNames['User login'], [WebhookInputDataFields.EVENT_NAME]: webhookEventNames['User login'],
[WebhookInputDataFields.EVENT_DESCRIPTION]: '',
[WebhookInputDataFields.ENDPOINT]: '', [WebhookInputDataFields.ENDPOINT]: '',
[WebhookInputDataFields.ENABLED]: true, [WebhookInputDataFields.ENABLED]: true,
[WebhookInputDataFields.HEADERS]: [{ ...initHeadersData }], [WebhookInputDataFields.HEADERS]: [{ ...initHeadersData }],
@@ -144,6 +147,9 @@ const UpdateWebhookModal = ({
case WebhookInputDataFields.EVENT_NAME: case WebhookInputDataFields.EVENT_NAME:
setWebhook({ ...webhook, [inputType]: value }); setWebhook({ ...webhook, [inputType]: value });
break; break;
case WebhookInputDataFields.EVENT_DESCRIPTION:
setWebhook({ ...webhook, [inputType]: value });
break;
case WebhookInputDataFields.ENDPOINT: case WebhookInputDataFields.ENDPOINT:
setWebhook({ ...webhook, [inputType]: value }); setWebhook({ ...webhook, [inputType]: value });
setValidator({ setValidator({
@@ -246,6 +252,8 @@ const UpdateWebhookModal = ({
let params: any = { let params: any = {
[WebhookInputDataFields.EVENT_NAME]: [WebhookInputDataFields.EVENT_NAME]:
webhook[WebhookInputDataFields.EVENT_NAME], webhook[WebhookInputDataFields.EVENT_NAME],
[WebhookInputDataFields.EVENT_DESCRIPTION]:
webhook[WebhookInputDataFields.EVENT_DESCRIPTION],
[WebhookInputDataFields.ENDPOINT]: [WebhookInputDataFields.ENDPOINT]:
webhook[WebhookInputDataFields.ENDPOINT], webhook[WebhookInputDataFields.ENDPOINT],
[WebhookInputDataFields.ENABLED]: webhook[WebhookInputDataFields.ENABLED], [WebhookInputDataFields.ENABLED]: webhook[WebhookInputDataFields.ENABLED],
@@ -402,7 +410,9 @@ const UpdateWebhookModal = ({
<Flex flex="3"> <Flex flex="3">
<Select <Select
size="md" size="md"
value={webhook[WebhookInputDataFields.EVENT_NAME]} value={
webhook[WebhookInputDataFields.EVENT_NAME].split('-')[0]
}
onChange={(e) => onChange={(e) =>
inputChangehandler( inputChangehandler(
WebhookInputDataFields.EVENT_NAME, WebhookInputDataFields.EVENT_NAME,
@@ -420,6 +430,30 @@ const UpdateWebhookModal = ({
</Select> </Select>
</Flex> </Flex>
</Flex> </Flex>
<Flex
width="100%"
justifyContent="start"
alignItems="center"
marginBottom="5%"
>
<Flex flex="1">Event Description</Flex>
<Flex flex="3">
<InputGroup size="md">
<Input
pr="4.5rem"
type="text"
placeholder="User event"
value={webhook[WebhookInputDataFields.EVENT_DESCRIPTION]}
onChange={(e) =>
inputChangehandler(
WebhookInputDataFields.EVENT_DESCRIPTION,
e.currentTarget.value,
)
}
/>
</InputGroup>
</Flex>
</Flex>
<Flex <Flex
width="100%" width="100%"
justifyContent="start" justifyContent="start"

View File

@@ -57,6 +57,8 @@ export const ArrayInputType = {
export const SelectInputType = { export const SelectInputType = {
JWT_TYPE: 'JWT_TYPE', JWT_TYPE: 'JWT_TYPE',
GENDER: 'gender', GENDER: 'gender',
DEFAULT_AUTHORIZE_RESPONSE_TYPE: 'DEFAULT_AUTHORIZE_RESPONSE_TYPE',
DEFAULT_AUTHORIZE_RESPONSE_MODE: 'DEFAULT_AUTHORIZE_RESPONSE_MODE',
}; };
export const MultiSelectInputType = { export const MultiSelectInputType = {
@@ -161,6 +163,8 @@ export interface envVarTypes {
ACCESS_TOKEN_EXPIRY_TIME: string; ACCESS_TOKEN_EXPIRY_TIME: string;
DISABLE_MULTI_FACTOR_AUTHENTICATION: boolean; DISABLE_MULTI_FACTOR_AUTHENTICATION: boolean;
ENFORCE_MULTI_FACTOR_AUTHENTICATION: boolean; ENFORCE_MULTI_FACTOR_AUTHENTICATION: boolean;
DEFAULT_AUTHORIZE_RESPONSE_TYPE: string;
DEFAULT_AUTHORIZE_RESPONSE_MODE: string;
} }
export const envSubViews = { export const envSubViews = {
@@ -179,6 +183,7 @@ export const envSubViews = {
export enum WebhookInputDataFields { export enum WebhookInputDataFields {
ID = 'id', ID = 'id',
EVENT_DESCRIPTION = 'event_description',
EVENT_NAME = 'event_name', EVENT_NAME = 'event_name',
ENDPOINT = 'endpoint', ENDPOINT = 'endpoint',
ENABLED = 'enabled', ENABLED = 'enabled',
@@ -348,3 +353,16 @@ export enum EmailTemplateEditors {
UNLAYER_EDITOR = 'unlayer_editor', UNLAYER_EDITOR = 'unlayer_editor',
PLAIN_HTML_EDITOR = 'plain_html_editor', PLAIN_HTML_EDITOR = 'plain_html_editor',
} }
export const ResponseTypes = {
token: 'token',
code: 'code',
id_token: 'id_token',
};
export const ResponseModes = {
query: 'query',
form_post: 'form_post',
fragment: 'fragment',
web_message: 'web_message',
};

View File

@@ -70,6 +70,8 @@ export const EnvVariablesQuery = `
ACCESS_TOKEN_EXPIRY_TIME ACCESS_TOKEN_EXPIRY_TIME
DISABLE_MULTI_FACTOR_AUTHENTICATION DISABLE_MULTI_FACTOR_AUTHENTICATION
ENFORCE_MULTI_FACTOR_AUTHENTICATION ENFORCE_MULTI_FACTOR_AUTHENTICATION
DEFAULT_AUTHORIZE_RESPONSE_TYPE
DEFAULT_AUTHORIZE_RESPONSE_MODE
} }
} }
`; `;
@@ -118,6 +120,7 @@ export const WebhooksDataQuery = `
_webhooks(params: $params){ _webhooks(params: $params){
webhooks{ webhooks{
id id
event_description
event_name event_name
endpoint endpoint
enabled enabled

View File

@@ -90,6 +90,8 @@ const Environment = () => {
ACCESS_TOKEN_EXPIRY_TIME: '', ACCESS_TOKEN_EXPIRY_TIME: '',
DISABLE_MULTI_FACTOR_AUTHENTICATION: false, DISABLE_MULTI_FACTOR_AUTHENTICATION: false,
ENFORCE_MULTI_FACTOR_AUTHENTICATION: false, ENFORCE_MULTI_FACTOR_AUTHENTICATION: false,
DEFAULT_AUTHORIZE_RESPONSE_TYPE: '',
DEFAULT_AUTHORIZE_RESPONSE_MODE: '',
}); });
const [fieldVisibility, setFieldVisibility] = React.useState< const [fieldVisibility, setFieldVisibility] = React.useState<

View File

@@ -56,6 +56,7 @@ interface paginationPropTypes {
interface webhookDataTypes { interface webhookDataTypes {
[WebhookInputDataFields.ID]: string; [WebhookInputDataFields.ID]: string;
[WebhookInputDataFields.EVENT_NAME]: string; [WebhookInputDataFields.EVENT_NAME]: string;
[WebhookInputDataFields.EVENT_DESCRIPTION]?: string;
[WebhookInputDataFields.ENDPOINT]: string; [WebhookInputDataFields.ENDPOINT]: string;
[WebhookInputDataFields.ENABLED]: boolean; [WebhookInputDataFields.ENABLED]: boolean;
[WebhookInputDataFields.HEADERS]?: Record<string, string>; [WebhookInputDataFields.HEADERS]?: Record<string, string>;
@@ -117,6 +118,7 @@ const Webhooks = () => {
useEffect(() => { useEffect(() => {
fetchWebookData(); fetchWebookData();
}, [paginationProps.page, paginationProps.limit]); }, [paginationProps.page, paginationProps.limit]);
console.log({ webhookData });
return ( return (
<Box m="5" py="5" px="10" bg="white" rounded="md"> <Box m="5" py="5" px="10" bg="white" rounded="md">
<Flex margin="2% 0" justifyContent="space-between" alignItems="center"> <Flex margin="2% 0" justifyContent="space-between" alignItems="center">
@@ -134,6 +136,7 @@ const Webhooks = () => {
<Thead> <Thead>
<Tr> <Tr>
<Th>Event Name</Th> <Th>Event Name</Th>
<Th>Event Description</Th>
<Th>Endpoint</Th> <Th>Endpoint</Th>
<Th>Enabled</Th> <Th>Enabled</Th>
<Th>Headers</Th> <Th>Headers</Th>
@@ -147,7 +150,10 @@ const Webhooks = () => {
style={{ fontSize: 14 }} style={{ fontSize: 14 }}
> >
<Td maxW="300"> <Td maxW="300">
{webhook[WebhookInputDataFields.EVENT_NAME]} {webhook[WebhookInputDataFields.EVENT_NAME].split('-')[0]}
</Td>
<Td maxW="300">
{webhook[WebhookInputDataFields.EVENT_DESCRIPTION]}
</Td> </Td>
<Td>{webhook[WebhookInputDataFields.ENDPOINT]}</Td> <Td>{webhook[WebhookInputDataFields.ENDPOINT]}</Td>
<Td> <Td>
@@ -264,7 +270,7 @@ const Webhooks = () => {
</Text> </Text>
</Text> </Text>
<Flex alignItems="center"> <Flex alignItems="center">
<Text flexShrink="0">Go to page:</Text>{' '} <Text>Go to page:</Text>{' '}
<NumberInput <NumberInput
ml={2} ml={2}
mr={8} mr={8}

View File

@@ -166,4 +166,12 @@ const (
EnvKeyDefaultRoles = "DEFAULT_ROLES" EnvKeyDefaultRoles = "DEFAULT_ROLES"
// EnvKeyAllowedOrigins key for env variable ALLOWED_ORIGINS // EnvKeyAllowedOrigins key for env variable ALLOWED_ORIGINS
EnvKeyAllowedOrigins = "ALLOWED_ORIGINS" EnvKeyAllowedOrigins = "ALLOWED_ORIGINS"
// For oauth/openid/authorize
// EnvKeyDefaultAuthorizeResponseType key for env variable DEFAULT_AUTHORIZE_RESPONSE_TYPE
// This env is used for setting default response type in authorize handler
EnvKeyDefaultAuthorizeResponseType = "DEFAULT_AUTHORIZE_RESPONSE_TYPE"
// EnvKeyDefaultAuthorizeResponseMode key for env variable DEFAULT_AUTHORIZE_RESPONSE_MODE
// This env is used for setting default response mode in authorize handler
EnvKeyDefaultAuthorizeResponseMode = "DEFAULT_AUTHORIZE_RESPONSE_MODE"
) )

View File

@@ -10,35 +10,42 @@ import (
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
// Event name has been kept unique as per initial design. But later on decided that we can have
// multiple hooks for same event so will be in a pattern `event_name-TIMESTAMP`
// Webhook model for db // Webhook model for db
type Webhook struct { type Webhook struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"`
EventName string `gorm:"unique" json:"event_name" bson:"event_name" cql:"event_name" dynamo:"event_name" index:"event_name,hash"` EventName string `gorm:"unique" json:"event_name" bson:"event_name" cql:"event_name" dynamo:"event_name" index:"event_name,hash"`
EndPoint string `json:"endpoint" bson:"endpoint" cql:"endpoint" dynamo:"endpoint"` EventDescription string `json:"event_description" bson:"event_description" cql:"event_description" dynamo:"event_description"`
Headers string `json:"headers" bson:"headers" cql:"headers" dynamo:"headers"` EndPoint string `json:"endpoint" bson:"endpoint" cql:"endpoint" dynamo:"endpoint"`
Enabled bool `json:"enabled" bson:"enabled" cql:"enabled" dynamo:"enabled"` Headers string `json:"headers" bson:"headers" cql:"headers" dynamo:"headers"`
CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at" dynamo:"created_at"` Enabled bool `json:"enabled" bson:"enabled" cql:"enabled" dynamo:"enabled"`
UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at" dynamo:"updated_at"` CreatedAt int64 `json:"created_at" bson:"created_at" cql:"created_at" dynamo:"created_at"`
UpdatedAt int64 `json:"updated_at" bson:"updated_at" cql:"updated_at" dynamo:"updated_at"`
} }
// AsAPIWebhook to return webhook as graphql response object // AsAPIWebhook to return webhook as graphql response object
func (w *Webhook) AsAPIWebhook() *model.Webhook { func (w *Webhook) AsAPIWebhook() *model.Webhook {
headersMap := make(map[string]interface{}) headersMap := make(map[string]interface{})
json.Unmarshal([]byte(w.Headers), &headersMap) json.Unmarshal([]byte(w.Headers), &headersMap)
id := w.ID id := w.ID
if strings.Contains(id, Collections.Webhook+"/") { if strings.Contains(id, Collections.Webhook+"/") {
id = strings.TrimPrefix(id, Collections.Webhook+"/") id = strings.TrimPrefix(id, Collections.Webhook+"/")
} }
// set default title to event name without dot(.)
if w.EventDescription == "" {
w.EventDescription = strings.Join(strings.Split(w.EventName, "."), " ")
}
return &model.Webhook{ return &model.Webhook{
ID: id, ID: id,
EventName: refs.NewStringRef(w.EventName), EventName: refs.NewStringRef(w.EventName),
Endpoint: refs.NewStringRef(w.EndPoint), EventDescription: refs.NewStringRef(w.EventDescription),
Headers: headersMap, Endpoint: refs.NewStringRef(w.EndPoint),
Enabled: refs.NewBoolRef(w.Enabled), Headers: headersMap,
CreatedAt: refs.NewInt64Ref(w.CreatedAt), Enabled: refs.NewBoolRef(w.Enabled),
UpdatedAt: refs.NewInt64Ref(w.UpdatedAt), CreatedAt: refs.NewInt64Ref(w.CreatedAt),
UpdatedAt: refs.NewInt64Ref(w.UpdatedAt),
} }
} }

View File

@@ -3,8 +3,10 @@ package arangodb
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/arangodb/go-driver"
arangoDriver "github.com/arangodb/go-driver" arangoDriver "github.com/arangodb/go-driver"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
@@ -17,8 +19,9 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
webhook.ID = uuid.New().String() webhook.ID = uuid.New().String()
webhook.Key = webhook.ID webhook.Key = webhook.ID
} }
webhook.Key = webhook.ID webhook.Key = webhook.ID
// Add timestamp to make event name unique for legacy version
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook) webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
@@ -32,12 +35,15 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook) webhookCollection, _ := p.db.Collection(ctx, models.Collections.Webhook)
meta, err := webhookCollection.UpdateDocument(ctx, webhook.Key, webhook) meta, err := webhookCollection.UpdateDocument(ctx, webhook.Key, webhook)
if err != nil { if err != nil {
return nil, err return nil, err
} }
webhook.Key = meta.Key webhook.Key = meta.Key
webhook.ID = meta.ID.String() webhook.ID = meta.ID.String()
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
@@ -55,10 +61,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
return nil, err return nil, err
} }
defer cursor.Close() defer cursor.Close()
paginationClone := pagination paginationClone := pagination
paginationClone.Total = cursor.Statistics().FullCount() paginationClone.Total = cursor.Statistics().FullCount()
for { for {
var webhook models.Webhook var webhook models.Webhook
meta, err := cursor.ReadDocument(ctx, &webhook) meta, err := cursor.ReadDocument(ctx, &webhook)
@@ -87,13 +91,11 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
bindVars := map[string]interface{}{ bindVars := map[string]interface{}{
"webhook_id": webhookID, "webhook_id": webhookID,
} }
cursor, err := p.db.Query(ctx, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cursor.Close() defer cursor.Close()
for { for {
if !cursor.HasMore() { if !cursor.HasMore() {
if webhook.Key == "" { if webhook.Key == "" {
@@ -110,32 +112,28 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
} }
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) {
var webhook models.Webhook query := fmt.Sprintf("FOR d in %s FILTER d.event_name LIKE @event_name RETURN d", models.Collections.Webhook)
query := fmt.Sprintf("FOR d in %s FILTER d.event_name == @event_name RETURN d", models.Collections.Webhook)
bindVars := map[string]interface{}{ bindVars := map[string]interface{}{
"event_name": eventName, "event_name": eventName + "%",
} }
cursor, err := p.db.Query(ctx, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cursor.Close() defer cursor.Close()
webhooks := []*model.Webhook{}
for { for {
if !cursor.HasMore() { var webhook models.Webhook
if webhook.Key == "" { if _, err := cursor.ReadDocument(ctx, &webhook); driver.IsNoMoreDocuments(err) {
return nil, fmt.Errorf("webhook not found") // We're done
}
break break
} } else if err != nil {
_, err := cursor.ReadDocument(ctx, &webhook)
if err != nil {
return nil, err return nil, err
} }
webhooks = append(webhooks, webhook.AsAPIWebhook())
} }
return webhook.AsAPIWebhook(), nil return webhooks, nil
} }
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook
@@ -145,17 +143,14 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er
if err != nil { if err != nil {
return err return err
} }
query := fmt.Sprintf("FOR d IN %s FILTER d.webhook_id == @webhook_id REMOVE { _key: d._key } IN %s", models.Collections.WebhookLog, models.Collections.WebhookLog) query := fmt.Sprintf("FOR d IN %s FILTER d.webhook_id == @webhook_id REMOVE { _key: d._key } IN %s", models.Collections.WebhookLog, models.Collections.WebhookLog)
bindVars := map[string]interface{}{ bindVars := map[string]interface{}{
"webhook_id": webhook.ID, "webhook_id": webhook.ID,
} }
cursor, err := p.db.Query(ctx, query, bindVars) cursor, err := p.db.Query(ctx, query, bindVars)
if err != nil { if err != nil {
return err return err
} }
defer cursor.Close() defer cursor.Close()
return nil return nil
} }

View File

@@ -207,6 +207,13 @@ func NewProvider() (*provider, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
// add event_description to webhook table
webhookAlterQuery := fmt.Sprintf(`ALTER TABLE %s.%s ADD (event_description text);`, KeySpace, models.Collections.Webhook)
err = session.Query(webhookAlterQuery).Exec()
if err != nil {
log.Debug("Failed to alter table as column exists: ", err)
// continue
}
webhookLogCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, http_status bigint, response text, request text, webhook_id text,updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.WebhookLog) webhookLogCollectionQuery := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (id text, http_status bigint, response text, request text, webhook_id text,updated_at bigint, created_at bigint, PRIMARY KEY (id))", KeySpace, models.Collections.WebhookLog)
err = session.Query(webhookLogCollectionQuery).Exec() err = session.Query(webhookLogCollectionQuery).Exec()

View File

@@ -19,29 +19,26 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
if webhook.ID == "" { if webhook.ID == "" {
webhook.ID = uuid.New().String() webhook.ID = uuid.New().String()
} }
webhook.Key = webhook.ID webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Add timestamp to make event name unique for legacy version
existingHook, _ := p.GetWebhookByEventName(ctx, webhook.EventName) webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
if existingHook != nil { insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventDescription, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt)
return nil, fmt.Errorf("Webhook with %s event_name already exists", webhook.EventName)
}
insertQuery := fmt.Sprintf("INSERT INTO %s (id, event_name, endpoint, headers, enabled, created_at, updated_at) VALUES ('%s', '%s', '%s', '%s', %t, %d, %d)", KeySpace+"."+models.Collections.Webhook, webhook.ID, webhook.EventName, webhook.EndPoint, webhook.Headers, webhook.Enabled, webhook.CreatedAt, webhook.UpdatedAt)
err := p.db.Query(insertQuery).Exec() err := p.db.Query(insertQuery).Exec()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
} }
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
bytes, err := json.Marshal(webhook) bytes, err := json.Marshal(webhook)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -54,22 +51,18 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
updateFields := "" updateFields := ""
for key, value := range webhookMap { for key, value := range webhookMap {
if key == "_id" { if key == "_id" {
continue continue
} }
if key == "_key" { if key == "_key" {
continue continue
} }
if value == nil { if value == nil {
updateFields += fmt.Sprintf("%s = null,", key) updateFields += fmt.Sprintf("%s = null,", key)
continue continue
} }
valueType := reflect.TypeOf(value) valueType := reflect.TypeOf(value)
if valueType.Name() == "string" { if valueType.Name() == "string" {
updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string)) updateFields += fmt.Sprintf("%s = '%s', ", key, value.(string))
@@ -79,7 +72,6 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
} }
updateFields = strings.Trim(updateFields, " ") updateFields = strings.Trim(updateFields, " ")
updateFields = strings.TrimSuffix(updateFields, ",") updateFields = strings.TrimSuffix(updateFields, ",")
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID) query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s'", KeySpace+"."+models.Collections.Webhook, updateFields, webhook.ID)
err = p.db.Query(query).Exec() err = p.db.Query(query).Exec()
if err != nil { if err != nil {
@@ -92,24 +84,21 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
webhooks := []*model.Webhook{} webhooks := []*model.Webhook{}
paginationClone := pagination paginationClone := pagination
totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.Webhook) totalCountQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, KeySpace+"."+models.Collections.Webhook)
err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total) err := p.db.Query(totalCountQuery).Consistency(gocql.One).Scan(&paginationClone.Total)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// there is no offset in cassandra // there is no offset in cassandra
// so we fetch till limit + offset // so we fetch till limit + offset
// and return the results from offset to limit // and return the results from offset to limit
query := fmt.Sprintf("SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.Webhook, pagination.Limit+pagination.Offset) query := fmt.Sprintf("SELECT id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s LIMIT %d", KeySpace+"."+models.Collections.Webhook, pagination.Limit+pagination.Offset)
scanner := p.db.Query(query).Iter().Scanner() scanner := p.db.Query(query).Iter().Scanner()
counter := int64(0) counter := int64(0)
for scanner.Next() { for scanner.Next() {
if counter >= pagination.Offset { if counter >= pagination.Offset {
var webhook models.Webhook var webhook models.Webhook
err := scanner.Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) err := scanner.Scan(&webhook.ID, &webhook.EventDescription, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -127,8 +116,8 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
// GetWebhookByID to get webhook by id // GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook var webhook models.Webhook
query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID) query := fmt.Sprintf(`SELECT id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE id = '%s' LIMIT 1`, KeySpace+"."+models.Collections.Webhook, webhookID)
err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventDescription, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -136,14 +125,19 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
} }
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) {
var webhook models.Webhook query := fmt.Sprintf(`SELECT id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name LIKE '%s' ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName+"%")
query := fmt.Sprintf(`SELECT id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s WHERE event_name = '%s' LIMIT 1 ALLOW FILTERING`, KeySpace+"."+models.Collections.Webhook, eventName) scanner := p.db.Query(query).Iter().Scanner()
err := p.db.Query(query).Consistency(gocql.One).Scan(&webhook.ID, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt) webhooks := []*model.Webhook{}
if err != nil { for scanner.Next() {
return nil, err var webhook models.Webhook
err := scanner.Scan(&webhook.ID, &webhook.EventDescription, &webhook.EventName, &webhook.EndPoint, &webhook.Headers, &webhook.Enabled, &webhook.CreatedAt, &webhook.UpdatedAt)
if err != nil {
return nil, err
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
} }
return webhook.AsAPIWebhook(), nil return webhooks, nil
} }
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook

View File

@@ -19,11 +19,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
if webhook.ID == "" { if webhook.ID == "" {
webhook.ID = uuid.New().String() webhook.ID = uuid.New().String()
} }
webhook.Key = webhook.ID webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Add timestamp to make event name unique for legacy version
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
insertOpt := gocb.InsertOptions{ insertOpt := gocb.InsertOptions{
Context: ctx, Context: ctx,
} }
@@ -37,7 +37,10 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
bytes, err := json.Marshal(webhook) bytes, err := json.Marshal(webhook)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -50,17 +53,13 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
updateFields, params := GetSetFields(webhookMap) updateFields, params := GetSetFields(webhookMap)
query := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE _id='%s'`, p.scopeName, models.Collections.Webhook, updateFields, webhook.ID) query := fmt.Sprintf(`UPDATE %s.%s SET %s WHERE _id='%s'`, p.scopeName, models.Collections.Webhook, updateFields, webhook.ID)
_, err = p.db.Query(query, &gocb.QueryOptions{ _, err = p.db.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params, NamedParameters: params,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -72,7 +71,6 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
webhooks := []*model.Webhook{} webhooks := []*model.Webhook{}
paginationClone := pagination paginationClone := pagination
params := make(map[string]interface{}, 1) params := make(map[string]interface{}, 1)
params["offset"] = paginationClone.Offset params["offset"] = paginationClone.Offset
params["limit"] = paginationClone.Limit params["limit"] = paginationClone.Limit
@@ -81,7 +79,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
return nil, err return nil, err
} }
paginationClone.Total = total paginationClone.Total = total
query := fmt.Sprintf("SELECT _id, env, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.Webhook) query := fmt.Sprintf("SELECT _id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s OFFSET $offset LIMIT $limit", p.scopeName, models.Collections.Webhook)
queryResult, err := p.db.Query(query, &gocb.QueryOptions{ queryResult, err := p.db.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
@@ -110,11 +108,9 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
// GetWebhookByID to get webhook by id // GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
var webhook models.Webhook var webhook models.Webhook
params := make(map[string]interface{}, 1) params := make(map[string]interface{}, 1)
params["_id"] = webhookID params["_id"] = webhookID
query := fmt.Sprintf(`SELECT _id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, models.Collections.Webhook)
query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE _id=$_id LIMIT 1`, p.scopeName, models.Collections.Webhook)
q, err := p.db.Query(query, &gocb.QueryOptions{ q, err := p.db.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
@@ -124,42 +120,42 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
return nil, err return nil, err
} }
err = q.One(&webhook) err = q.One(&webhook)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
} }
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) {
var webhook models.Webhook
params := make(map[string]interface{}, 1) params := make(map[string]interface{}, 1)
params["event_name"] = eventName // params["event_name"] = eventName + "%"
query := fmt.Sprintf(`SELECT _id, event_description, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE event_name LIKE '%s'`, p.scopeName, models.Collections.Webhook, eventName+"%")
query := fmt.Sprintf(`SELECT _id, event_name, endpoint, headers, enabled, created_at, updated_at FROM %s.%s WHERE event_name=$event_name LIMIT 1`, p.scopeName, models.Collections.Webhook) queryResult, err := p.db.Query(query, &gocb.QueryOptions{
q, err := p.db.Query(query, &gocb.QueryOptions{
Context: ctx, Context: ctx,
ScanConsistency: gocb.QueryScanConsistencyRequestPlus, ScanConsistency: gocb.QueryScanConsistencyRequestPlus,
NamedParameters: params, NamedParameters: params,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = q.One(&webhook) webhooks := []*model.Webhook{}
for queryResult.Next() {
if err != nil { var webhook models.Webhook
err := queryResult.Row(&webhook)
if err != nil {
log.Fatal(err)
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
if err := queryResult.Err(); err != nil {
return nil, err return nil, err
} }
return webhooks, nil
return webhook.AsAPIWebhook(), nil
} }
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error { func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) error {
params := make(map[string]interface{}, 1) params := make(map[string]interface{}, 1)
params["webhook_id"] = webhook.ID params["webhook_id"] = webhook.ID
removeOpt := gocb.RemoveOptions{ removeOpt := gocb.RemoveOptions{

View File

@@ -3,28 +3,29 @@ package dynamodb
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings"
"time" "time"
"github.com/google/uuid"
"github.com/guregu/dynamo"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/google/uuid"
"github.com/guregu/dynamo"
) )
// AddWebhook to add webhook // AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
collection := p.db.Table(models.Collections.Webhook) collection := p.db.Table(models.Collections.Webhook)
if webhook.ID == "" { if webhook.ID == "" {
webhook.ID = uuid.New().String() webhook.ID = uuid.New().String()
} }
webhook.Key = webhook.ID webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Add timestamp to make event name unique for legacy version
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
err := collection.Put(webhook).RunWithContext(ctx) err := collection.Put(webhook).RunWithContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -33,11 +34,13 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
collection := p.db.Table(models.Collections.Webhook)
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
collection := p.db.Table(models.Collections.Webhook)
err := UpdateByHashKey(collection, "id", webhook.ID, webhook) err := UpdateByHashKey(collection, "id", webhook.ID, webhook)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -51,16 +54,13 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
var lastEval dynamo.PagingKey var lastEval dynamo.PagingKey
var iter dynamo.PagingIter var iter dynamo.PagingIter
var iteration int64 = 0 var iteration int64 = 0
collection := p.db.Table(models.Collections.Webhook) collection := p.db.Table(models.Collections.Webhook)
paginationClone := pagination paginationClone := pagination
scanner := collection.Scan() scanner := collection.Scan()
count, err := scanner.Count() count, err := scanner.Count()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for (paginationClone.Offset + paginationClone.Limit) > iteration { for (paginationClone.Offset + paginationClone.Limit) > iteration {
iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter()
for iter.NextWithContext(ctx, &webhook) { for iter.NextWithContext(ctx, &webhook) {
@@ -75,9 +75,7 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
lastEval = iter.LastEvaluatedKey() lastEval = iter.LastEvaluatedKey()
iteration += paginationClone.Limit iteration += paginationClone.Limit
} }
paginationClone.Total = count paginationClone.Total = count
return &model.Webhooks{ return &model.Webhooks{
Pagination: &paginationClone, Pagination: &paginationClone,
Webhooks: webhooks, Webhooks: webhooks,
@@ -88,37 +86,29 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) { func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) {
collection := p.db.Table(models.Collections.Webhook) collection := p.db.Table(models.Collections.Webhook)
var webhook models.Webhook var webhook models.Webhook
err := collection.Get("id", webhookID).OneWithContext(ctx, &webhook) err := collection.Get("id", webhookID).OneWithContext(ctx, &webhook)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if webhook.ID == "" { if webhook.ID == "" {
return webhook.AsAPIWebhook(), errors.New("no documets found") return webhook.AsAPIWebhook(), errors.New("no documets found")
} }
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
} }
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) {
var webhook models.Webhook webhooks := []models.Webhook{}
collection := p.db.Table(models.Collections.Webhook) collection := p.db.Table(models.Collections.Webhook)
err := collection.Scan().Index("event_name").Filter("contains(event_name, ?)", eventName).AllWithContext(ctx, &webhooks)
iter := collection.Scan().Index("event_name").Filter("'event_name' = ?", eventName).Iter()
for iter.NextWithContext(ctx, &webhook) {
return webhook.AsAPIWebhook(), nil
}
err := iter.Err()
if err != nil { if err != nil {
return webhook.AsAPIWebhook(), err return nil, err
} }
return webhook.AsAPIWebhook(), nil resWebhooks := []*model.Webhook{}
for _, w := range webhooks {
resWebhooks = append(resWebhooks, w.AsAPIWebhook())
}
return resWebhooks, nil
} }
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook
@@ -133,7 +123,6 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er
return err return err
} }
webhookLogs, errIs := p.ListWebhookLogs(ctx, pagination, webhook.ID) webhookLogs, errIs := p.ListWebhookLogs(ctx, pagination, webhook.ID)
for _, webhookLog := range webhookLogs.WebhookLogs { for _, webhookLog := range webhookLogs.WebhookLogs {
err = webhookLogCollection.Delete("id", webhookLog.ID).RunWithContext(ctx) err = webhookLogCollection.Delete("id", webhookLog.ID).RunWithContext(ctx)
if err != nil { if err != nil {

View File

@@ -2,6 +2,8 @@ package mongodb
import ( import (
"context" "context"
"fmt"
"strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -16,11 +18,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
if webhook.ID == "" { if webhook.ID == "" {
webhook.ID = uuid.New().String() webhook.ID = uuid.New().String()
} }
webhook.Key = webhook.ID webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Add timestamp to make event name unique for legacy version
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.InsertOne(ctx, webhook) _, err := webhookCollection.InsertOne(ctx, webhook)
if err != nil { if err != nil {
@@ -32,39 +34,37 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
_, err := webhookCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": webhook.ID}}, bson.M{"$set": webhook}, options.MergeUpdateOptions()) _, err := webhookCollection.UpdateOne(ctx, bson.M{"_id": bson.M{"$eq": webhook.ID}}, bson.M{"$set": webhook}, options.MergeUpdateOptions())
if err != nil { if err != nil {
return nil, err return nil, err
} }
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
} }
// ListWebhooks to list webhook // ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
var webhooks []*model.Webhook webhooks := []*model.Webhook{}
opts := options.Find() opts := options.Find()
opts.SetLimit(pagination.Limit) opts.SetLimit(pagination.Limit)
opts.SetSkip(pagination.Offset) opts.SetSkip(pagination.Offset)
opts.SetSort(bson.M{"created_at": -1}) opts.SetSort(bson.M{"created_at": -1})
paginationClone := pagination paginationClone := pagination
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
count, err := webhookCollection.CountDocuments(ctx, bson.M{}, options.Count()) count, err := webhookCollection.CountDocuments(ctx, bson.M{}, options.Count())
if err != nil { if err != nil {
return nil, err return nil, err
} }
paginationClone.Total = count paginationClone.Total = count
cursor, err := webhookCollection.Find(ctx, bson.M{}, opts) cursor, err := webhookCollection.Find(ctx, bson.M{}, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cursor.Close(ctx) defer cursor.Close(ctx)
for cursor.Next(ctx) { for cursor.Next(ctx) {
var webhook models.Webhook var webhook models.Webhook
err := cursor.Decode(&webhook) err := cursor.Decode(&webhook)
@@ -73,7 +73,6 @@ func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination)
} }
webhooks = append(webhooks, webhook.AsAPIWebhook()) webhooks = append(webhooks, webhook.AsAPIWebhook())
} }
return &model.Webhooks{ return &model.Webhooks{
Pagination: &paginationClone, Pagination: &paginationClone,
Webhooks: webhooks, Webhooks: webhooks,
@@ -92,14 +91,27 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
} }
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) {
var webhook models.Webhook webhooks := []*model.Webhook{}
webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection()) webhookCollection := p.db.Collection(models.Collections.Webhook, options.Collection())
err := webhookCollection.FindOne(ctx, bson.M{"event_name": eventName}).Decode(&webhook) opts := options.Find()
opts.SetSort(bson.M{"created_at": -1})
cursor, err := webhookCollection.Find(ctx, bson.M{"event_name": bson.M{
"$regex": fmt.Sprintf("^%s", eventName),
}}, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return webhook.AsAPIWebhook(), nil defer cursor.Close(ctx)
for cursor.Next(ctx) {
var webhook models.Webhook
err := cursor.Decode(&webhook)
if err != nil {
return nil, err
}
webhooks = append(webhooks, webhook.AsAPIWebhook())
}
return webhooks, nil
} }
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook
@@ -109,12 +121,10 @@ func (p *provider) DeleteWebhook(ctx context.Context, webhook *model.Webhook) er
if err != nil { if err != nil {
return err return err
} }
webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection()) webhookLogCollection := p.db.Collection(models.Collections.WebhookLog, options.Collection())
_, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete()) _, err = webhookLogCollection.DeleteMany(nil, bson.M{"webhook_id": webhook.ID}, options.Delete())
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }

View File

@@ -2,6 +2,8 @@ package provider_template
import ( import (
"context" "context"
"fmt"
"strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -14,16 +16,21 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
if webhook.ID == "" { if webhook.ID == "" {
webhook.ID = uuid.New().String() webhook.ID = uuid.New().String()
} }
webhook.Key = webhook.ID webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Add timestamp to make event name unique for legacy version
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
} }
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
} }
@@ -38,7 +45,7 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
} }
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) {
return nil, nil return nil, nil
} }

View File

@@ -56,7 +56,7 @@ type Provider interface {
// GetWebhookByID to get webhook by id // GetWebhookByID to get webhook by id
GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error) GetWebhookByID(ctx context.Context, webhookID string) (*model.Webhook, error)
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error)
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook
DeleteWebhook(ctx context.Context, webhook *model.Webhook) error DeleteWebhook(ctx context.Context, webhook *model.Webhook) error

View File

@@ -2,6 +2,8 @@ package sql
import ( import (
"context" "context"
"fmt"
"strings"
"time" "time"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
@@ -14,10 +16,11 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
if webhook.ID == "" { if webhook.ID == "" {
webhook.ID = uuid.New().String() webhook.ID = uuid.New().String()
} }
webhook.Key = webhook.ID webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix() webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Add timestamp to make event name unique for legacy version
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
res := p.db.Create(&webhook) res := p.db.Create(&webhook)
if res.Error != nil { if res.Error != nil {
return nil, res.Error return nil, res.Error
@@ -28,33 +31,31 @@ func (p *provider) AddWebhook(ctx context.Context, webhook models.Webhook) (*mod
// UpdateWebhook to update webhook // UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) { func (p *provider) UpdateWebhook(ctx context.Context, webhook models.Webhook) (*model.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix()
// Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
result := p.db.Save(&webhook) result := p.db.Save(&webhook)
if result.Error != nil { if result.Error != nil {
return nil, result.Error return nil, result.Error
} }
return webhook.AsAPIWebhook(), nil return webhook.AsAPIWebhook(), nil
} }
// ListWebhooks to list webhook // ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) { func (p *provider) ListWebhook(ctx context.Context, pagination model.Pagination) (*model.Webhooks, error) {
var webhooks []models.Webhook var webhooks []models.Webhook
result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhooks) result := p.db.Limit(int(pagination.Limit)).Offset(int(pagination.Offset)).Order("created_at DESC").Find(&webhooks)
if result.Error != nil { if result.Error != nil {
return nil, result.Error return nil, result.Error
} }
var total int64 var total int64
totalRes := p.db.Model(&models.Webhook{}).Count(&total) totalRes := p.db.Model(&models.Webhook{}).Count(&total)
if totalRes.Error != nil { if totalRes.Error != nil {
return nil, totalRes.Error return nil, totalRes.Error
} }
paginationClone := pagination paginationClone := pagination
paginationClone.Total = total paginationClone.Total = total
responseWebhooks := []*model.Webhook{} responseWebhooks := []*model.Webhook{}
for _, w := range webhooks { for _, w := range webhooks {
responseWebhooks = append(responseWebhooks, w.AsAPIWebhook()) responseWebhooks = append(responseWebhooks, w.AsAPIWebhook())
@@ -77,14 +78,17 @@ func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*model
} }
// GetWebhookByEventName to get webhook by event_name // GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) (*model.Webhook, error) { func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*model.Webhook, error) {
var webhook models.Webhook var webhooks []models.Webhook
result := p.db.Where("event_name LIKE ?", eventName+"%").Find(&webhooks)
result := p.db.Where("event_name = ?", eventName).First(&webhook)
if result.Error != nil { if result.Error != nil {
return nil, result.Error return nil, result.Error
} }
return webhook.AsAPIWebhook(), nil responseWebhooks := []*model.Webhook{}
for _, w := range webhooks {
responseWebhooks = append(responseWebhooks, w.AsAPIWebhook())
}
return responseWebhooks, nil
} }
// DeleteWebhook to delete webhook // DeleteWebhook to delete webhook

24
server/env/env.go vendored
View File

@@ -87,6 +87,8 @@ func InitAllEnv() error {
osCouchbaseBucket := os.Getenv(constants.EnvCouchbaseBucket) osCouchbaseBucket := os.Getenv(constants.EnvCouchbaseBucket)
osCouchbaseScope := os.Getenv(constants.EnvCouchbaseScope) osCouchbaseScope := os.Getenv(constants.EnvCouchbaseScope)
osCouchbaseBucketRAMQuotaMB := os.Getenv(constants.EnvCouchbaseBucketRAMQuotaMB) osCouchbaseBucketRAMQuotaMB := os.Getenv(constants.EnvCouchbaseBucketRAMQuotaMB)
osAuthorizeResponseType := os.Getenv(constants.EnvKeyDefaultAuthorizeResponseType)
osAuthorizeResponseMode := os.Getenv(constants.EnvKeyDefaultAuthorizeResponseMode)
// os bool vars // os bool vars
osAppCookieSecure := os.Getenv(constants.EnvKeyAppCookieSecure) osAppCookieSecure := os.Getenv(constants.EnvKeyAppCookieSecure)
@@ -735,6 +737,28 @@ func InitAllEnv() error {
envData[constants.EnvKeyProtectedRoles] = osProtectedRoles envData[constants.EnvKeyProtectedRoles] = osProtectedRoles
} }
if val, ok := envData[constants.EnvKeyDefaultAuthorizeResponseType]; !ok || val == "" {
envData[constants.EnvKeyDefaultAuthorizeResponseType] = osAuthorizeResponseType
// Set the default value to token type
if envData[constants.EnvKeyDefaultAuthorizeResponseType] == "" {
envData[constants.EnvKeyDefaultAuthorizeResponseType] = constants.ResponseTypeToken
}
}
if osAuthorizeResponseType != "" && envData[constants.EnvKeyDefaultAuthorizeResponseType] != osAuthorizeResponseType {
envData[constants.EnvKeyDefaultAuthorizeResponseType] = osAuthorizeResponseType
}
if val, ok := envData[constants.EnvKeyDefaultAuthorizeResponseMode]; !ok || val == "" {
envData[constants.EnvKeyDefaultAuthorizeResponseMode] = osAuthorizeResponseMode
// Set the default value to token type
if envData[constants.EnvKeyDefaultAuthorizeResponseMode] == "" {
envData[constants.EnvKeyDefaultAuthorizeResponseMode] = constants.ResponseModeQuery
}
}
if osAuthorizeResponseMode != "" && envData[constants.EnvKeyDefaultAuthorizeResponseMode] != osAuthorizeResponseMode {
envData[constants.EnvKeyDefaultAuthorizeResponseMode] = osAuthorizeResponseMode
}
err = memorystore.Provider.UpdateEnvStore(envData) err = memorystore.Provider.UpdateEnvStore(envData)
if err != nil { if err != nil {
log.Debug("Error while updating env store: ", err) log.Debug("Error while updating env store: ", err)

View File

@@ -21,6 +21,7 @@ require (
github.com/joho/godotenv v1.3.0 github.com/joho/godotenv v1.3.0
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/pelletier/go-toml/v2 v2.0.5 // indirect github.com/pelletier/go-toml/v2 v2.0.5 // indirect
github.com/redis/go-redis/v9 v9.0.3 // indirect
github.com/robertkrimen/otto v0.0.0-20211024170158-b87d35c0b86f github.com/robertkrimen/otto v0.0.0-20211024170158-b87d35c0b86f
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0

View File

@@ -58,11 +58,15 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYE
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w=
github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cenkalti/backoff/v4 v4.1.2 h1:6Yo7N8UP2K6LWZnW94DLVSSrbobcWdVzAYOisuDPIFo= github.com/cenkalti/backoff/v4 v4.1.2 h1:6Yo7N8UP2K6LWZnW94DLVSSrbobcWdVzAYOisuDPIFo=
github.com/cenkalti/backoff/v4 v4.1.2/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/cenkalti/backoff/v4 v4.1.2/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
@@ -295,6 +299,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k=
github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6OkFY5QxjkYwrChwuRruF69c169dPK26NUlk=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/robertkrimen/otto v0.0.0-20211024170158-b87d35c0b86f h1:a7clxaGmmqtdNTXyvrp/lVO/Gnkzlhc/+dLs5v965GM= github.com/robertkrimen/otto v0.0.0-20211024170158-b87d35c0b86f h1:a7clxaGmmqtdNTXyvrp/lVO/Gnkzlhc/+dLs5v965GM=

View File

@@ -88,6 +88,8 @@ type ComplexityRoot struct {
DatabaseType func(childComplexity int) int DatabaseType func(childComplexity int) int
DatabaseURL func(childComplexity int) int DatabaseURL func(childComplexity int) int
DatabaseUsername func(childComplexity int) int DatabaseUsername func(childComplexity int) int
DefaultAuthorizeResponseMode func(childComplexity int) int
DefaultAuthorizeResponseType func(childComplexity int) int
DefaultRoles func(childComplexity int) int DefaultRoles func(childComplexity int) int
DisableBasicAuthentication func(childComplexity int) int DisableBasicAuthentication func(childComplexity int) int
DisableEmailVerification func(childComplexity int) int DisableEmailVerification func(childComplexity int) int
@@ -275,13 +277,14 @@ type ComplexityRoot struct {
} }
Webhook struct { Webhook struct {
CreatedAt func(childComplexity int) int CreatedAt func(childComplexity int) int
Enabled func(childComplexity int) int Enabled func(childComplexity int) int
Endpoint func(childComplexity int) int Endpoint func(childComplexity int) int
EventName func(childComplexity int) int EventDescription func(childComplexity int) int
Headers func(childComplexity int) int EventName func(childComplexity int) int
ID func(childComplexity int) int Headers func(childComplexity int) int
UpdatedAt func(childComplexity int) int ID func(childComplexity int) int
UpdatedAt func(childComplexity int) int
} }
WebhookLog struct { WebhookLog struct {
@@ -607,6 +610,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.Env.DatabaseUsername(childComplexity), true return e.complexity.Env.DatabaseUsername(childComplexity), true
case "Env.DEFAULT_AUTHORIZE_RESPONSE_MODE":
if e.complexity.Env.DefaultAuthorizeResponseMode == nil {
break
}
return e.complexity.Env.DefaultAuthorizeResponseMode(childComplexity), true
case "Env.DEFAULT_AUTHORIZE_RESPONSE_TYPE":
if e.complexity.Env.DefaultAuthorizeResponseType == nil {
break
}
return e.complexity.Env.DefaultAuthorizeResponseType(childComplexity), true
case "Env.DEFAULT_ROLES": case "Env.DEFAULT_ROLES":
if e.complexity.Env.DefaultRoles == nil { if e.complexity.Env.DefaultRoles == nil {
break break
@@ -1833,6 +1850,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return e.complexity.Webhook.Endpoint(childComplexity), true return e.complexity.Webhook.Endpoint(childComplexity), true
case "Webhook.event_description":
if e.complexity.Webhook.EventDescription == nil {
break
}
return e.complexity.Webhook.EventDescription(childComplexity), true
case "Webhook.event_name": case "Webhook.event_name":
if e.complexity.Webhook.EventName == nil { if e.complexity.Webhook.EventName == nil {
break break
@@ -2195,6 +2219,8 @@ type Env {
ORGANIZATION_LOGO: String ORGANIZATION_LOGO: String
APP_COOKIE_SECURE: Boolean! APP_COOKIE_SECURE: Boolean!
ADMIN_COOKIE_SECURE: Boolean! ADMIN_COOKIE_SECURE: Boolean!
DEFAULT_AUTHORIZE_RESPONSE_TYPE: String
DEFAULT_AUTHORIZE_RESPONSE_MODE: String
} }
type ValidateJWTTokenResponse { type ValidateJWTTokenResponse {
@@ -2210,7 +2236,8 @@ type GenerateJWTKeysResponse {
type Webhook { type Webhook {
id: ID! id: ID!
event_name: String event_name: String # this is unique string
event_description: String
endpoint: String endpoint: String
enabled: Boolean enabled: Boolean
headers: Map headers: Map
@@ -2308,6 +2335,8 @@ input UpdateEnvInput {
MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID: String MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID: String
ORGANIZATION_NAME: String ORGANIZATION_NAME: String
ORGANIZATION_LOGO: String ORGANIZATION_LOGO: String
DEFAULT_AUTHORIZE_RESPONSE_TYPE: String
DEFAULT_AUTHORIZE_RESPONSE_MODE: String
} }
input AdminLoginInput { input AdminLoginInput {
@@ -2501,6 +2530,7 @@ input ListWebhookLogRequest {
input AddWebhookRequest { input AddWebhookRequest {
event_name: String! event_name: String!
event_description: String
endpoint: String! endpoint: String!
enabled: Boolean! enabled: Boolean!
headers: Map headers: Map
@@ -2509,6 +2539,7 @@ input AddWebhookRequest {
input UpdateWebhookRequest { input UpdateWebhookRequest {
id: ID! id: ID!
event_name: String event_name: String
event_description: String
endpoint: String endpoint: String
enabled: Boolean enabled: Boolean
headers: Map headers: Map
@@ -6413,6 +6444,88 @@ func (ec *executionContext) fieldContext_Env_ADMIN_COOKIE_SECURE(ctx context.Con
return fc, nil return fc, nil
} }
func (ec *executionContext) _Env_DEFAULT_AUTHORIZE_RESPONSE_TYPE(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Env_DEFAULT_AUTHORIZE_RESPONSE_TYPE(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.DefaultAuthorizeResponseType, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_Env_DEFAULT_AUTHORIZE_RESPONSE_TYPE(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Env",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _Env_DEFAULT_AUTHORIZE_RESPONSE_MODE(ctx context.Context, field graphql.CollectedField, obj *model.Env) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Env_DEFAULT_AUTHORIZE_RESPONSE_MODE(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.DefaultAuthorizeResponseMode, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_Env_DEFAULT_AUTHORIZE_RESPONSE_MODE(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Env",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _Error_message(ctx context.Context, field graphql.CollectedField, obj *model.Error) (ret graphql.Marshaler) { func (ec *executionContext) _Error_message(ctx context.Context, field graphql.CollectedField, obj *model.Error) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Error_message(ctx, field) fc, err := ec.fieldContext_Error_message(ctx, field)
if err != nil { if err != nil {
@@ -10093,6 +10206,10 @@ func (ec *executionContext) fieldContext_Query__env(ctx context.Context, field g
return ec.fieldContext_Env_APP_COOKIE_SECURE(ctx, field) return ec.fieldContext_Env_APP_COOKIE_SECURE(ctx, field)
case "ADMIN_COOKIE_SECURE": case "ADMIN_COOKIE_SECURE":
return ec.fieldContext_Env_ADMIN_COOKIE_SECURE(ctx, field) return ec.fieldContext_Env_ADMIN_COOKIE_SECURE(ctx, field)
case "DEFAULT_AUTHORIZE_RESPONSE_TYPE":
return ec.fieldContext_Env_DEFAULT_AUTHORIZE_RESPONSE_TYPE(ctx, field)
case "DEFAULT_AUTHORIZE_RESPONSE_MODE":
return ec.fieldContext_Env_DEFAULT_AUTHORIZE_RESPONSE_MODE(ctx, field)
} }
return nil, fmt.Errorf("no field named %q was found under type Env", field.Name) return nil, fmt.Errorf("no field named %q was found under type Env", field.Name)
}, },
@@ -10143,6 +10260,8 @@ func (ec *executionContext) fieldContext_Query__webhook(ctx context.Context, fie
return ec.fieldContext_Webhook_id(ctx, field) return ec.fieldContext_Webhook_id(ctx, field)
case "event_name": case "event_name":
return ec.fieldContext_Webhook_event_name(ctx, field) return ec.fieldContext_Webhook_event_name(ctx, field)
case "event_description":
return ec.fieldContext_Webhook_event_description(ctx, field)
case "endpoint": case "endpoint":
return ec.fieldContext_Webhook_endpoint(ctx, field) return ec.fieldContext_Webhook_endpoint(ctx, field)
case "enabled": case "enabled":
@@ -12201,6 +12320,47 @@ func (ec *executionContext) fieldContext_Webhook_event_name(ctx context.Context,
return fc, nil return fc, nil
} }
func (ec *executionContext) _Webhook_event_description(ctx context.Context, field graphql.CollectedField, obj *model.Webhook) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Webhook_event_description(ctx, field)
if err != nil {
return graphql.Null
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func() {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = graphql.Null
}
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return obj.EventDescription, nil
})
if err != nil {
ec.Error(ctx, err)
return graphql.Null
}
if resTmp == nil {
return graphql.Null
}
res := resTmp.(*string)
fc.Result = res
return ec.marshalOString2ᚖstring(ctx, field.Selections, res)
}
func (ec *executionContext) fieldContext_Webhook_event_description(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: "Webhook",
Field: field,
IsMethod: false,
IsResolver: false,
Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
return nil, errors.New("field of type String does not have child fields")
},
}
return fc, nil
}
func (ec *executionContext) _Webhook_endpoint(ctx context.Context, field graphql.CollectedField, obj *model.Webhook) (ret graphql.Marshaler) { func (ec *executionContext) _Webhook_endpoint(ctx context.Context, field graphql.CollectedField, obj *model.Webhook) (ret graphql.Marshaler) {
fc, err := ec.fieldContext_Webhook_endpoint(ctx, field) fc, err := ec.fieldContext_Webhook_endpoint(ctx, field)
if err != nil { if err != nil {
@@ -12907,6 +13067,8 @@ func (ec *executionContext) fieldContext_Webhooks_webhooks(ctx context.Context,
return ec.fieldContext_Webhook_id(ctx, field) return ec.fieldContext_Webhook_id(ctx, field)
case "event_name": case "event_name":
return ec.fieldContext_Webhook_event_name(ctx, field) return ec.fieldContext_Webhook_event_name(ctx, field)
case "event_description":
return ec.fieldContext_Webhook_event_description(ctx, field)
case "endpoint": case "endpoint":
return ec.fieldContext_Webhook_endpoint(ctx, field) return ec.fieldContext_Webhook_endpoint(ctx, field)
case "enabled": case "enabled":
@@ -14756,7 +14918,7 @@ func (ec *executionContext) unmarshalInputAddWebhookRequest(ctx context.Context,
asMap[k] = v asMap[k] = v
} }
fieldsInOrder := [...]string{"event_name", "endpoint", "enabled", "headers"} fieldsInOrder := [...]string{"event_name", "event_description", "endpoint", "enabled", "headers"}
for _, k := range fieldsInOrder { for _, k := range fieldsInOrder {
v, ok := asMap[k] v, ok := asMap[k]
if !ok { if !ok {
@@ -14771,6 +14933,14 @@ func (ec *executionContext) unmarshalInputAddWebhookRequest(ctx context.Context,
if err != nil { if err != nil {
return it, err return it, err
} }
case "event_description":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("event_description"))
it.EventDescription, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
case "endpoint": case "endpoint":
var err error var err error
@@ -15952,7 +16122,7 @@ func (ec *executionContext) unmarshalInputUpdateEnvInput(ctx context.Context, ob
asMap[k] = v asMap[k] = v
} }
fieldsInOrder := [...]string{"ACCESS_TOKEN_EXPIRY_TIME", "ADMIN_SECRET", "CUSTOM_ACCESS_TOKEN_SCRIPT", "OLD_ADMIN_SECRET", "SMTP_HOST", "SMTP_PORT", "SMTP_USERNAME", "SMTP_PASSWORD", "SMTP_LOCAL_NAME", "SENDER_EMAIL", "JWT_TYPE", "JWT_SECRET", "JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", "ALLOWED_ORIGINS", "APP_URL", "RESET_PASSWORD_URL", "APP_COOKIE_SECURE", "ADMIN_COOKIE_SECURE", "DISABLE_EMAIL_VERIFICATION", "DISABLE_BASIC_AUTHENTICATION", "DISABLE_MAGIC_LINK_LOGIN", "DISABLE_LOGIN_PAGE", "DISABLE_SIGN_UP", "DISABLE_REDIS_FOR_ENV", "DISABLE_STRONG_PASSWORD", "DISABLE_MULTI_FACTOR_AUTHENTICATION", "ENFORCE_MULTI_FACTOR_AUTHENTICATION", "ROLES", "PROTECTED_ROLES", "DEFAULT_ROLES", "JWT_ROLE_CLAIM", "GOOGLE_CLIENT_ID", "GOOGLE_CLIENT_SECRET", "GITHUB_CLIENT_ID", "GITHUB_CLIENT_SECRET", "FACEBOOK_CLIENT_ID", "FACEBOOK_CLIENT_SECRET", "LINKEDIN_CLIENT_ID", "LINKEDIN_CLIENT_SECRET", "APPLE_CLIENT_ID", "APPLE_CLIENT_SECRET", "TWITTER_CLIENT_ID", "TWITTER_CLIENT_SECRET", "MICROSOFT_CLIENT_ID", "MICROSOFT_CLIENT_SECRET", "MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID", "ORGANIZATION_NAME", "ORGANIZATION_LOGO"} fieldsInOrder := [...]string{"ACCESS_TOKEN_EXPIRY_TIME", "ADMIN_SECRET", "CUSTOM_ACCESS_TOKEN_SCRIPT", "OLD_ADMIN_SECRET", "SMTP_HOST", "SMTP_PORT", "SMTP_USERNAME", "SMTP_PASSWORD", "SMTP_LOCAL_NAME", "SENDER_EMAIL", "JWT_TYPE", "JWT_SECRET", "JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", "ALLOWED_ORIGINS", "APP_URL", "RESET_PASSWORD_URL", "APP_COOKIE_SECURE", "ADMIN_COOKIE_SECURE", "DISABLE_EMAIL_VERIFICATION", "DISABLE_BASIC_AUTHENTICATION", "DISABLE_MAGIC_LINK_LOGIN", "DISABLE_LOGIN_PAGE", "DISABLE_SIGN_UP", "DISABLE_REDIS_FOR_ENV", "DISABLE_STRONG_PASSWORD", "DISABLE_MULTI_FACTOR_AUTHENTICATION", "ENFORCE_MULTI_FACTOR_AUTHENTICATION", "ROLES", "PROTECTED_ROLES", "DEFAULT_ROLES", "JWT_ROLE_CLAIM", "GOOGLE_CLIENT_ID", "GOOGLE_CLIENT_SECRET", "GITHUB_CLIENT_ID", "GITHUB_CLIENT_SECRET", "FACEBOOK_CLIENT_ID", "FACEBOOK_CLIENT_SECRET", "LINKEDIN_CLIENT_ID", "LINKEDIN_CLIENT_SECRET", "APPLE_CLIENT_ID", "APPLE_CLIENT_SECRET", "TWITTER_CLIENT_ID", "TWITTER_CLIENT_SECRET", "MICROSOFT_CLIENT_ID", "MICROSOFT_CLIENT_SECRET", "MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID", "ORGANIZATION_NAME", "ORGANIZATION_LOGO", "DEFAULT_AUTHORIZE_RESPONSE_TYPE", "DEFAULT_AUTHORIZE_RESPONSE_MODE"}
for _, k := range fieldsInOrder { for _, k := range fieldsInOrder {
v, ok := asMap[k] v, ok := asMap[k]
if !ok { if !ok {
@@ -16351,6 +16521,22 @@ func (ec *executionContext) unmarshalInputUpdateEnvInput(ctx context.Context, ob
if err != nil { if err != nil {
return it, err return it, err
} }
case "DEFAULT_AUTHORIZE_RESPONSE_TYPE":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("DEFAULT_AUTHORIZE_RESPONSE_TYPE"))
it.DefaultAuthorizeResponseType, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
case "DEFAULT_AUTHORIZE_RESPONSE_MODE":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("DEFAULT_AUTHORIZE_RESPONSE_MODE"))
it.DefaultAuthorizeResponseMode, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
} }
} }
@@ -16612,7 +16798,7 @@ func (ec *executionContext) unmarshalInputUpdateWebhookRequest(ctx context.Conte
asMap[k] = v asMap[k] = v
} }
fieldsInOrder := [...]string{"id", "event_name", "endpoint", "enabled", "headers"} fieldsInOrder := [...]string{"id", "event_name", "event_description", "endpoint", "enabled", "headers"}
for _, k := range fieldsInOrder { for _, k := range fieldsInOrder {
v, ok := asMap[k] v, ok := asMap[k]
if !ok { if !ok {
@@ -16635,6 +16821,14 @@ func (ec *executionContext) unmarshalInputUpdateWebhookRequest(ctx context.Conte
if err != nil { if err != nil {
return it, err return it, err
} }
case "event_description":
var err error
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("event_description"))
it.EventDescription, err = ec.unmarshalOString2ᚖstring(ctx, v)
if err != nil {
return it, err
}
case "endpoint": case "endpoint":
var err error var err error
@@ -17257,6 +17451,14 @@ func (ec *executionContext) _Env(ctx context.Context, sel ast.SelectionSet, obj
if out.Values[i] == graphql.Null { if out.Values[i] == graphql.Null {
invalids++ invalids++
} }
case "DEFAULT_AUTHORIZE_RESPONSE_TYPE":
out.Values[i] = ec._Env_DEFAULT_AUTHORIZE_RESPONSE_TYPE(ctx, field, obj)
case "DEFAULT_AUTHORIZE_RESPONSE_MODE":
out.Values[i] = ec._Env_DEFAULT_AUTHORIZE_RESPONSE_MODE(ctx, field, obj)
default: default:
panic("unknown field " + strconv.Quote(field.Name)) panic("unknown field " + strconv.Quote(field.Name))
} }
@@ -18513,6 +18715,10 @@ func (ec *executionContext) _Webhook(ctx context.Context, sel ast.SelectionSet,
out.Values[i] = ec._Webhook_event_name(ctx, field, obj) out.Values[i] = ec._Webhook_event_name(ctx, field, obj)
case "event_description":
out.Values[i] = ec._Webhook_event_description(ctx, field, obj)
case "endpoint": case "endpoint":
out.Values[i] = ec._Webhook_endpoint(ctx, field, obj) out.Values[i] = ec._Webhook_endpoint(ctx, field, obj)

View File

@@ -10,10 +10,11 @@ type AddEmailTemplateRequest struct {
} }
type AddWebhookRequest struct { type AddWebhookRequest struct {
EventName string `json:"event_name"` EventName string `json:"event_name"`
Endpoint string `json:"endpoint"` EventDescription *string `json:"event_description"`
Enabled bool `json:"enabled"` Endpoint string `json:"endpoint"`
Headers map[string]interface{} `json:"headers"` Enabled bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
} }
type AdminLoginInput struct { type AdminLoginInput struct {
@@ -116,6 +117,8 @@ type Env struct {
OrganizationLogo *string `json:"ORGANIZATION_LOGO"` OrganizationLogo *string `json:"ORGANIZATION_LOGO"`
AppCookieSecure bool `json:"APP_COOKIE_SECURE"` AppCookieSecure bool `json:"APP_COOKIE_SECURE"`
AdminCookieSecure bool `json:"ADMIN_COOKIE_SECURE"` AdminCookieSecure bool `json:"ADMIN_COOKIE_SECURE"`
DefaultAuthorizeResponseType *string `json:"DEFAULT_AUTHORIZE_RESPONSE_TYPE"`
DefaultAuthorizeResponseMode *string `json:"DEFAULT_AUTHORIZE_RESPONSE_MODE"`
} }
type Error struct { type Error struct {
@@ -352,6 +355,8 @@ type UpdateEnvInput struct {
MicrosoftActiveDirectoryTenantID *string `json:"MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID"` MicrosoftActiveDirectoryTenantID *string `json:"MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID"`
OrganizationName *string `json:"ORGANIZATION_NAME"` OrganizationName *string `json:"ORGANIZATION_NAME"`
OrganizationLogo *string `json:"ORGANIZATION_LOGO"` OrganizationLogo *string `json:"ORGANIZATION_LOGO"`
DefaultAuthorizeResponseType *string `json:"DEFAULT_AUTHORIZE_RESPONSE_TYPE"`
DefaultAuthorizeResponseMode *string `json:"DEFAULT_AUTHORIZE_RESPONSE_MODE"`
} }
type UpdateProfileInput struct { type UpdateProfileInput struct {
@@ -387,11 +392,12 @@ type UpdateUserInput struct {
} }
type UpdateWebhookRequest struct { type UpdateWebhookRequest struct {
ID string `json:"id"` ID string `json:"id"`
EventName *string `json:"event_name"` EventName *string `json:"event_name"`
Endpoint *string `json:"endpoint"` EventDescription *string `json:"event_description"`
Enabled *bool `json:"enabled"` Endpoint *string `json:"endpoint"`
Headers map[string]interface{} `json:"headers"` Enabled *bool `json:"enabled"`
Headers map[string]interface{} `json:"headers"`
} }
type User struct { type User struct {
@@ -461,13 +467,14 @@ type VerifyOTPRequest struct {
} }
type Webhook struct { type Webhook struct {
ID string `json:"id"` ID string `json:"id"`
EventName *string `json:"event_name"` EventName *string `json:"event_name"`
Endpoint *string `json:"endpoint"` EventDescription *string `json:"event_description"`
Enabled *bool `json:"enabled"` Endpoint *string `json:"endpoint"`
Headers map[string]interface{} `json:"headers"` Enabled *bool `json:"enabled"`
CreatedAt *int64 `json:"created_at"` Headers map[string]interface{} `json:"headers"`
UpdatedAt *int64 `json:"updated_at"` CreatedAt *int64 `json:"created_at"`
UpdatedAt *int64 `json:"updated_at"`
} }
type WebhookLog struct { type WebhookLog struct {

View File

@@ -153,6 +153,8 @@ type Env {
ORGANIZATION_LOGO: String ORGANIZATION_LOGO: String
APP_COOKIE_SECURE: Boolean! APP_COOKIE_SECURE: Boolean!
ADMIN_COOKIE_SECURE: Boolean! ADMIN_COOKIE_SECURE: Boolean!
DEFAULT_AUTHORIZE_RESPONSE_TYPE: String
DEFAULT_AUTHORIZE_RESPONSE_MODE: String
} }
type ValidateJWTTokenResponse { type ValidateJWTTokenResponse {
@@ -168,7 +170,8 @@ type GenerateJWTKeysResponse {
type Webhook { type Webhook {
id: ID! id: ID!
event_name: String event_name: String # this is unique string
event_description: String
endpoint: String endpoint: String
enabled: Boolean enabled: Boolean
headers: Map headers: Map
@@ -266,6 +269,8 @@ input UpdateEnvInput {
MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID: String MICROSOFT_ACTIVE_DIRECTORY_TENANT_ID: String
ORGANIZATION_NAME: String ORGANIZATION_NAME: String
ORGANIZATION_LOGO: String ORGANIZATION_LOGO: String
DEFAULT_AUTHORIZE_RESPONSE_TYPE: String
DEFAULT_AUTHORIZE_RESPONSE_MODE: String
} }
input AdminLoginInput { input AdminLoginInput {
@@ -459,6 +464,7 @@ input ListWebhookLogRequest {
input AddWebhookRequest { input AddWebhookRequest {
event_name: String! event_name: String!
event_description: String
endpoint: String! endpoint: String!
enabled: Boolean! enabled: Boolean!
headers: Map headers: Map
@@ -467,6 +473,7 @@ input AddWebhookRequest {
input UpdateWebhookRequest { input UpdateWebhookRequest {
id: ID! id: ID!
event_name: String event_name: String
event_description: String
endpoint: String endpoint: String
enabled: Boolean enabled: Boolean
headers: Map headers: Map

View File

@@ -83,7 +83,11 @@ func AuthorizeHandler() gin.HandlerFunc {
} }
if responseMode == "" { if responseMode == "" {
responseMode = constants.ResponseModeQuery if val, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultAuthorizeResponseMode); err == nil {
responseType = val
} else {
responseType = constants.ResponseModeQuery
}
} }
if redirectURI == "" { if redirectURI == "" {
@@ -91,7 +95,11 @@ func AuthorizeHandler() gin.HandlerFunc {
} }
if responseType == "" { if responseType == "" {
responseType = "token" if val, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyDefaultAuthorizeResponseType); err == nil {
responseType = val
} else {
responseType = constants.ResponseTypeToken
}
} }
if err := validateAuthorizeRequest(responseType, responseMode, clientID, state, codeChallenge); err != nil { if err := validateAuthorizeRequest(responseType, responseMode, clientID, state, codeChallenge); err != nil {
@@ -186,7 +194,7 @@ func AuthorizeHandler() gin.HandlerFunc {
// rollover the session for security // rollover the session for security
go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce) go memorystore.Provider.DeleteUserSession(sessionKey, claims.Nonce)
if responseType == constants.ResponseTypeCode { if responseType == constants.ResponseTypeCode {
newSessionTokenData, newSessionToken, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod) newSessionTokenData, newSessionToken, newSessionExpiresAt, err := token.CreateSessionToken(user, nonce, claims.Roles, scope, claims.LoginMethod)
if err != nil { if err != nil {
log.Debug("CreateSessionToken failed: ", err) log.Debug("CreateSessionToken failed: ", err)
handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK)
@@ -207,7 +215,7 @@ func AuthorizeHandler() gin.HandlerFunc {
return return
} }
if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+newSessionTokenData.Nonce, newSessionToken); err != nil { if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+newSessionTokenData.Nonce, newSessionToken, newSessionExpiresAt); err != nil {
log.Debug("SetUserSession failed: ", err) log.Debug("SetUserSession failed: ", err)
handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK)
return return
@@ -263,13 +271,13 @@ func AuthorizeHandler() gin.HandlerFunc {
return return
} }
if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+nonce, authToken.FingerPrintHash); err != nil { if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+nonce, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt); err != nil {
log.Debug("SetUserSession failed: ", err) log.Debug("SetUserSession failed: ", err)
handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK)
return return
} }
if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+nonce, authToken.AccessToken.Token); err != nil { if err := memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+nonce, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt); err != nil {
log.Debug("SetUserSession failed: ", err) log.Debug("SetUserSession failed: ", err)
handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK) handleResponse(gc, responseMode, loginURL, redirectURI, loginError, http.StatusOK)
return return
@@ -297,7 +305,7 @@ func AuthorizeHandler() gin.HandlerFunc {
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res["refresh_token"] = authToken.RefreshToken.Token res["refresh_token"] = authToken.RefreshToken.Token
params += "&refresh_token=" + authToken.RefreshToken.Token params += "&refresh_token=" + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
if responseMode == constants.ResponseModeQuery { if responseMode == constants.ResponseModeQuery {

View File

@@ -2,7 +2,7 @@ package handlers
import ( import (
"github.com/99designs/gqlgen/graphql/handler" "github.com/99designs/gqlgen/graphql/handler"
graph "github.com/authorizerdev/authorizer/server/graph" "github.com/authorizerdev/authorizer/server/graph"
"github.com/authorizerdev/authorizer/server/graph/generated" "github.com/authorizerdev/authorizer/server/graph/generated"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -47,7 +47,14 @@ func LogoutHandler() gin.HandlerFunc {
return return
} }
memorystore.Provider.DeleteUserSession(sessionData.Subject, sessionData.Nonce) userID := sessionData.Subject
loginMethod := sessionData.LoginMethod
sessionToken := userID
if loginMethod != "" {
sessionToken = loginMethod + ":" + userID
}
memorystore.Provider.DeleteUserSession(sessionToken, sessionData.Nonce)
cookie.DeleteSession(gc) cookie.DeleteSession(gc)
if redirectURL != "" { if redirectURL != "" {

View File

@@ -249,12 +249,12 @@ func OAuthCallbackHandler() gin.HandlerFunc {
sessionKey := provider + ":" + user.ID sessionKey := provider + ":" + user.ID
cookie.SetSession(ctx, authToken.FingerPrintHash) cookie.SetSession(ctx, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
params += `&refresh_token=` + authToken.RefreshToken.Token params += `&refresh_token=` + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
go func() { go func() {

View File

@@ -247,8 +247,8 @@ func TokenHandler() gin.HandlerFunc {
return return
} }
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix() expiresIn := authToken.AccessToken.ExpiresAt - time.Now().Unix()
@@ -266,7 +266,7 @@ func TokenHandler() gin.HandlerFunc {
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res["refresh_token"] = authToken.RefreshToken.Token res["refresh_token"] = authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
gc.JSON(http.StatusOK, res) gc.JSON(http.StatusOK, res)

View File

@@ -154,12 +154,12 @@ func VerifyEmailHandler() gin.HandlerFunc {
sessionKey := loginMethod + ":" + user.ID sessionKey := loginMethod + ":" + user.ID
cookie.SetSession(c, authToken.FingerPrintHash) cookie.SetSession(c, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
params = params + `&refresh_token=` + authToken.RefreshToken.Token params = params + `&refresh_token=` + authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
if redirectURL == "" { if redirectURL == "" {

View File

@@ -0,0 +1,14 @@
package inmemory
import (
"testing"
"github.com/authorizerdev/authorizer/server/memorystore/providers"
"github.com/stretchr/testify/assert"
)
func TestInMemoryProvider(t *testing.T) {
p, err := NewInMemoryProvider()
assert.NoError(t, err)
providers.ProviderTests(t, p)
}

View File

@@ -8,45 +8,31 @@ import (
) )
// SetUserSession sets the user session for given user identifier in form recipe:user_id // SetUserSession sets the user session for given user identifier in form recipe:user_id
func (c *provider) SetUserSession(userId, key, token string) error { func (c *provider) SetUserSession(userId, key, token string, expiration int64) error {
c.sessionStore.Set(userId, key, token) c.sessionStore.Set(userId, key, token, expiration)
return nil return nil
} }
// GetAllUserSessions returns all the user sessions token from the in-memory store.
func (c *provider) GetAllUserSessions(userId string) (map[string]string, error) {
data := c.sessionStore.GetAll(userId)
return data, nil
}
// GetUserSession returns value for given session token // GetUserSession returns value for given session token
func (c *provider) GetUserSession(userId, sessionToken string) (string, error) { func (c *provider) GetUserSession(userId, sessionToken string) (string, error) {
return c.sessionStore.Get(userId, sessionToken), nil val := c.sessionStore.Get(userId, sessionToken)
if val == "" {
return "", fmt.Errorf("Not found")
}
return val, nil
} }
// DeleteAllUserSessions deletes all the user sessions from in-memory store. // DeleteAllUserSessions deletes all the user sessions from in-memory store.
func (c *provider) DeleteAllUserSessions(userId string) error { func (c *provider) DeleteAllUserSessions(userId string) error {
namespaces := []string{ c.sessionStore.RemoveAll(userId)
constants.AuthRecipeMethodBasicAuth,
constants.AuthRecipeMethodMagicLinkLogin,
constants.AuthRecipeMethodApple,
constants.AuthRecipeMethodFacebook,
constants.AuthRecipeMethodGithub,
constants.AuthRecipeMethodGoogle,
constants.AuthRecipeMethodLinkedIn,
constants.AuthRecipeMethodTwitter,
constants.AuthRecipeMethodMicrosoft,
}
for _, namespace := range namespaces {
c.sessionStore.RemoveAll(namespace + ":" + userId)
}
return nil return nil
} }
// DeleteUserSession deletes the user session from the in-memory store. // DeleteUserSession deletes the user session from the in-memory store.
func (c *provider) DeleteUserSession(userId, sessionToken string) error { func (c *provider) DeleteUserSession(userId, sessionToken string) error {
c.sessionStore.Remove(userId, sessionToken) c.sessionStore.Remove(userId, constants.TokenTypeSessionToken+"_"+sessionToken)
c.sessionStore.Remove(userId, constants.TokenTypeAccessToken+"_"+sessionToken)
c.sessionStore.Remove(userId, constants.TokenTypeRefreshToken+"_"+sessionToken)
return nil return nil
} }

View File

@@ -31,11 +31,15 @@ func (e *EnvStore) UpdateStore(store map[string]interface{}) {
// GetStore returns the env store // GetStore returns the env store
func (e *EnvStore) GetStore() map[string]interface{} { func (e *EnvStore) GetStore() map[string]interface{} {
e.mutex.Lock()
defer e.mutex.Unlock()
return e.store return e.store
} }
// Get returns the value of the key in evn store // Get returns the value of the key in evn store
func (e *EnvStore) Get(key string) interface{} { func (e *EnvStore) Get(key string) interface{} {
e.mutex.Lock()
defer e.mutex.Unlock()
return e.store[key] return e.store[key]
} }

View File

@@ -1,73 +1,140 @@
package stores package stores
import ( import (
"fmt"
"sort"
"strings" "strings"
"sync" "sync"
"time"
) )
const (
// Maximum entries to keep in session storage
maxCacheSize = 1000
// Cache clear interval
clearInterval = 10 * time.Minute
)
// SessionEntry is the struct for entry stored in store
type SessionEntry struct {
Value string
ExpiresAt int64
}
// SessionStore struct to store the env variables // SessionStore struct to store the env variables
type SessionStore struct { type SessionStore struct {
mutex sync.Mutex wg sync.WaitGroup
store map[string]map[string]string mutex sync.RWMutex
store map[string]*SessionEntry
// stores expireTime: key to remove data when cache is full
// map is sorted by key so older most entry can be deleted first
keyIndex map[int64]string
stop chan struct{}
} }
// NewSessionStore create a new session store // NewSessionStore create a new session store
func NewSessionStore() *SessionStore { func NewSessionStore() *SessionStore {
return &SessionStore{ store := &SessionStore{
mutex: sync.Mutex{}, mutex: sync.RWMutex{},
store: make(map[string]map[string]string), store: make(map[string]*SessionEntry),
keyIndex: make(map[int64]string),
stop: make(chan struct{}),
}
store.wg.Add(1)
go func() {
defer store.wg.Done()
store.clean()
}()
return store
}
func (s *SessionStore) clean() {
t := time.NewTicker(clearInterval)
defer t.Stop()
for {
select {
case <-s.stop:
return
case <-t.C:
s.mutex.Lock()
currentTime := time.Now().Unix()
for k, v := range s.store {
if v.ExpiresAt < currentTime {
delete(s.store, k)
delete(s.keyIndex, v.ExpiresAt)
}
}
s.mutex.Unlock()
}
} }
} }
// Get returns the value of the key in state store // Get returns the value of the key in state store
func (s *SessionStore) Get(key, subKey string) string { func (s *SessionStore) Get(key, subKey string) string {
return s.store[key][subKey] s.mutex.RLock()
defer s.mutex.RUnlock()
currentTime := time.Now().Unix()
k := fmt.Sprintf("%s:%s", key, subKey)
if v, ok := s.store[k]; ok {
if v.ExpiresAt > currentTime {
return v.Value
}
// Delete expired items
delete(s.store, k)
}
return ""
} }
// Set sets the value of the key in state store // Set sets the value of the key in state store
func (s *SessionStore) Set(key string, subKey, value string) { func (s *SessionStore) Set(key string, subKey, value string, expiration int64) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
k := fmt.Sprintf("%s:%s", key, subKey)
if _, ok := s.store[key]; !ok { if _, ok := s.store[k]; !ok {
s.store[key] = make(map[string]string) // check if there is enough space in cache
// else delete entries based on FIFO
if len(s.store) == maxCacheSize {
// remove older most entry
sortedKeys := []int64{}
for ik := range s.keyIndex {
sortedKeys = append(sortedKeys, ik)
}
sort.Slice(sortedKeys, func(i, j int) bool { return sortedKeys[i] < sortedKeys[j] })
itemToRemove := sortedKeys[0]
delete(s.store, s.keyIndex[itemToRemove])
delete(s.keyIndex, itemToRemove)
}
} }
s.store[key][subKey] = value s.store[k] = &SessionEntry{
Value: value,
ExpiresAt: expiration,
}
s.keyIndex[expiration] = k
} }
// RemoveAll all values for given key // RemoveAll all values for given key
func (s *SessionStore) RemoveAll(key string) { func (s *SessionStore) RemoveAll(key string) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
for k := range s.store {
delete(s.store, key) if strings.Contains(k, key) {
delete(s.store, k)
}
}
} }
// Remove value for given key and subkey // Remove value for given key and subkey
func (s *SessionStore) Remove(key, subKey string) { func (s *SessionStore) Remove(key, subKey string) {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if _, ok := s.store[key]; ok { k := fmt.Sprintf("%s:%s", key, subKey)
delete(s.store[key], subKey) delete(s.store, k)
}
}
// Get all the values for given key
func (s *SessionStore) GetAll(key string) map[string]string {
s.mutex.Lock()
defer s.mutex.Unlock()
if _, ok := s.store[key]; !ok {
s.store[key] = make(map[string]string)
}
return s.store[key]
} }
// RemoveByNamespace to delete session for a given namespace example google,github // RemoveByNamespace to delete session for a given namespace example google,github
func (s *SessionStore) RemoveByNamespace(namespace string) error { func (s *SessionStore) RemoveByNamespace(namespace string) error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
for key := range s.store { for key := range s.store {
if strings.Contains(key, namespace+":") { if strings.Contains(key, namespace+":") {
delete(s.store, key) delete(s.store, key)

View File

@@ -20,6 +20,8 @@ func NewStateStore() *StateStore {
// Get returns the value of the key in state store // Get returns the value of the key in state store
func (s *StateStore) Get(key string) string { func (s *StateStore) Get(key string) string {
s.mutex.Lock()
defer s.mutex.Unlock()
return s.store[key] return s.store[key]
} }

View File

@@ -0,0 +1,115 @@
package providers
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// ProviderTests runs all provider tests
func ProviderTests(t *testing.T, p Provider) {
err := p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
// Same user multiple session
err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
err = p.SetUserSession("auth_provider:123", "access_token_key1", "test_jwt1123", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
// Different user session
err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(5*time.Second).Unix())
assert.NoError(t, err)
err = p.SetUserSession("auth_provider:124", "access_token_key", "test_jwt124", time.Now().Add(5*time.Second).Unix())
assert.NoError(t, err)
// Different provider session
err = p.SetUserSession("auth_provider1:124", "session_token_key", "test_hash124", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
err = p.SetUserSession("auth_provider1:124", "access_token_key", "test_jwt124", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
// Different provider session
err = p.SetUserSession("auth_provider1:123", "session_token_key", "test_hash1123", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
err = p.SetUserSession("auth_provider1:123", "access_token_key", "test_jwt1123", time.Now().Add(60*time.Second).Unix())
assert.NoError(t, err)
// Get session
key, err := p.GetUserSession("auth_provider:123", "session_token_key")
assert.NoError(t, err)
assert.Equal(t, "test_hash123", key)
key, err = p.GetUserSession("auth_provider:123", "access_token_key")
assert.NoError(t, err)
assert.Equal(t, "test_jwt123", key)
key, err = p.GetUserSession("auth_provider:124", "session_token_key")
assert.NoError(t, err)
assert.Equal(t, "test_hash124", key)
key, err = p.GetUserSession("auth_provider:124", "access_token_key")
assert.NoError(t, err)
assert.Equal(t, "test_jwt124", key)
// Expire some tokens and make sure they are empty
time.Sleep(5 * time.Second)
key, err = p.GetUserSession("auth_provider:124", "session_token_key")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider:124", "access_token_key")
assert.Empty(t, key)
assert.Error(t, err)
// Delete user session
err = p.DeleteUserSession("auth_provider:123", "key")
assert.NoError(t, err)
err = p.DeleteUserSession("auth_provider:123", "key")
assert.NoError(t, err)
key, err = p.GetUserSession("auth_provider:123", "key")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider:123", "access_token_key")
assert.Empty(t, key)
assert.Error(t, err)
// Delete all user session
err = p.DeleteAllUserSessions("123")
assert.NoError(t, err)
err = p.DeleteAllUserSessions("123")
assert.NoError(t, err)
key, err = p.GetUserSession("auth_provider:123", "session_token_key1")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider:123", "access_token_key1")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider1:123", "session_token_key")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider1:123", "access_token_key")
assert.Empty(t, key)
assert.Error(t, err)
// Delete namespace
err = p.DeleteSessionForNamespace("auth_provider")
assert.NoError(t, err)
err = p.DeleteSessionForNamespace("auth_provider1")
assert.NoError(t, err)
key, err = p.GetUserSession("auth_provider:123", "session_token_key1")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider:123", "access_token_key1")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider1:123", "session_token_key")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider1:123", "access_token_key")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider:124", "session_token_key1")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider:124", "access_token_key1")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider1:124", "session_token_key")
assert.Empty(t, key)
assert.Error(t, err)
key, err = p.GetUserSession("auth_provider1:124", "access_token_key")
assert.Empty(t, key)
assert.Error(t, err)
}

View File

@@ -3,9 +3,7 @@ package providers
// Provider defines current memory store provider // Provider defines current memory store provider
type Provider interface { type Provider interface {
// SetUserSession sets the user session for given user identifier in form recipe:user_id // SetUserSession sets the user session for given user identifier in form recipe:user_id
SetUserSession(userId, key, token string) error SetUserSession(userId, key, token string, expiration int64) error
// GetAllUserSessions returns all the user sessions from the session store
GetAllUserSessions(userId string) (map[string]string, error)
// GetUserSession returns the session token for given token // GetUserSession returns the session token for given token
GetUserSession(userId, key string) (string, error) GetUserSession(userId, key string) (string, error)
// DeleteUserSession deletes the user session // DeleteUserSession deletes the user session

View File

@@ -5,7 +5,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-redis/redis/v8" "github.com/redis/go-redis/v9"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -17,10 +17,11 @@ type RedisClient interface {
HMGet(ctx context.Context, key string, fields ...string) *redis.SliceCmd HMGet(ctx context.Context, key string, fields ...string) *redis.SliceCmd
HSet(ctx context.Context, key string, values ...interface{}) *redis.IntCmd HSet(ctx context.Context, key string, values ...interface{}) *redis.IntCmd
HGet(ctx context.Context, key, field string) *redis.StringCmd HGet(ctx context.Context, key, field string) *redis.StringCmd
HGetAll(ctx context.Context, key string) *redis.StringStringMapCmd HGetAll(ctx context.Context, key string) *redis.MapStringStringCmd
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd
Get(ctx context.Context, key string) *redis.StringCmd Get(ctx context.Context, key string) *redis.StringCmd
Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd
Keys(ctx context.Context, pattern string) *redis.StringSliceCmd
} }
type provider struct { type provider struct {
@@ -31,7 +32,6 @@ type provider struct {
// NewRedisProvider returns a new redis provider // NewRedisProvider returns a new redis provider
func NewRedisProvider(redisURL string) (*provider, error) { func NewRedisProvider(redisURL string) (*provider, error) {
redisURLHostPortsList := strings.Split(redisURL, ",") redisURLHostPortsList := strings.Split(redisURL, ",")
if len(redisURLHostPortsList) > 1 { if len(redisURLHostPortsList) > 1 {
opt, err := redis.ParseURL(redisURLHostPortsList[0]) opt, err := redis.ParseURL(redisURLHostPortsList[0])
if err != nil { if err != nil {
@@ -70,7 +70,6 @@ func NewRedisProvider(redisURL string) (*provider, error) {
log.Debug("error connecting to redis: ", err) log.Debug("error connecting to redis: ", err)
return nil, err return nil, err
} }
return &provider{ return &provider{
ctx: ctx, ctx: ctx,
store: rdb, store: rdb,

View File

@@ -0,0 +1,15 @@
package redis
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/authorizerdev/authorizer/server/memorystore/providers"
)
func TestRedisProvider(t *testing.T) {
p, err := NewRedisProvider("redis://127.0.0.1:6379")
assert.NoError(t, err)
providers.ProviderTests(t, p)
}

View File

@@ -1,7 +1,9 @@
package redis package redis
import ( import (
"fmt"
"strconv" "strconv"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -15,29 +17,21 @@ var (
) )
// SetUserSession sets the user session for given user identifier in form recipe:user_id // SetUserSession sets the user session for given user identifier in form recipe:user_id
func (c *provider) SetUserSession(userId, key, token string) error { func (c *provider) SetUserSession(userId, key, token string, expiration int64) error {
err := c.store.HSet(c.ctx, userId, key, token).Err() currentTime := time.Now()
expireTime := time.Unix(expiration, 0)
duration := expireTime.Sub(currentTime)
err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, duration).Err()
if err != nil { if err != nil {
log.Debug("Error saving to redis: ", err) log.Debug("Error saving user session to redis: ", err)
return err return err
} }
return nil return nil
} }
// GetAllUserSessions returns all the user session token from the redis store.
func (c *provider) GetAllUserSessions(userID string) (map[string]string, error) {
data, err := c.store.HGetAll(c.ctx, userID).Result()
if err != nil {
log.Debug("error getting all user sessions from redis store: ", err)
return nil, err
}
return data, nil
}
// GetUserSession returns the user session from redis store. // GetUserSession returns the user session from redis store.
func (c *provider) GetUserSession(userId, key string) (string, error) { func (c *provider) GetUserSession(userId, key string) (string, error) {
data, err := c.store.HGet(c.ctx, userId, key).Result() data, err := c.store.Get(c.ctx, fmt.Sprintf("%s:%s", userId, key)).Result()
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -46,39 +40,34 @@ func (c *provider) GetUserSession(userId, key string) (string, error) {
// DeleteUserSession deletes the user session from redis store. // DeleteUserSession deletes the user session from redis store.
func (c *provider) DeleteUserSession(userId, key string) error { func (c *provider) DeleteUserSession(userId, key string) error {
if err := c.store.HDel(c.ctx, userId, constants.TokenTypeSessionToken+"_"+key).Err(); err != nil { if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeSessionToken+"_"+key)).Err(); err != nil {
log.Debug("Error deleting user session from redis: ", err) log.Debug("Error deleting user session from redis: ", err)
return err // continue
} }
if err := c.store.HDel(c.ctx, userId, constants.TokenTypeAccessToken+"_"+key).Err(); err != nil { if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeAccessToken+"_"+key)).Err(); err != nil {
log.Debug("Error deleting user session from redis: ", err) log.Debug("Error deleting user session from redis: ", err)
return err // continue
} }
if err := c.store.HDel(c.ctx, userId, constants.TokenTypeRefreshToken+"_"+key).Err(); err != nil { if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeRefreshToken+"_"+key)).Err(); err != nil {
log.Debug("Error deleting user session from redis: ", err) log.Debug("Error deleting user session from redis: ", err)
return err // continue
} }
return nil return nil
} }
// DeleteAllUserSessions deletes all the user session from redis // DeleteAllUserSessions deletes all the user session from redis
func (c *provider) DeleteAllUserSessions(userID string) error { func (c *provider) DeleteAllUserSessions(userID string) error {
namespaces := []string{ res := c.store.Keys(c.ctx, fmt.Sprintf("*%s*", userID))
constants.AuthRecipeMethodBasicAuth, if res.Err() != nil {
constants.AuthRecipeMethodMagicLinkLogin, log.Debug("Error getting all user sessions from redis: ", res.Err())
constants.AuthRecipeMethodApple, return res.Err()
constants.AuthRecipeMethodFacebook,
constants.AuthRecipeMethodGithub,
constants.AuthRecipeMethodGoogle,
constants.AuthRecipeMethodLinkedIn,
constants.AuthRecipeMethodTwitter,
constants.AuthRecipeMethodMicrosoft,
} }
for _, namespace := range namespaces { keys := res.Val()
err := c.store.Del(c.ctx, namespace+":"+userID).Err() for _, key := range keys {
err := c.store.Del(c.ctx, key).Err()
if err != nil { if err != nil {
log.Debug("Error deleting all user sessions from redis: ", err) log.Debug("Error deleting all user sessions from redis: ", err)
return err continue
} }
} }
return nil return nil
@@ -86,27 +75,19 @@ func (c *provider) DeleteAllUserSessions(userID string) error {
// DeleteSessionForNamespace to delete session for a given namespace example google,github // DeleteSessionForNamespace to delete session for a given namespace example google,github
func (c *provider) DeleteSessionForNamespace(namespace string) error { func (c *provider) DeleteSessionForNamespace(namespace string) error {
var cursor uint64 res := c.store.Keys(c.ctx, fmt.Sprintf("%s:*", namespace))
for { if res.Err() != nil {
keys := []string{} log.Debug("Error getting all user sessions from redis: ", res.Err())
keys, cursor, err := c.store.Scan(c.ctx, cursor, namespace+":*", 0).Result() return res.Err()
}
keys := res.Val()
for _, key := range keys {
err := c.store.Del(c.ctx, key).Err()
if err != nil { if err != nil {
log.Debugf("Error scanning keys for %s namespace: %s", namespace, err.Error()) log.Debug("Error deleting all user sessions from redis: ", err)
return err continue
}
for _, key := range keys {
err := c.store.Del(c.ctx, key).Err()
if err != nil {
log.Debugf("Error deleting sessions for %s namespace: %s", namespace, err.Error())
return err
}
}
if cursor == 0 { // no more keys
break
} }
} }
return nil return nil
} }

View File

@@ -9,6 +9,7 @@ import (
"github.com/authorizerdev/authorizer/server/db" "github.com/authorizerdev/authorizer/server/db"
"github.com/authorizerdev/authorizer/server/db/models" "github.com/authorizerdev/authorizer/server/db/models"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/token" "github.com/authorizerdev/authorizer/server/token"
"github.com/authorizerdev/authorizer/server/utils" "github.com/authorizerdev/authorizer/server/utils"
"github.com/authorizerdev/authorizer/server/validators" "github.com/authorizerdev/authorizer/server/validators"
@@ -22,32 +23,32 @@ func AddWebhookResolver(ctx context.Context, params model.AddWebhookRequest) (*m
log.Debug("Failed to get GinContext: ", err) log.Debug("Failed to get GinContext: ", err)
return nil, err return nil, err
} }
if !token.IsSuperAdmin(gc) { if !token.IsSuperAdmin(gc) {
log.Debug("Not logged in as super admin") log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized") return nil, fmt.Errorf("unauthorized")
} }
if !validators.IsValidWebhookEventName(params.EventName) { if !validators.IsValidWebhookEventName(params.EventName) {
log.Debug("Invalid Event Name: ", params.EventName) log.Debug("Invalid Event Name: ", params.EventName)
return nil, fmt.Errorf("invalid event name %s", params.EventName) return nil, fmt.Errorf("invalid event name %s", params.EventName)
} }
if strings.TrimSpace(params.Endpoint) == "" { if strings.TrimSpace(params.Endpoint) == "" {
log.Debug("empty endpoint not allowed") log.Debug("empty endpoint not allowed")
return nil, fmt.Errorf("empty endpoint not allowed") return nil, fmt.Errorf("empty endpoint not allowed")
} }
headerBytes, err := json.Marshal(params.Headers) headerBytes, err := json.Marshal(params.Headers)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if params.EventDescription == nil {
params.EventDescription = refs.NewStringRef(strings.Join(strings.Split(params.EventName, "."), " "))
}
_, err = db.Provider.AddWebhook(ctx, models.Webhook{ _, err = db.Provider.AddWebhook(ctx, models.Webhook{
EventName: params.EventName, EventDescription: refs.StringValue(params.EventDescription),
EndPoint: params.Endpoint, EventName: params.EventName,
Enabled: params.Enabled, EndPoint: params.Endpoint,
Headers: string(headerBytes), Enabled: params.Enabled,
Headers: string(headerBytes),
}) })
if err != nil { if err != nil {
log.Debug("Failed to add webhook: ", err) log.Debug("Failed to add webhook: ", err)

View File

@@ -168,6 +168,12 @@ func EnvResolver(ctx context.Context) (*model.Env, error) {
if val, ok := store[constants.EnvKeyOrganizationLogo]; ok { if val, ok := store[constants.EnvKeyOrganizationLogo]; ok {
res.OrganizationLogo = refs.NewStringRef(val.(string)) res.OrganizationLogo = refs.NewStringRef(val.(string))
} }
if val, ok := store[constants.EnvKeyDefaultAuthorizeResponseType]; ok {
res.DefaultAuthorizeResponseType = refs.NewStringRef(val.(string))
}
if val, ok := store[constants.EnvKeyDefaultAuthorizeResponseMode]; ok {
res.DefaultAuthorizeResponseMode = refs.NewStringRef(val.(string))
}
// string slice vars // string slice vars
res.AllowedOrigins = strings.Split(store[constants.EnvKeyAllowedOrigins].(string), ",") res.AllowedOrigins = strings.Split(store[constants.EnvKeyAllowedOrigins].(string), ",")

View File

@@ -193,12 +193,12 @@ func LoginResolver(ctx context.Context, params model.LoginInput) (*model.AuthRes
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
sessionStoreKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID sessionStoreKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
go func() { go func() {

View File

@@ -195,12 +195,12 @@ func MobileLoginResolver(ctx context.Context, params model.MobileLoginInput) (*m
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
sessionStoreKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID sessionStoreKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionStoreKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
go func() { go func() {

View File

@@ -249,12 +249,12 @@ func MobileSignupResolver(ctx context.Context, params *model.MobileSignUpInput)
sessionKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID sessionKey := constants.AuthRecipeMethodMobileBasicAuth + ":" + user.ID
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
go func() { go func() {

View File

@@ -99,12 +99,12 @@ func SessionResolver(ctx context.Context, params *model.SessionQueryInput) (*mod
} }
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
return res, nil return res, nil
} }

View File

@@ -91,7 +91,6 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
} }
inputRoles := []string{} inputRoles := []string{}
if len(params.Roles) > 0 { if len(params.Roles) > 0 {
// check if roles exists // check if roles exists
rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles) rolesString, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyRoles)
@@ -293,12 +292,12 @@ func SignupResolver(ctx context.Context, params model.SignUpInput) (*model.AuthR
sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
go func() { go func() {

View File

@@ -28,13 +28,11 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques
log.Debug("Not logged in as super admin") log.Debug("Not logged in as super admin")
return nil, fmt.Errorf("unauthorized") return nil, fmt.Errorf("unauthorized")
} }
webhook, err := db.Provider.GetWebhookByID(ctx, params.ID) webhook, err := db.Provider.GetWebhookByID(ctx, params.ID)
if err != nil { if err != nil {
log.Debug("failed to get webhook: ", err) log.Debug("failed to get webhook: ", err)
return nil, err return nil, err
} }
headersString := "" headersString := ""
if webhook.Headers != nil { if webhook.Headers != nil {
headerBytes, err := json.Marshal(webhook.Headers) headerBytes, err := json.Marshal(webhook.Headers)
@@ -43,17 +41,16 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques
} }
headersString = string(headerBytes) headersString = string(headerBytes)
} }
webhookDetails := models.Webhook{ webhookDetails := models.Webhook{
ID: webhook.ID, ID: webhook.ID,
Key: webhook.ID, Key: webhook.ID,
EventName: refs.StringValue(webhook.EventName), EventName: refs.StringValue(webhook.EventName),
EndPoint: refs.StringValue(webhook.Endpoint), EventDescription: refs.StringValue(webhook.EventDescription),
Enabled: refs.BoolValue(webhook.Enabled), EndPoint: refs.StringValue(webhook.Endpoint),
Headers: headersString, Enabled: refs.BoolValue(webhook.Enabled),
CreatedAt: refs.Int64Value(webhook.CreatedAt), Headers: headersString,
CreatedAt: refs.Int64Value(webhook.CreatedAt),
} }
if params.EventName != nil && webhookDetails.EventName != refs.StringValue(params.EventName) { if params.EventName != nil && webhookDetails.EventName != refs.StringValue(params.EventName) {
if isValid := validators.IsValidWebhookEventName(refs.StringValue(params.EventName)); !isValid { if isValid := validators.IsValidWebhookEventName(refs.StringValue(params.EventName)); !isValid {
log.Debug("invalid event name: ", refs.StringValue(params.EventName)) log.Debug("invalid event name: ", refs.StringValue(params.EventName))
@@ -61,7 +58,6 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques
} }
webhookDetails.EventName = refs.StringValue(params.EventName) webhookDetails.EventName = refs.StringValue(params.EventName)
} }
if params.Endpoint != nil && webhookDetails.EndPoint != refs.StringValue(params.Endpoint) { if params.Endpoint != nil && webhookDetails.EndPoint != refs.StringValue(params.Endpoint) {
if strings.TrimSpace(refs.StringValue(params.Endpoint)) == "" { if strings.TrimSpace(refs.StringValue(params.Endpoint)) == "" {
log.Debug("empty endpoint not allowed") log.Debug("empty endpoint not allowed")
@@ -69,11 +65,12 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques
} }
webhookDetails.EndPoint = refs.StringValue(params.Endpoint) webhookDetails.EndPoint = refs.StringValue(params.Endpoint)
} }
if params.Enabled != nil && webhookDetails.Enabled != refs.BoolValue(params.Enabled) { if params.Enabled != nil && webhookDetails.Enabled != refs.BoolValue(params.Enabled) {
webhookDetails.Enabled = refs.BoolValue(params.Enabled) webhookDetails.Enabled = refs.BoolValue(params.Enabled)
} }
if params.EventDescription != nil && webhookDetails.EventDescription != refs.StringValue(params.EventDescription) {
webhookDetails.EventDescription = refs.StringValue(params.EventDescription)
}
if params.Headers != nil { if params.Headers != nil {
headerBytes, err := json.Marshal(params.Headers) headerBytes, err := json.Marshal(params.Headers)
if err != nil { if err != nil {
@@ -83,12 +80,10 @@ func UpdateWebhookResolver(ctx context.Context, params model.UpdateWebhookReques
webhookDetails.Headers = string(headerBytes) webhookDetails.Headers = string(headerBytes)
} }
_, err = db.Provider.UpdateWebhook(ctx, webhookDetails) _, err = db.Provider.UpdateWebhook(ctx, webhookDetails)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &model.Response{ return &model.Response{
Message: `Webhook updated successfully.`, Message: `Webhook updated successfully.`,
}, nil }, nil

View File

@@ -150,12 +150,12 @@ func VerifyEmailResolver(ctx context.Context, params model.VerifyEmailInput) (*m
sessionKey := loginMethod + ":" + user.ID sessionKey := loginMethod + ":" + user.ID
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
return res, nil return res, nil
} }

View File

@@ -123,12 +123,12 @@ func VerifyOtpResolver(ctx context.Context, params model.VerifyOTPRequest) (*mod
sessionKey := loginMethod + ":" + user.ID sessionKey := loginMethod + ":" + user.ID
cookie.SetSession(gc, authToken.FingerPrintHash) cookie.SetSession(gc, authToken.FingerPrintHash)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
res.RefreshToken = &authToken.RefreshToken.Token res.RefreshToken = &authToken.RefreshToken.Token
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
return res, nil return res, nil
} }

View File

@@ -3,11 +3,13 @@ package test
import ( import (
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/crypto"
"github.com/authorizerdev/authorizer/server/graph/model" "github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/memorystore"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/resolvers" "github.com/authorizerdev/authorizer/server/resolvers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -21,7 +23,6 @@ func addWebhookTest(t *testing.T, s TestSetup) {
h, err := crypto.EncryptPassword(adminSecret) h, err := crypto.EncryptPassword(adminSecret)
assert.NoError(t, err) assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h))
for _, eventType := range s.TestInfo.TestWebhookEventTypes { for _, eventType := range s.TestInfo.TestWebhookEventTypes {
webhook, err := resolvers.AddWebhookResolver(ctx, model.AddWebhookRequest{ webhook, err := resolvers.AddWebhookResolver(ctx, model.AddWebhookRequest{
EventName: eventType, EventName: eventType,
@@ -35,5 +36,21 @@ func addWebhookTest(t *testing.T, s TestSetup) {
assert.NotNil(t, webhook) assert.NotNil(t, webhook)
assert.NotEmpty(t, webhook.Message) assert.NotEmpty(t, webhook.Message)
} }
time.Sleep(2 * time.Second)
// Allow setting multiple webhooks for same event
for _, eventType := range s.TestInfo.TestWebhookEventTypes {
webhook, err := resolvers.AddWebhookResolver(ctx, model.AddWebhookRequest{
EventName: eventType,
Endpoint: s.TestInfo.WebhookEndpoint,
Enabled: true,
EventDescription: refs.NewStringRef(eventType + "-2"),
Headers: map[string]interface{}{
"x-test": "foo",
},
})
assert.NoError(t, err)
assert.NotNil(t, webhook)
assert.NotEmpty(t, webhook.Message)
}
}) })
} }

View File

@@ -25,6 +25,6 @@ func adminSignupTests(t *testing.T, s TestSetup) {
_, err = resolvers.AdminSignupResolver(ctx, model.AdminSignupInput{ _, err = resolvers.AdminSignupResolver(ctx, model.AdminSignupInput{
AdminSecret: "admin123", AdminSecret: "admin123",
}) })
assert.Nil(t, err) assert.NoError(t, err)
}) })
} }

View File

@@ -25,7 +25,7 @@ func deleteWebhookTest(t *testing.T, s TestSetup) {
// get all webhooks // get all webhooks
webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{ webhooks, err := db.Provider.ListWebhook(ctx, model.Pagination{
Limit: 10, Limit: 20,
Page: 1, Page: 1,
Offset: 0, Offset: 0,
}) })
@@ -42,7 +42,7 @@ func deleteWebhookTest(t *testing.T, s TestSetup) {
} }
webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{ webhooks, err = db.Provider.ListWebhook(ctx, model.Pagination{
Limit: 10, Limit: 20,
Page: 1, Page: 1,
Offset: 0, Offset: 0,
}) })

View File

@@ -23,6 +23,8 @@ func enableAccessTest(t *testing.T, s TestSetup) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
}) })

View File

@@ -15,17 +15,18 @@ func forgotPasswordTest(t *testing.T, s TestSetup) {
t.Run(`should run forgot password`, func(t *testing.T) { t.Run(`should run forgot password`, func(t *testing.T) {
_, ctx := createContext(s) _, ctx := createContext(s)
email := "forgot_password." + s.TestInfo.Email email := "forgot_password." + s.TestInfo.Email
_, err := resolvers.SignupResolver(ctx, model.SignUpInput{ res, err := resolvers.SignupResolver(ctx, model.SignUpInput{
Email: email, Email: email,
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NoError(t, err)
_, err = resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{ assert.NotNil(t, res)
forgotPasswordRes, err := resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{
Email: email, Email: email,
}) })
assert.Nil(t, err, "no errors for forgot password") assert.Nil(t, err, "no errors for forgot password")
assert.NotNil(t, forgotPasswordRes)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -41,21 +41,20 @@ func inviteUserTest(t *testing.T, s TestSetup) {
res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{ res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{
Emails: invalidEmailsTest, Emails: invalidEmailsTest,
}) })
assert.Error(t, err)
assert.Nil(t, res)
// valid test // valid test
res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{ res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{
Emails: emails, Emails: emails,
}) })
assert.Nil(t, err) assert.Nil(t, err)
assert.NotNil(t, res) assert.NotNil(t, res)
// duplicate error test // duplicate error test
res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{ res, err = resolvers.InviteMembersResolver(ctx, model.InviteMemberInput{
Emails: emails, Emails: emails,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, res) assert.Nil(t, res)
cleanData(emails[0]) cleanData(emails[0])
}) })
} }

View File

@@ -16,12 +16,13 @@ func loginTests(t *testing.T, s TestSetup) {
t.Run(`should login`, func(t *testing.T) { t.Run(`should login`, func(t *testing.T) {
_, ctx := createContext(s) _, ctx := createContext(s)
email := "login." + s.TestInfo.Email email := "login." + s.TestInfo.Email
_, err := resolvers.SignupResolver(ctx, model.SignUpInput{ signUpRes, err := resolvers.SignupResolver(ctx, model.SignUpInput{
Email: email, Email: email,
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NoError(t, err)
assert.NotNil(t, signUpRes)
res, err := resolvers.LoginResolver(ctx, model.LoginInput{ res, err := resolvers.LoginResolver(ctx, model.LoginInput{
Email: email, Email: email,
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
@@ -30,6 +31,8 @@ func loginTests(t *testing.T, s TestSetup) {
assert.NotNil(t, err, "should fail because email is not verified") assert.NotNil(t, err, "should fail because email is not verified")
assert.Nil(t, res) assert.Nil(t, res)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
n, err := utils.EncryptNonce(verificationRequest.Nonce) n, err := utils.EncryptNonce(verificationRequest.Nonce)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, n) assert.NotEmpty(t, n)

View File

@@ -20,22 +20,24 @@ func logoutTests(t *testing.T, s TestSetup) {
req, ctx := createContext(s) req, ctx := createContext(s)
email := "logout." + s.TestInfo.Email email := "logout." + s.TestInfo.Email
_, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{ magicLoginRes, err := resolvers.MagicLinkLoginResolver(ctx, model.MagicLinkLoginInput{
Email: email, Email: email,
}) })
assert.NoError(t, err)
assert.NotNil(t, magicLoginRes)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
}) })
assert.NoError(t, err)
assert.NotNil(t, verifyRes)
accessToken := *verifyRes.AccessToken accessToken := *verifyRes.AccessToken
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
claims, err := token.ParseJWTToken(accessToken) claims, err := token.ParseJWTToken(accessToken)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, claims) assert.NotEmpty(t, claims)
loginMethod := claims["login_method"] loginMethod := claims["login_method"]
sessionKey := verifyRes.User.ID sessionKey := verifyRes.User.ID
if loginMethod != nil && loginMethod != "" { if loginMethod != nil && loginMethod != "" {

View File

@@ -30,6 +30,8 @@ func magicLinkLoginTests(t *testing.T, s TestSetup) {
assert.Nil(t, err, "signup should be successful") assert.Nil(t, err, "signup should be successful")
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
}) })

View File

@@ -29,24 +29,25 @@ func mobileSingupTest(t *testing.T, s TestSetup) {
Password: "test", Password: "test",
ConfirmPassword: "test", ConfirmPassword: "test",
}) })
assert.NotNil(t, err, "invalid password") assert.Error(t, err)
assert.Nil(t, res)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true)
res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{
Email: refs.NewStringRef(email), Email: refs.NewStringRef(email),
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NotNil(t, err, "singup disabled") assert.Error(t, err)
assert.Nil(t, res)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableMobileBasicAuthentication, true) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableMobileBasicAuthentication, true)
res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{
Email: refs.NewStringRef(email), Email: refs.NewStringRef(email),
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NotNil(t, err, "singup disabled") assert.Error(t, err)
assert.Nil(t, res)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableMobileBasicAuthentication, false) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableMobileBasicAuthentication, false)
res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{
@@ -54,14 +55,16 @@ func mobileSingupTest(t *testing.T, s TestSetup) {
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NotNil(t, err, "invalid mobile") assert.Error(t, err)
assert.Nil(t, res)
res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{
PhoneNumber: "test", PhoneNumber: "test",
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NotNil(t, err, "invalid mobile") assert.Error(t, err)
assert.Nil(t, res)
res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{ res, err = resolvers.MobileSignupResolver(ctx, &model.MobileSignUpInput{
PhoneNumber: "1234567890", PhoneNumber: "1234567890",
@@ -77,7 +80,8 @@ func mobileSingupTest(t *testing.T, s TestSetup) {
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.Error(t, err, "user exists") assert.Error(t, err)
assert.Nil(t, res)
cleanData(email) cleanData(email)
cleanData("1234567890@authorizer.dev") cleanData("1234567890@authorizer.dev")

View File

@@ -27,6 +27,8 @@ func profileTests(t *testing.T, s TestSetup) {
assert.NotNil(t, err, "unauthorized") assert.NotNil(t, err, "unauthorized")
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
}) })

View File

@@ -44,10 +44,11 @@ func resendOTPTest(t *testing.T, s TestSetup) {
// Using access token update profile // Using access token update profile
s.GinContext.Request.Header.Set("Authorization", "Bearer "+refs.StringValue(verifyRes.AccessToken)) s.GinContext.Request.Header.Set("Authorization", "Bearer "+refs.StringValue(verifyRes.AccessToken))
ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext)
_, err = resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{ updateRes, err := resolvers.UpdateProfileResolver(ctx, model.UpdateProfileInput{
IsMultiFactorAuthEnabled: refs.NewBoolRef(true), IsMultiFactorAuthEnabled: refs.NewBoolRef(true),
}) })
assert.NoError(t, err)
assert.NotNil(t, updateRes)
// Resend otp should return error as no initial opt is being sent // Resend otp should return error as no initial opt is being sent
resendOtpRes, err := resolvers.ResendOTPResolver(ctx, model.ResendOTPRequest{ resendOtpRes, err := resolvers.ResendOTPResolver(ctx, model.ResendOTPRequest{
Email: email, Email: email,
@@ -87,7 +88,7 @@ func resendOTPTest(t *testing.T, s TestSetup) {
Otp: otp.Otp, Otp: otp.Otp,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, verifyOtpRes)
verifyOtpRes, err = resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{ verifyOtpRes, err = resolvers.VerifyOtpResolver(ctx, model.VerifyOTPRequest{
Email: email, Email: email,
Otp: newOtp.Otp, Otp: newOtp.Otp,

View File

@@ -19,13 +19,12 @@ func resendVerifyEmailTests(t *testing.T, s TestSetup) {
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NoError(t, err)
_, err = resolvers.ResendVerifyEmailResolver(ctx, model.ResendVerifyEmailInput{ _, err = resolvers.ResendVerifyEmailResolver(ctx, model.ResendVerifyEmailInput{
Email: email, Email: email,
Identifier: constants.VerificationTypeBasicAuthSignup, Identifier: constants.VerificationTypeBasicAuthSignup,
}) })
assert.NoError(t, err)
assert.Nil(t, err)
cleanData(email) cleanData(email)
}) })

View File

@@ -20,7 +20,7 @@ func resetPasswordTest(t *testing.T, s TestSetup) {
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NoError(t, err)
_, err = resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{ _, err = resolvers.ForgotPasswordResolver(ctx, model.ForgotPasswordInput{
Email: email, Email: email,
}) })
@@ -28,7 +28,7 @@ func resetPasswordTest(t *testing.T, s TestSetup) {
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword)
assert.Nil(t, err, "should get forgot password request") assert.Nil(t, err, "should get forgot password request")
assert.NotNil(t, verificationRequest)
_, err = resolvers.ResetPasswordResolver(ctx, model.ResetPasswordInput{ _, err = resolvers.ResetPasswordResolver(ctx, model.ResetPasswordInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
Password: "test1", Password: "test1",

View File

@@ -23,6 +23,8 @@ func revokeAccessTest(t *testing.T, s TestSetup) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
}) })
@@ -33,7 +35,7 @@ func revokeAccessTest(t *testing.T, s TestSetup) {
UserID: verifyRes.User.ID, UserID: verifyRes.User.ID,
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, res)
adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -30,10 +30,13 @@ func sessionTests(t *testing.T, s TestSetup) {
assert.NotNil(t, err, "unauthorized") assert.NotNil(t, err, "unauthorized")
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
}) })
assert.NoError(t, err)
assert.NotNil(t, verifyRes)
accessToken := *verifyRes.AccessToken accessToken := *verifyRes.AccessToken
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)

View File

@@ -22,14 +22,14 @@ func signupTests(t *testing.T, s TestSetup) {
ConfirmPassword: s.TestInfo.Password + "s", ConfirmPassword: s.TestInfo.Password + "s",
}) })
assert.NotNil(t, err, "invalid password") assert.NotNil(t, err, "invalid password")
assert.Nil(t, res)
res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ res, err = resolvers.SignupResolver(ctx, model.SignUpInput{
Email: email, Email: email,
Password: "test", Password: "test",
ConfirmPassword: "test", ConfirmPassword: "test",
}) })
assert.NotNil(t, err, "invalid password") assert.NotNil(t, err, "invalid password")
assert.Nil(t, res)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, true)
res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ res, err = resolvers.SignupResolver(ctx, model.SignUpInput{
Email: email, Email: email,
@@ -37,7 +37,7 @@ func signupTests(t *testing.T, s TestSetup) {
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NotNil(t, err, "singup disabled") assert.NotNil(t, err, "singup disabled")
assert.Nil(t, res)
memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false) memorystore.Provider.UpdateEnvVariable(constants.EnvKeyDisableSignUp, false)
res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ res, err = resolvers.SignupResolver(ctx, model.SignUpInput{
Email: email, Email: email,
@@ -48,15 +48,13 @@ func signupTests(t *testing.T, s TestSetup) {
user := *res.User user := *res.User
assert.Equal(t, email, user.Email) assert.Equal(t, email, user.Email)
assert.Nil(t, res.AccessToken, "access token should be nil") assert.Nil(t, res.AccessToken, "access token should be nil")
res, err = resolvers.SignupResolver(ctx, model.SignUpInput{ res, err = resolvers.SignupResolver(ctx, model.SignUpInput{
Email: email, Email: email,
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NotNil(t, err, "should throw duplicate email error") assert.NotNil(t, err, "should throw duplicate email error")
assert.Nil(t, res)
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, email, verificationRequest.Email) assert.Equal(t, email, verificationRequest.Email)

View File

@@ -40,31 +40,49 @@ func cleanData(email string) {
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup)
if err == nil { if err == nil {
err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest)
if err != nil {
log.Debug("DeleteVerificationRequest err", err)
}
} }
verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword) verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeForgotPassword)
if err == nil { if err == nil {
err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest)
if err != nil {
log.Debug("DeleteVerificationRequest err", err)
}
} }
verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeUpdateEmail) verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeUpdateEmail)
if err == nil { if err == nil {
err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest)
if err != nil {
log.Debug("DeleteVerificationRequest err", err)
}
} }
verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin) verificationRequest, err = db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeMagicLinkLogin)
if err == nil { if err == nil {
err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest) err = db.Provider.DeleteVerificationRequest(ctx, verificationRequest)
if err != nil {
log.Debug("DeleteVerificationRequest err", err)
}
} }
otp, err := db.Provider.GetOTPByEmail(ctx, email) otp, err := db.Provider.GetOTPByEmail(ctx, email)
if err == nil { if err == nil {
err = db.Provider.DeleteOTP(ctx, otp) err = db.Provider.DeleteOTP(ctx, otp)
if err != nil {
log.Debug("DeleteOTP err", err)
}
} }
dbUser, err := db.Provider.GetUserByEmail(ctx, email) dbUser, err := db.Provider.GetUserByEmail(ctx, email)
if err == nil { if err == nil {
db.Provider.DeleteUser(ctx, dbUser) err = db.Provider.DeleteUser(ctx, dbUser)
if err != nil {
log.Debug("DeleteUser err", err)
}
} }
} }

View File

@@ -17,15 +17,12 @@ func updateAllUsersTest(t *testing.T, s TestSetup) {
t.Helper() t.Helper()
t.Run("Should update all users", func(t *testing.T) { t.Run("Should update all users", func(t *testing.T) {
_, ctx := createContext(s) _, ctx := createContext(s)
users := []models.User{}
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
user := models.User{ user := models.User{
Email: fmt.Sprintf("update_all_user_%d_%s", i, s.TestInfo.Email), Email: fmt.Sprintf("update_all_user_%d_%s", i, s.TestInfo.Email),
SignupMethods: constants.AuthRecipeMethodBasicAuth, SignupMethods: constants.AuthRecipeMethodBasicAuth,
Roles: "user", Roles: "user",
} }
users = append(users, user)
u, err := db.Provider.AddUser(ctx, user) u, err := db.Provider.AddUser(ctx, user)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, u) assert.NotNil(t, u)
@@ -56,12 +53,15 @@ func updateAllUsersTest(t *testing.T, s TestSetup) {
Limit: 20, Limit: 20,
Offset: 0, Offset: 0,
}) })
assert.NoError(t, err)
assert.NotNil(t, listUsers)
for _, u := range listUsers.Users { for _, u := range listUsers.Users {
if utils.StringSliceContains(updateIds, u.ID) { if utils.StringSliceContains(updateIds, u.ID) {
assert.False(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) assert.False(t, refs.BoolValue(u.IsMultiFactorAuthEnabled))
} else { } else {
assert.True(t, refs.BoolValue(u.IsMultiFactorAuthEnabled)) assert.True(t, refs.BoolValue(u.IsMultiFactorAuthEnabled))
} }
cleanData(u.Email)
} }
}) })
} }

View File

@@ -30,11 +30,13 @@ func updateProfileTests(t *testing.T, s TestSetup) {
assert.NotNil(t, err, "unauthorized") assert.NotNil(t, err, "unauthorized")
verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup) verificationRequest, err := db.Provider.GetVerificationRequestByEmail(ctx, email, constants.VerificationTypeBasicAuthSignup)
assert.NoError(t, err)
assert.NotNil(t, verificationRequest)
verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{ verifyRes, err := resolvers.VerifyEmailResolver(ctx, model.VerifyEmailInput{
Token: verificationRequest.Token, Token: verificationRequest.Token,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, verifyRes)
s.GinContext.Request.Header.Set("Authorization", "Bearer "+*verifyRes.AccessToken) s.GinContext.Request.Header.Set("Authorization", "Bearer "+*verifyRes.AccessToken)
ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext) ctx = context.WithValue(req.Context(), "GinContextKey", s.GinContext)

View File

@@ -24,45 +24,73 @@ func updateWebhookTest(t *testing.T, s TestSetup) {
assert.NoError(t, err) assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h))
// get webhook // get webhook
webhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent) webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, webhook) assert.NotNil(t, webhooks)
// it should completely replace headers assert.Equal(t, 2, len(webhooks))
webhook.Headers = map[string]interface{}{ for _, webhook := range webhooks {
"x-new-test": "test", // it should completely replace headers
webhook.Headers = map[string]interface{}{
"x-new-test": "test",
}
res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{
ID: webhook.ID,
Headers: webhook.Headers,
Enabled: refs.NewBoolRef(false),
Endpoint: refs.NewStringRef("https://sometest.com"),
})
assert.NoError(t, err)
assert.NotEmpty(t, res)
assert.NotEmpty(t, res.Message)
} }
if len(webhooks) == 0 {
// avoid index out of range error
return
}
// Test updating webhook name
w := webhooks[0]
res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{
ID: webhook.ID, ID: w.ID,
Headers: webhook.Headers, EventName: refs.NewStringRef(constants.UserAccessEnabledWebhookEvent),
Enabled: refs.NewBoolRef(false),
Endpoint: refs.NewStringRef("https://sometest.com"),
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, res) assert.NotNil(t, res)
assert.NotEmpty(t, res.Message) // Check if webhooks with new name is as per expected len
accessWebhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserAccessEnabledWebhookEvent)
updatedWebhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, updatedWebhook) assert.Equal(t, 3, len(accessWebhooks))
assert.Equal(t, webhook.ID, updatedWebhook.ID) // Revert name change
assert.Equal(t, refs.StringValue(webhook.EventName), refs.StringValue(updatedWebhook.EventName))
assert.Len(t, updatedWebhook.Headers, 1)
assert.False(t, refs.BoolValue(updatedWebhook.Enabled))
for key, val := range updatedWebhook.Headers {
assert.Equal(t, val, webhook.Headers[key])
}
assert.Equal(t, refs.StringValue(updatedWebhook.Endpoint), "https://sometest.com")
res, err = resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{ res, err = resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{
ID: webhook.ID, ID: w.ID,
Headers: webhook.Headers, EventName: refs.NewStringRef(constants.UserDeletedWebhookEvent),
Enabled: refs.NewBoolRef(true),
Endpoint: refs.NewStringRef(s.TestInfo.WebhookEndpoint),
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, res) assert.NotNil(t, res)
assert.NotEmpty(t, res.Message) updatedWebhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserDeletedWebhookEvent)
assert.NoError(t, err)
assert.NotNil(t, updatedWebhooks)
assert.Equal(t, 2, len(updatedWebhooks))
for _, updatedWebhook := range updatedWebhooks {
assert.Contains(t, refs.StringValue(updatedWebhook.EventName), constants.UserDeletedWebhookEvent)
assert.Len(t, updatedWebhook.Headers, 1)
assert.False(t, refs.BoolValue(updatedWebhook.Enabled))
foundUpdatedHeader := false
for key, val := range updatedWebhook.Headers {
if key == "x-new-test" && val == "test" {
foundUpdatedHeader = true
}
}
assert.True(t, foundUpdatedHeader)
assert.Equal(t, "https://sometest.com", refs.StringValue(updatedWebhook.Endpoint))
res, err := resolvers.UpdateWebhookResolver(ctx, model.UpdateWebhookRequest{
ID: updatedWebhook.ID,
Headers: updatedWebhook.Headers,
Enabled: refs.NewBoolRef(true),
Endpoint: refs.NewStringRef(s.TestInfo.WebhookEndpoint),
})
assert.NoError(t, err)
assert.NotEmpty(t, res)
assert.NotEmpty(t, res.Message)
}
}) })
} }

View File

@@ -34,7 +34,7 @@ func usersTest(t *testing.T, s TestSetup) {
usersRes, err := resolvers.UsersResolver(ctx, pagination) usersRes, err := resolvers.UsersResolver(ctx, pagination)
assert.NotNil(t, err, "unauthorized") assert.NotNil(t, err, "unauthorized")
assert.Nil(t, usersRes)
adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)
assert.Nil(t, err) assert.Nil(t, err)
h, err := crypto.EncryptPassword(adminSecret) h, err := crypto.EncryptPassword(adminSecret)

View File

@@ -53,11 +53,13 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) {
sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID sessionKey := constants.AuthRecipeMethodBasicAuth + ":" + user.ID
nonce := uuid.New().String() nonce := uuid.New().String()
authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth, nonce, "") authToken, err := token.CreateAuthToken(gc, user, roles, scope, constants.AuthRecipeMethodBasicAuth, nonce, "")
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash) assert.NoError(t, err)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token) assert.NotNil(t, authToken)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeSessionToken+"_"+authToken.FingerPrint, authToken.FingerPrintHash, authToken.SessionTokenExpiresAt)
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeAccessToken+"_"+authToken.FingerPrint, authToken.AccessToken.Token, authToken.AccessToken.ExpiresAt)
if authToken.RefreshToken != nil { if authToken.RefreshToken != nil {
memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token) memorystore.Provider.SetUserSession(sessionKey, constants.TokenTypeRefreshToken+"_"+authToken.FingerPrint, authToken.RefreshToken.Token, authToken.RefreshToken.ExpiresAt)
} }
t.Run(`should validate the access token`, func(t *testing.T) { t.Run(`should validate the access token`, func(t *testing.T) {
@@ -74,8 +76,8 @@ func validateJwtTokenTest(t *testing.T, s TestSetup) {
Token: authToken.AccessToken.Token, Token: authToken.AccessToken.Token,
Roles: []string{"invalid_role"}, Roles: []string{"invalid_role"},
}) })
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, res)
}) })
t.Run(`should validate the refresh token`, func(t *testing.T) { t.Run(`should validate the refresh token`, func(t *testing.T) {

View File

@@ -17,17 +17,14 @@ func verificationRequestsTest(t *testing.T, s TestSetup) {
t.Run(`should get verification requests with admin secret only`, func(t *testing.T) { t.Run(`should get verification requests with admin secret only`, func(t *testing.T) {
req, ctx := createContext(s) req, ctx := createContext(s)
email := "verification_requests." + s.TestInfo.Email email := "verification_requests." + s.TestInfo.Email
res, err := resolvers.SignupResolver(ctx, model.SignUpInput{ res, err := resolvers.SignupResolver(ctx, model.SignUpInput{
Email: email, Email: email,
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, res) assert.NotNil(t, res)
limit := int64(10) limit := int64(10)
page := int64(1) page := int64(1)
pagination := &model.PaginatedInput{ pagination := &model.PaginatedInput{
@@ -39,6 +36,7 @@ func verificationRequestsTest(t *testing.T, s TestSetup) {
requests, err := resolvers.VerificationRequestsResolver(ctx, pagination) requests, err := resolvers.VerificationRequestsResolver(ctx, pagination)
assert.NotNil(t, err, "unauthorized") assert.NotNil(t, err, "unauthorized")
assert.Nil(t, requests)
adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret) adminSecret, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyAdminSecret)
assert.Nil(t, err) assert.Nil(t, err)

View File

@@ -20,7 +20,8 @@ func verifyEmailTest(t *testing.T, s TestSetup) {
Password: s.TestInfo.Password, Password: s.TestInfo.Password,
ConfirmPassword: s.TestInfo.Password, ConfirmPassword: s.TestInfo.Password,
}) })
assert.NoError(t, err)
assert.NotNil(t, res)
user := *res.User user := *res.User
assert.Equal(t, email, user.Email) assert.Equal(t, email, user.Email)
assert.Nil(t, res.AccessToken, "access token should be nil") assert.Nil(t, res.AccessToken, "access token should be nil")

View File

@@ -29,17 +29,20 @@ func webhookLogsTest(t *testing.T, s TestSetup) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Greater(t, len(webhookLogs.WebhookLogs), 1) assert.Greater(t, len(webhookLogs.WebhookLogs), 1)
webhooks, err := resolvers.WebhooksResolver(ctx, nil) webhooks, err := resolvers.WebhooksResolver(ctx, &model.PaginatedInput{
Pagination: &model.PaginationInput{
Limit: refs.NewInt64Ref(20),
},
})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, webhooks) assert.NotEmpty(t, webhooks)
for _, w := range webhooks.Webhooks { for _, w := range webhooks.Webhooks {
t.Run(fmt.Sprintf("should get webhook for webhook_id:%s", w.ID), func(t *testing.T) { t.Run(fmt.Sprintf("should get webhook for webhook_id:%s", w.ID), func(t *testing.T) {
webhookLogs, err := resolvers.WebhookLogsResolver(ctx, &model.ListWebhookLogRequest{ webhookLogs, err := resolvers.WebhookLogsResolver(ctx, &model.ListWebhookLogRequest{
WebhookID: &w.ID, WebhookID: &w.ID,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.GreaterOrEqual(t, len(webhookLogs.WebhookLogs), 1) assert.GreaterOrEqual(t, len(webhookLogs.WebhookLogs), 1, refs.StringValue(w.EventName))
for _, wl := range webhookLogs.WebhookLogs { for _, wl := range webhookLogs.WebhookLogs {
assert.Equal(t, refs.StringValue(wl.WebhookID), w.ID) assert.Equal(t, refs.StringValue(wl.WebhookID), w.ID)
} }

View File

@@ -25,18 +25,20 @@ func webhookTest(t *testing.T, s TestSetup) {
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h))
// get webhook by event name // get webhook by event name
webhook, err := db.Provider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent) webhooks, err := db.Provider.GetWebhookByEventName(ctx, constants.UserCreatedWebhookEvent)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, webhook) assert.NotNil(t, webhooks)
assert.Equal(t, 2, len(webhooks))
res, err := resolvers.WebhookResolver(ctx, model.WebhookRequest{ for _, webhook := range webhooks {
ID: webhook.ID, res, err := resolvers.WebhookResolver(ctx, model.WebhookRequest{
}) ID: webhook.ID,
assert.NoError(t, err) })
assert.Equal(t, res.ID, webhook.ID) assert.NoError(t, err)
assert.Equal(t, refs.StringValue(res.Endpoint), refs.StringValue(webhook.Endpoint)) assert.Equal(t, res.ID, webhook.ID)
assert.Equal(t, refs.StringValue(res.EventName), refs.StringValue(webhook.EventName)) assert.Equal(t, refs.StringValue(res.Endpoint), refs.StringValue(webhook.Endpoint))
assert.Equal(t, refs.BoolValue(res.Enabled), refs.BoolValue(webhook.Enabled)) // assert.Equal(t, refs.StringValue(res.EventName), refs.StringValue(webhook.EventName))
assert.Len(t, res.Headers, len(webhook.Headers)) assert.Equal(t, refs.BoolValue(res.Enabled), refs.BoolValue(webhook.Enabled))
assert.Len(t, res.Headers, len(webhook.Headers))
}
}) })
} }

View File

@@ -6,7 +6,9 @@ import (
"github.com/authorizerdev/authorizer/server/constants" "github.com/authorizerdev/authorizer/server/constants"
"github.com/authorizerdev/authorizer/server/crypto" "github.com/authorizerdev/authorizer/server/crypto"
"github.com/authorizerdev/authorizer/server/graph/model"
"github.com/authorizerdev/authorizer/server/memorystore" "github.com/authorizerdev/authorizer/server/memorystore"
"github.com/authorizerdev/authorizer/server/refs"
"github.com/authorizerdev/authorizer/server/resolvers" "github.com/authorizerdev/authorizer/server/resolvers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -21,9 +23,13 @@ func webhooksTest(t *testing.T, s TestSetup) {
assert.NoError(t, err) assert.NoError(t, err)
req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h)) req.Header.Set("Cookie", fmt.Sprintf("%s=%s", constants.AdminCookieName, h))
webhooks, err := resolvers.WebhooksResolver(ctx, nil) webhooks, err := resolvers.WebhooksResolver(ctx, &model.PaginatedInput{
Pagination: &model.PaginationInput{
Limit: refs.NewInt64Ref(20),
},
})
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, webhooks) assert.NotEmpty(t, webhooks)
assert.Len(t, webhooks.Webhooks, len(s.TestInfo.TestWebhookEventTypes)) assert.Len(t, webhooks.Webhooks, len(s.TestInfo.TestWebhookEventTypes)*2)
}) })
} }

View File

@@ -30,11 +30,13 @@ type JWTToken struct {
// Token object to hold the finger print and refresh token information // Token object to hold the finger print and refresh token information
type Token struct { type Token struct {
FingerPrint string `json:"fingerprint"` FingerPrint string `json:"fingerprint"`
FingerPrintHash string `json:"fingerprint_hash"` // Session Token
RefreshToken *JWTToken `json:"refresh_token"` FingerPrintHash string `json:"fingerprint_hash"`
AccessToken *JWTToken `json:"access_token"` SessionTokenExpiresAt int64 `json:"expires_at"`
IDToken *JWTToken `json:"id_token"` RefreshToken *JWTToken `json:"refresh_token"`
AccessToken *JWTToken `json:"access_token"`
IDToken *JWTToken `json:"id_token"`
} }
// SessionData // SessionData
@@ -51,7 +53,7 @@ type SessionData struct {
// CreateAuthToken creates a new auth token when userlogs in // CreateAuthToken creates a new auth token when userlogs in
func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, loginMethod, nonce string, code string) (*Token, error) { func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, loginMethod, nonce string, code string) (*Token, error) {
hostname := parsers.GetHost(gc) hostname := parsers.GetHost(gc)
_, fingerPrintHash, err := CreateSessionToken(user, nonce, roles, scope, loginMethod) _, fingerPrintHash, sessionTokenExpiresAt, err := CreateSessionToken(user, nonce, roles, scope, loginMethod)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -82,10 +84,11 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, l
} }
res := &Token{ res := &Token{
FingerPrint: nonce, FingerPrint: nonce,
FingerPrintHash: fingerPrintHash, FingerPrintHash: fingerPrintHash,
AccessToken: &JWTToken{Token: accessToken, ExpiresAt: accessTokenExpiresAt}, SessionTokenExpiresAt: sessionTokenExpiresAt,
IDToken: &JWTToken{Token: idToken, ExpiresAt: idTokenExpiresAt}, AccessToken: &JWTToken{Token: accessToken, ExpiresAt: accessTokenExpiresAt},
IDToken: &JWTToken{Token: idToken, ExpiresAt: idTokenExpiresAt},
} }
if utils.StringSliceContains(scope, "offline_access") { if utils.StringSliceContains(scope, "offline_access") {
@@ -101,7 +104,8 @@ func CreateAuthToken(gc *gin.Context, user models.User, roles, scope []string, l
} }
// CreateSessionToken creates a new session token // CreateSessionToken creates a new session token
func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, error) { func CreateSessionToken(user models.User, nonce string, roles, scope []string, loginMethod string) (*SessionData, string, int64, error) {
expiresAt := time.Now().AddDate(1, 0, 0).Unix()
fingerPrintMap := &SessionData{ fingerPrintMap := &SessionData{
Nonce: nonce, Nonce: nonce,
Roles: roles, Roles: roles,
@@ -109,15 +113,15 @@ func CreateSessionToken(user models.User, nonce string, roles, scope []string, l
Scope: scope, Scope: scope,
LoginMethod: loginMethod, LoginMethod: loginMethod,
IssuedAt: time.Now().Unix(), IssuedAt: time.Now().Unix(),
ExpiresAt: time.Now().AddDate(1, 0, 0).Unix(), ExpiresAt: expiresAt,
} }
fingerPrintBytes, _ := json.Marshal(fingerPrintMap) fingerPrintBytes, _ := json.Marshal(fingerPrintMap)
fingerPrintHash, err := crypto.EncryptAES(string(fingerPrintBytes)) fingerPrintHash, err := crypto.EncryptAES(string(fingerPrintBytes))
if err != nil { if err != nil {
return nil, "", err return nil, "", 0, err
} }
return fingerPrintMap, fingerPrintHash, nil return fingerPrintMap, fingerPrintHash, expiresAt, nil
} }
// CreateRefreshToken util to create JWT token // CreateRefreshToken util to create JWT token

View File

@@ -17,98 +17,101 @@ import (
) )
func RegisterEvent(ctx context.Context, eventName string, authRecipe string, user models.User) error { func RegisterEvent(ctx context.Context, eventName string, authRecipe string, user models.User) error {
webhook, err := db.Provider.GetWebhookByEventName(ctx, eventName) webhooks, err := db.Provider.GetWebhookByEventName(ctx, eventName)
if err != nil { if err != nil {
log.Debug("Error getting webhook: %v", err)
return err return err
} }
for _, webhook := range webhooks {
if !refs.BoolValue(webhook.Enabled) {
continue
}
userBytes, err := json.Marshal(user.AsAPIUser())
if err != nil {
log.Debug("error marshalling user obj: ", err)
continue
}
userMap := map[string]interface{}{}
err = json.Unmarshal(userBytes, &userMap)
if err != nil {
log.Debug("error un-marshalling user obj: ", err)
continue
}
if !refs.BoolValue(webhook.Enabled) { reqBody := map[string]interface{}{
return nil "webhook_id": webhook.ID,
} "event_name": eventName,
"event_description": webhook.EventDescription,
"user": userMap,
}
userBytes, err := json.Marshal(user.AsAPIUser()) if eventName == constants.UserLoginWebhookEvent || eventName == constants.UserSignUpWebhookEvent {
if err != nil { reqBody["auth_recipe"] = authRecipe
log.Debug("error marshalling user obj: ", err) }
return err
}
userMap := map[string]interface{}{}
err = json.Unmarshal(userBytes, &userMap)
if err != nil {
log.Debug("error un-marshalling user obj: ", err)
return err
}
reqBody := map[string]interface{}{ requestBody, err := json.Marshal(reqBody)
"event_name": eventName, if err != nil {
"user": userMap, log.Debug("error marshalling requestBody obj: ", err)
} continue
}
if eventName == constants.UserLoginWebhookEvent || eventName == constants.UserSignUpWebhookEvent { // dont trigger webhook call in case of test
reqBody["auth_recipe"] = authRecipe envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv)
} if err != nil {
continue
}
if envKey == constants.TestEnv {
_, err := db.Provider.AddWebhookLog(ctx, models.WebhookLog{
HttpStatus: 200,
Request: string(requestBody),
Response: string(`{"message": "test"}`),
WebhookID: webhook.ID,
})
if err != nil {
log.Debug("error saving webhook log:", err)
}
continue
}
requestBody, err := json.Marshal(reqBody) requestBytesBuffer := bytes.NewBuffer(requestBody)
if err != nil { req, err := http.NewRequest("POST", refs.StringValue(webhook.Endpoint), requestBytesBuffer)
log.Debug("error marshalling requestBody obj: ", err) if err != nil {
return err log.Debug("error creating webhook post request: ", err)
} continue
}
req.Header.Set("Content-Type", "application/json")
// dont trigger webhook call in case of test if webhook.Headers != nil {
envKey, err := memorystore.Provider.GetStringStoreEnvVariable(constants.EnvKeyEnv) for key, val := range webhook.Headers {
if err != nil { req.Header.Set(key, val.(string))
return err }
} }
if envKey == constants.TestEnv {
db.Provider.AddWebhookLog(ctx, models.WebhookLog{ client := &http.Client{Timeout: time.Second * 30}
HttpStatus: 200, resp, err := client.Do(req)
if err != nil {
log.Debug("error making request: ", err)
continue
}
defer resp.Body.Close()
responseBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Debug("error reading response: ", err)
continue
}
statusCode := int64(resp.StatusCode)
_, err = db.Provider.AddWebhookLog(ctx, models.WebhookLog{
HttpStatus: statusCode,
Request: string(requestBody), Request: string(requestBody),
Response: string(`{"message": "test"}`), Response: string(responseBytes),
WebhookID: webhook.ID, WebhookID: webhook.ID,
}) })
if err != nil {
return nil log.Debug("failed to add webhook log: ", err)
} continue
requestBytesBuffer := bytes.NewBuffer(requestBody)
req, err := http.NewRequest("POST", refs.StringValue(webhook.Endpoint), requestBytesBuffer)
if err != nil {
log.Debug("error creating webhook post request: ", err)
return err
}
req.Header.Set("Content-Type", "application/json")
if webhook.Headers != nil {
for key, val := range webhook.Headers {
req.Header.Set(key, val.(string))
} }
} }
client := &http.Client{Timeout: time.Second * 30}
resp, err := client.Do(req)
if err != nil {
log.Debug("error making request: ", err)
return err
}
defer resp.Body.Close()
responseBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Debug("error reading response: ", err)
return err
}
statusCode := int64(resp.StatusCode)
_, err = db.Provider.AddWebhookLog(ctx, models.WebhookLog{
HttpStatus: statusCode,
Request: string(requestBody),
Response: string(responseBytes),
WebhookID: webhook.ID,
})
if err != nil {
log.Debug("failed to add webhook log: ", err)
return err
}
return nil return nil
} }