335 lines
7.6 KiB
Go
335 lines
7.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/google/uuid"
|
|
jsonIterator "github.com/json-iterator/go"
|
|
"github.com/redis/go-redis/v9"
|
|
"github.com/tidwall/buntdb"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
jsonMarshal = jsonIterator.Marshal
|
|
jsonUnmarshal = jsonIterator.Unmarshal
|
|
)
|
|
|
|
type TokenStore interface {
|
|
Create(info TokenInfo) error
|
|
RemoveByAccess(access string) error
|
|
RemoveByRefresh(refresh string) error
|
|
GetByAccess(access string) (TokenInfo, error)
|
|
GetByRefresh(refresh string) (TokenInfo, error)
|
|
}
|
|
|
|
// NewMemoryTokenStore create a token buntStore instance based on memory
|
|
func NewMemoryTokenStore() (TokenStore, error) {
|
|
return NewFileTokenStore(":memory:")
|
|
}
|
|
|
|
// NewFileTokenStore create a token buntStore instance based on file
|
|
func NewFileTokenStore(filename string) (TokenStore, error) {
|
|
db, err := buntdb.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &buntStore{db: db}, nil
|
|
}
|
|
|
|
// buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
|
|
type buntStore struct {
|
|
db *buntdb.DB
|
|
}
|
|
|
|
func (ts *buntStore) remove(key string) error {
|
|
err := ts.db.Update(func(tx *buntdb.Tx) error {
|
|
_, err := tx.Delete(key)
|
|
return err
|
|
})
|
|
if errors.Is(err, buntdb.ErrNotFound) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (ts *buntStore) getData(key string) (TokenInfo, error) {
|
|
var ti TokenInfo
|
|
err := ts.db.View(func(tx *buntdb.Tx) error {
|
|
jv, err := tx.Get(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var tm Token
|
|
err = jsonUnmarshal([]byte(jv), &tm)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ti = &tm
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
if err == buntdb.ErrNotFound {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return ti, nil
|
|
}
|
|
|
|
func (ts *buntStore) getBasicID(key string) (string, error) {
|
|
var basicID string
|
|
err := ts.db.View(func(tx *buntdb.Tx) error {
|
|
v, err := tx.Get(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
basicID = v
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
if err == buntdb.ErrNotFound {
|
|
return "", nil
|
|
}
|
|
return "", err
|
|
}
|
|
return basicID, nil
|
|
}
|
|
|
|
// Create and buntStore the new token information
|
|
func (ts *buntStore) Create(info TokenInfo) error {
|
|
ct := time.Now()
|
|
jv, err := jsonMarshal(info)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return ts.db.Update(func(tx *buntdb.Tx) error {
|
|
basicID := uuid.Must(uuid.NewRandom()).String()
|
|
aexp := info.GetAccessExpiresIn()
|
|
rexp := aexp
|
|
expires := true
|
|
if refresh := info.GetRefresh(); refresh != "" {
|
|
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
|
|
if aexp.Seconds() > rexp.Seconds() {
|
|
aexp = rexp
|
|
}
|
|
expires = info.GetRefreshExpiresIn() != 0
|
|
_, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
_, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
|
|
return err
|
|
})
|
|
}
|
|
|
|
// RemoveByAccess use the access token to delete the token information
|
|
func (ts *buntStore) RemoveByAccess(access string) error {
|
|
return ts.remove(access)
|
|
}
|
|
|
|
// RemoveByRefresh use the refresh token to delete the token information
|
|
func (ts *buntStore) RemoveByRefresh(refresh string) error {
|
|
return ts.remove(refresh)
|
|
}
|
|
|
|
// GetByAccess use the access token for token information data
|
|
func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
|
|
basicID, err := ts.getBasicID(access)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ts.getData(basicID)
|
|
}
|
|
|
|
// GetByRefresh use the refresh token for token information data
|
|
func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
|
|
basicID, err := ts.getBasicID(refresh)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ts.getData(basicID)
|
|
}
|
|
|
|
/*------------------------------------------------------------------------------------*/
|
|
|
|
// NewRedisStoreWithCli create an instance of a redis store
|
|
func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
|
|
store := &redisStore{
|
|
cli: cli,
|
|
ctx: context.TODO(),
|
|
ns: keyNamespace,
|
|
}
|
|
return store
|
|
}
|
|
|
|
// TokenStore redis token store
|
|
type redisStore struct {
|
|
cli *redis.Client
|
|
ctx context.Context
|
|
ns string
|
|
}
|
|
|
|
func (s *redisStore) wrapperKey(key string) string {
|
|
return fmt.Sprintf("%s%s", s.ns, key)
|
|
}
|
|
|
|
func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
|
|
if err := result.Err(); err != nil {
|
|
if err == redis.Nil {
|
|
return true, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (s *redisStore) remove(key string) error {
|
|
result := s.cli.Del(s.ctx, s.wrapperKey(key))
|
|
_, err := s.checkError(result)
|
|
return err
|
|
}
|
|
|
|
func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
|
|
basicID, err := s.getBasicID(tokenString)
|
|
if err != nil {
|
|
return err
|
|
} else if basicID == "" {
|
|
return nil
|
|
}
|
|
|
|
err = s.remove(tokenString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
token, err := s.getToken(basicID)
|
|
if err != nil {
|
|
return err
|
|
} else if token == nil {
|
|
return nil
|
|
}
|
|
|
|
checkToken := token.GetRefresh()
|
|
if isRefresh {
|
|
checkToken = token.GetAccess()
|
|
}
|
|
result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
|
|
if err = result.Err(); err != nil && err != redis.Nil {
|
|
return err
|
|
} else if result.Val() == 0 {
|
|
return s.remove(basicID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
|
|
if ok, err := s.checkError(result); err != nil {
|
|
return nil, err
|
|
} else if ok {
|
|
return nil, nil
|
|
}
|
|
|
|
buf, err := result.Bytes()
|
|
if err != nil {
|
|
if err == redis.Nil {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
var token Token
|
|
if err = jsonUnmarshal(buf, &token); err != nil {
|
|
return nil, err
|
|
}
|
|
return &token, nil
|
|
}
|
|
|
|
func (s *redisStore) getToken(key string) (TokenInfo, error) {
|
|
result := s.cli.Get(s.ctx, s.wrapperKey(key))
|
|
return s.parseToken(result)
|
|
}
|
|
|
|
func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
|
|
if ok, err := s.checkError(result); err != nil {
|
|
return "", err
|
|
} else if ok {
|
|
return "", nil
|
|
}
|
|
return result.Val(), nil
|
|
}
|
|
|
|
func (s *redisStore) getBasicID(token string) (string, error) {
|
|
result := s.cli.Get(s.ctx, s.wrapperKey(token))
|
|
return s.parseBasicID(result)
|
|
}
|
|
|
|
// Create and store the new token information
|
|
func (s *redisStore) Create(info TokenInfo) error {
|
|
ct := time.Now()
|
|
jv, err := jsonMarshal(info)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pipe := s.cli.TxPipeline()
|
|
basicID := uuid.Must(uuid.NewRandom()).String()
|
|
aexp := info.GetAccessExpiresIn()
|
|
rexp := aexp
|
|
|
|
if refresh := info.GetRefresh(); refresh != "" {
|
|
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
|
|
if aexp.Seconds() > rexp.Seconds() {
|
|
aexp = rexp
|
|
}
|
|
pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
|
|
}
|
|
|
|
pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
|
|
pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
|
|
|
|
if _, err = pipe.Exec(s.ctx); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoveByAccess Use the access token to delete the token information
|
|
func (s *redisStore) RemoveByAccess(access string) error {
|
|
return s.removeToken(access, false)
|
|
}
|
|
|
|
// RemoveByRefresh Use the refresh token to delete the token information
|
|
func (s *redisStore) RemoveByRefresh(refresh string) error {
|
|
return s.removeToken(refresh, true)
|
|
}
|
|
|
|
// GetByAccess Use the access token for token information data
|
|
func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
|
|
basicID, err := s.getBasicID(access)
|
|
if err != nil || basicID == "" {
|
|
return nil, err
|
|
}
|
|
return s.getToken(basicID)
|
|
}
|
|
|
|
// GetByRefresh Use the refresh token for token information data
|
|
func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
|
|
basicID, err := s.getBasicID(refresh)
|
|
if err != nil || basicID == "" {
|
|
return nil, err
|
|
}
|
|
return s.getToken(basicID)
|
|
}
|