fix: session storage
This commit is contained in:
14
server/memorystore/providers/inmemory/provider_test.go
Normal file
14
server/memorystore/providers/inmemory/provider_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package inmemory
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/memorystore/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInMemoryProvider(t *testing.T) {
|
||||
p, err := NewInMemoryProvider()
|
||||
assert.NoError(t, err)
|
||||
providers.ProviderTests(t, p)
|
||||
}
|
@@ -8,39 +8,31 @@ import (
|
||||
)
|
||||
|
||||
// SetUserSession sets the user session for given user identifier in form recipe:user_id
|
||||
func (c *provider) SetUserSession(userId, key, token string) error {
|
||||
c.sessionStore.Set(userId, key, token)
|
||||
func (c *provider) SetUserSession(userId, key, token string, expiration int64) error {
|
||||
c.sessionStore.Set(userId, key, token, expiration)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserSession returns value for given session token
|
||||
func (c *provider) GetUserSession(userId, sessionToken string) (string, error) {
|
||||
return c.sessionStore.Get(userId, sessionToken), nil
|
||||
val := c.sessionStore.Get(userId, sessionToken)
|
||||
if val == "" {
|
||||
return "", fmt.Errorf("Not found")
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// DeleteAllUserSessions deletes all the user sessions from in-memory store.
|
||||
func (c *provider) DeleteAllUserSessions(userId string) error {
|
||||
namespaces := []string{
|
||||
constants.AuthRecipeMethodBasicAuth,
|
||||
constants.AuthRecipeMethodMagicLinkLogin,
|
||||
constants.AuthRecipeMethodApple,
|
||||
constants.AuthRecipeMethodFacebook,
|
||||
constants.AuthRecipeMethodGithub,
|
||||
constants.AuthRecipeMethodGoogle,
|
||||
constants.AuthRecipeMethodLinkedIn,
|
||||
constants.AuthRecipeMethodTwitter,
|
||||
constants.AuthRecipeMethodMicrosoft,
|
||||
}
|
||||
|
||||
for _, namespace := range namespaces {
|
||||
c.sessionStore.RemoveAll(namespace + ":" + userId)
|
||||
}
|
||||
c.sessionStore.RemoveAll(userId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUserSession deletes the user session from the in-memory store.
|
||||
func (c *provider) DeleteUserSession(userId, sessionToken string) error {
|
||||
c.sessionStore.Remove(userId, sessionToken)
|
||||
c.sessionStore.Remove(userId, constants.TokenTypeSessionToken+"_"+sessionToken)
|
||||
c.sessionStore.Remove(userId, constants.TokenTypeAccessToken+"_"+sessionToken)
|
||||
c.sessionStore.Remove(userId, constants.TokenTypeRefreshToken+"_"+sessionToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -1,8 +1,15 @@
|
||||
package stores
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Maximum entries to keep in session storage
|
||||
maxCacheSize = 1000
|
||||
)
|
||||
|
||||
// SessionEntry is the struct for entry stored in store
|
||||
@@ -13,15 +20,16 @@ type SessionEntry struct {
|
||||
|
||||
// SessionStore struct to store the env variables
|
||||
type SessionStore struct {
|
||||
mutex sync.Mutex
|
||||
store map[string]map[string]*SessionEntry
|
||||
mutex sync.Mutex
|
||||
store map[string]*SessionEntry
|
||||
itemsToEvict []string
|
||||
}
|
||||
|
||||
// NewSessionStore create a new session store
|
||||
func NewSessionStore() *SessionStore {
|
||||
return &SessionStore{
|
||||
mutex: sync.Mutex{},
|
||||
store: make(map[string]map[string]*SessionEntry),
|
||||
store: make(map[string]*SessionEntry),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,53 +37,59 @@ func NewSessionStore() *SessionStore {
|
||||
func (s *SessionStore) Get(key, subKey string) string {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
return s.store[key][subKey].Value
|
||||
currentTime := time.Now().Unix()
|
||||
k := fmt.Sprintf("%s:%s", key, subKey)
|
||||
if v, ok := s.store[k]; ok {
|
||||
if v.ExpiresAt > currentTime {
|
||||
return v.Value
|
||||
}
|
||||
s.itemsToEvict = append(s.itemsToEvict, k)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Set sets the value of the key in state store
|
||||
func (s *SessionStore) Set(key string, subKey, value string) {
|
||||
func (s *SessionStore) Set(key string, subKey, value string, expiration int64) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if _, ok := s.store[key]; !ok {
|
||||
s.store[key] = make(map[string]string)
|
||||
k := fmt.Sprintf("%s:%s", key, subKey)
|
||||
if _, ok := s.store[k]; !ok {
|
||||
s.store[k] = &SessionEntry{
|
||||
Value: value,
|
||||
ExpiresAt: expiration,
|
||||
// TODO add expire time
|
||||
}
|
||||
}
|
||||
s.store[k] = &SessionEntry{
|
||||
Value: value,
|
||||
ExpiresAt: expiration,
|
||||
// TODO add expire time
|
||||
}
|
||||
s.store[key][subKey] = value
|
||||
}
|
||||
|
||||
// RemoveAll all values for given key
|
||||
func (s *SessionStore) RemoveAll(key string) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
delete(s.store, key)
|
||||
for k := range s.store {
|
||||
if strings.Contains(k, key) {
|
||||
delete(s.store, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove value for given key and subkey
|
||||
func (s *SessionStore) Remove(key, subKey string) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if _, ok := s.store[key]; ok {
|
||||
delete(s.store[key], subKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all the values for given key
|
||||
func (s *SessionStore) GetAll(key string) map[string]string {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if _, ok := s.store[key]; !ok {
|
||||
s.store[key] = make(map[string]string)
|
||||
}
|
||||
return s.store[key]
|
||||
k := fmt.Sprintf("%s:%s", key, subKey)
|
||||
delete(s.store, k)
|
||||
}
|
||||
|
||||
// RemoveByNamespace to delete session for a given namespace example google,github
|
||||
func (s *SessionStore) RemoveByNamespace(namespace string) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
for key := range s.store {
|
||||
if strings.Contains(key, namespace+":") {
|
||||
delete(s.store, key)
|
||||
|
115
server/memorystore/providers/provider_tests.go
Normal file
115
server/memorystore/providers/provider_tests.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// ProviderTests runs all provider tests
|
||||
func ProviderTests(t *testing.T, p Provider) {
|
||||
|
||||
err := p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
// Same user multiple session
|
||||
err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
err = p.SetUserSession("auth_provider:123", "access_token_key1", "test_jwt1123", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
// Different user session
|
||||
err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(5*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
err = p.SetUserSession("auth_provider:124", "access_token_key", "test_jwt124", time.Now().Add(5*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
// Different provider session
|
||||
err = p.SetUserSession("auth_provider1:124", "session_token_key", "test_hash124", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
err = p.SetUserSession("auth_provider1:124", "access_token_key", "test_jwt124", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
// Different provider session
|
||||
err = p.SetUserSession("auth_provider1:123", "session_token_key", "test_hash1123", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
err = p.SetUserSession("auth_provider1:123", "access_token_key", "test_jwt1123", time.Now().Add(60*time.Second).Unix())
|
||||
assert.NoError(t, err)
|
||||
// Get session
|
||||
key, err := p.GetUserSession("auth_provider:123", "session_token_key")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_hash123", key)
|
||||
key, err = p.GetUserSession("auth_provider:123", "access_token_key")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_jwt123", key)
|
||||
key, err = p.GetUserSession("auth_provider:124", "session_token_key")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_hash124", key)
|
||||
key, err = p.GetUserSession("auth_provider:124", "access_token_key")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_jwt124", key)
|
||||
// Expire some tokens and make sure they are empty
|
||||
time.Sleep(5 * time.Second)
|
||||
key, err = p.GetUserSession("auth_provider:124", "session_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:124", "access_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
// Delete user session
|
||||
err = p.DeleteUserSession("auth_provider:123", "key")
|
||||
assert.NoError(t, err)
|
||||
err = p.DeleteUserSession("auth_provider:123", "key")
|
||||
assert.NoError(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:123", "key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:123", "access_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
// Delete all user session
|
||||
err = p.DeleteAllUserSessions("123")
|
||||
assert.NoError(t, err)
|
||||
err = p.DeleteAllUserSessions("123")
|
||||
assert.NoError(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:123", "session_token_key1")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:123", "access_token_key1")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider1:123", "session_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider1:123", "access_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
// Delete namespace
|
||||
err = p.DeleteSessionForNamespace("auth_provider")
|
||||
assert.NoError(t, err)
|
||||
err = p.DeleteSessionForNamespace("auth_provider1")
|
||||
assert.NoError(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:123", "session_token_key1")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:123", "access_token_key1")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider1:123", "session_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider1:123", "access_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:124", "session_token_key1")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider:124", "access_token_key1")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider1:124", "session_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
key, err = p.GetUserSession("auth_provider1:124", "access_token_key")
|
||||
assert.Empty(t, key)
|
||||
assert.Error(t, err)
|
||||
}
|
@@ -3,7 +3,7 @@ package providers
|
||||
// Provider defines current memory store provider
|
||||
type Provider interface {
|
||||
// SetUserSession sets the user session for given user identifier in form recipe:user_id
|
||||
SetUserSession(userId, key, token string) error
|
||||
SetUserSession(userId, key, token string, expiration int64) error
|
||||
// GetUserSession returns the session token for given token
|
||||
GetUserSession(userId, key string) (string, error)
|
||||
// DeleteUserSession deletes the user session
|
||||
|
@@ -32,7 +32,6 @@ type provider struct {
|
||||
// NewRedisProvider returns a new redis provider
|
||||
func NewRedisProvider(redisURL string) (*provider, error) {
|
||||
redisURLHostPortsList := strings.Split(redisURL, ",")
|
||||
|
||||
if len(redisURLHostPortsList) > 1 {
|
||||
opt, err := redis.ParseURL(redisURLHostPortsList[0])
|
||||
if err != nil {
|
||||
|
15
server/memorystore/providers/redis/provider_test.go
Normal file
15
server/memorystore/providers/redis/provider_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/memorystore/providers"
|
||||
)
|
||||
|
||||
func TestRedisProvider(t *testing.T) {
|
||||
p, err := NewRedisProvider("redis://127.0.0.1:6379")
|
||||
assert.NoError(t, err)
|
||||
providers.ProviderTests(t, p)
|
||||
}
|
@@ -3,6 +3,7 @@ package redis
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/authorizerdev/authorizer/server/constants"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -16,8 +17,11 @@ var (
|
||||
)
|
||||
|
||||
// SetUserSession sets the user session for given user identifier in form recipe:user_id
|
||||
func (c *provider) SetUserSession(userId, key, token string) error {
|
||||
err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, 0).Err()
|
||||
func (c *provider) SetUserSession(userId, key, token string, expiration int64) error {
|
||||
currentTime := time.Now()
|
||||
expireTime := time.Unix(expiration, 0)
|
||||
duration := expireTime.Sub(currentTime)
|
||||
err := c.store.Set(c.ctx, fmt.Sprintf("%s:%s", userId, key), token, duration).Err()
|
||||
if err != nil {
|
||||
log.Debug("Error saving user session to redis: ", err)
|
||||
return err
|
||||
@@ -38,37 +42,35 @@ func (c *provider) GetUserSession(userId, key string) (string, error) {
|
||||
func (c *provider) DeleteUserSession(userId, key string) error {
|
||||
if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeSessionToken+"_"+key)).Err(); err != nil {
|
||||
log.Debug("Error deleting user session from redis: ", err)
|
||||
return err
|
||||
fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeSessionToken, key)
|
||||
// continue
|
||||
}
|
||||
if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeAccessToken+"_"+key)).Err(); err != nil {
|
||||
log.Debug("Error deleting user session from redis: ", err)
|
||||
return err
|
||||
fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeAccessToken, key)
|
||||
// continue
|
||||
}
|
||||
if err := c.store.Del(c.ctx, fmt.Sprintf("%s:%s", userId, constants.TokenTypeRefreshToken+"_"+key)).Err(); err != nil {
|
||||
log.Debug("Error deleting user session from redis: ", err)
|
||||
return err
|
||||
fmt.Println("Error deleting user session from redis: ", err, userId, constants.TokenTypeRefreshToken, key)
|
||||
// continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAllUserSessions deletes all the user session from redis
|
||||
func (c *provider) DeleteAllUserSessions(userID string) error {
|
||||
namespaces := []string{
|
||||
constants.AuthRecipeMethodBasicAuth,
|
||||
constants.AuthRecipeMethodMagicLinkLogin,
|
||||
constants.AuthRecipeMethodApple,
|
||||
constants.AuthRecipeMethodFacebook,
|
||||
constants.AuthRecipeMethodGithub,
|
||||
constants.AuthRecipeMethodGoogle,
|
||||
constants.AuthRecipeMethodLinkedIn,
|
||||
constants.AuthRecipeMethodTwitter,
|
||||
constants.AuthRecipeMethodMicrosoft,
|
||||
res := c.store.Keys(c.ctx, fmt.Sprintf("*%s*", userID))
|
||||
if res.Err() != nil {
|
||||
log.Debug("Error getting all user sessions from redis: ", res.Err())
|
||||
return res.Err()
|
||||
}
|
||||
for _, namespace := range namespaces {
|
||||
err := c.store.Del(c.ctx, namespace+":"+userID).Err()
|
||||
keys := res.Val()
|
||||
for _, key := range keys {
|
||||
err := c.store.Del(c.ctx, key).Err()
|
||||
if err != nil {
|
||||
log.Debug("Error deleting all user sessions from redis: ", err)
|
||||
return err
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -76,27 +78,19 @@ func (c *provider) DeleteAllUserSessions(userID string) error {
|
||||
|
||||
// DeleteSessionForNamespace to delete session for a given namespace example google,github
|
||||
func (c *provider) DeleteSessionForNamespace(namespace string) error {
|
||||
var cursor uint64
|
||||
for {
|
||||
keys := []string{}
|
||||
keys, cursor, err := c.store.Scan(c.ctx, cursor, namespace+":*", 0).Result()
|
||||
res := c.store.Keys(c.ctx, fmt.Sprintf("%s:*", namespace))
|
||||
if res.Err() != nil {
|
||||
log.Debug("Error getting all user sessions from redis: ", res.Err())
|
||||
return res.Err()
|
||||
}
|
||||
keys := res.Val()
|
||||
for _, key := range keys {
|
||||
err := c.store.Del(c.ctx, key).Err()
|
||||
if err != nil {
|
||||
log.Debugf("Error scanning keys for %s namespace: %s", namespace, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
err := c.store.Del(c.ctx, key).Err()
|
||||
if err != nil {
|
||||
log.Debugf("Error deleting sessions for %s namespace: %s", namespace, err.Error())
|
||||
return err
|
||||
}
|
||||
}
|
||||
if cursor == 0 { // no more keys
|
||||
break
|
||||
log.Debug("Error deleting all user sessions from redis: ", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user