first commit
This commit is contained in:
334
pkg/auth/store.go
Normal file
334
pkg/auth/store.go
Normal file
@ -0,0 +1,334 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
jsonIterator "github.com/json-iterator/go"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/tidwall/buntdb"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
jsonMarshal = jsonIterator.Marshal
|
||||
jsonUnmarshal = jsonIterator.Unmarshal
|
||||
)
|
||||
|
||||
type TokenStore interface {
|
||||
Create(info TokenInfo) error
|
||||
RemoveByAccess(access string) error
|
||||
RemoveByRefresh(refresh string) error
|
||||
GetByAccess(access string) (TokenInfo, error)
|
||||
GetByRefresh(refresh string) (TokenInfo, error)
|
||||
}
|
||||
|
||||
// NewMemoryTokenStore create a token buntStore instance based on memory
|
||||
func NewMemoryTokenStore() (TokenStore, error) {
|
||||
return NewFileTokenStore(":memory:")
|
||||
}
|
||||
|
||||
// NewFileTokenStore create a token buntStore instance based on file
|
||||
func NewFileTokenStore(filename string) (TokenStore, error) {
|
||||
db, err := buntdb.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &buntStore{db: db}, nil
|
||||
}
|
||||
|
||||
// buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
|
||||
type buntStore struct {
|
||||
db *buntdb.DB
|
||||
}
|
||||
|
||||
func (ts *buntStore) remove(key string) error {
|
||||
err := ts.db.Update(func(tx *buntdb.Tx) error {
|
||||
_, err := tx.Delete(key)
|
||||
return err
|
||||
})
|
||||
if errors.Is(err, buntdb.ErrNotFound) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (ts *buntStore) getData(key string) (TokenInfo, error) {
|
||||
var ti TokenInfo
|
||||
err := ts.db.View(func(tx *buntdb.Tx) error {
|
||||
jv, err := tx.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var tm Token
|
||||
err = jsonUnmarshal([]byte(jv), &tm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ti = &tm
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if err == buntdb.ErrNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
func (ts *buntStore) getBasicID(key string) (string, error) {
|
||||
var basicID string
|
||||
err := ts.db.View(func(tx *buntdb.Tx) error {
|
||||
v, err := tx.Get(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
basicID = v
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if err == buntdb.ErrNotFound {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return basicID, nil
|
||||
}
|
||||
|
||||
// Create and buntStore the new token information
|
||||
func (ts *buntStore) Create(info TokenInfo) error {
|
||||
ct := time.Now()
|
||||
jv, err := jsonMarshal(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ts.db.Update(func(tx *buntdb.Tx) error {
|
||||
basicID := uuid.Must(uuid.NewRandom()).String()
|
||||
aexp := info.GetAccessExpiresIn()
|
||||
rexp := aexp
|
||||
expires := true
|
||||
if refresh := info.GetRefresh(); refresh != "" {
|
||||
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
|
||||
if aexp.Seconds() > rexp.Seconds() {
|
||||
aexp = rexp
|
||||
}
|
||||
expires = info.GetRefreshExpiresIn() != 0
|
||||
_, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveByAccess use the access token to delete the token information
|
||||
func (ts *buntStore) RemoveByAccess(access string) error {
|
||||
return ts.remove(access)
|
||||
}
|
||||
|
||||
// RemoveByRefresh use the refresh token to delete the token information
|
||||
func (ts *buntStore) RemoveByRefresh(refresh string) error {
|
||||
return ts.remove(refresh)
|
||||
}
|
||||
|
||||
// GetByAccess use the access token for token information data
|
||||
func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
|
||||
basicID, err := ts.getBasicID(access)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ts.getData(basicID)
|
||||
}
|
||||
|
||||
// GetByRefresh use the refresh token for token information data
|
||||
func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
|
||||
basicID, err := ts.getBasicID(refresh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ts.getData(basicID)
|
||||
}
|
||||
|
||||
/*------------------------------------------------------------------------------------*/
|
||||
|
||||
// NewRedisStoreWithCli create an instance of a redis store
|
||||
func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
|
||||
store := &redisStore{
|
||||
cli: cli,
|
||||
ctx: context.TODO(),
|
||||
ns: keyNamespace,
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
// TokenStore redis token store
|
||||
type redisStore struct {
|
||||
cli *redis.Client
|
||||
ctx context.Context
|
||||
ns string
|
||||
}
|
||||
|
||||
func (s *redisStore) wrapperKey(key string) string {
|
||||
return fmt.Sprintf("%s%s", s.ns, key)
|
||||
}
|
||||
|
||||
func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
|
||||
if err := result.Err(); err != nil {
|
||||
if err == redis.Nil {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *redisStore) remove(key string) error {
|
||||
result := s.cli.Del(s.ctx, s.wrapperKey(key))
|
||||
_, err := s.checkError(result)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
|
||||
basicID, err := s.getBasicID(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if basicID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = s.remove(tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := s.getToken(basicID)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if token == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
checkToken := token.GetRefresh()
|
||||
if isRefresh {
|
||||
checkToken = token.GetAccess()
|
||||
}
|
||||
result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
|
||||
if err = result.Err(); err != nil && err != redis.Nil {
|
||||
return err
|
||||
} else if result.Val() == 0 {
|
||||
return s.remove(basicID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
|
||||
if ok, err := s.checkError(result); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf, err := result.Bytes()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var token Token
|
||||
if err = jsonUnmarshal(buf, &token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (s *redisStore) getToken(key string) (TokenInfo, error) {
|
||||
result := s.cli.Get(s.ctx, s.wrapperKey(key))
|
||||
return s.parseToken(result)
|
||||
}
|
||||
|
||||
func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
|
||||
if ok, err := s.checkError(result); err != nil {
|
||||
return "", err
|
||||
} else if ok {
|
||||
return "", nil
|
||||
}
|
||||
return result.Val(), nil
|
||||
}
|
||||
|
||||
func (s *redisStore) getBasicID(token string) (string, error) {
|
||||
result := s.cli.Get(s.ctx, s.wrapperKey(token))
|
||||
return s.parseBasicID(result)
|
||||
}
|
||||
|
||||
// Create and store the new token information
|
||||
func (s *redisStore) Create(info TokenInfo) error {
|
||||
ct := time.Now()
|
||||
jv, err := jsonMarshal(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := s.cli.TxPipeline()
|
||||
basicID := uuid.Must(uuid.NewRandom()).String()
|
||||
aexp := info.GetAccessExpiresIn()
|
||||
rexp := aexp
|
||||
|
||||
if refresh := info.GetRefresh(); refresh != "" {
|
||||
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
|
||||
if aexp.Seconds() > rexp.Seconds() {
|
||||
aexp = rexp
|
||||
}
|
||||
pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
|
||||
}
|
||||
|
||||
pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
|
||||
pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
|
||||
|
||||
if _, err = pipe.Exec(s.ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveByAccess Use the access token to delete the token information
|
||||
func (s *redisStore) RemoveByAccess(access string) error {
|
||||
return s.removeToken(access, false)
|
||||
}
|
||||
|
||||
// RemoveByRefresh Use the refresh token to delete the token information
|
||||
func (s *redisStore) RemoveByRefresh(refresh string) error {
|
||||
return s.removeToken(refresh, true)
|
||||
}
|
||||
|
||||
// GetByAccess Use the access token for token information data
|
||||
func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
|
||||
basicID, err := s.getBasicID(access)
|
||||
if err != nil || basicID == "" {
|
||||
return nil, err
|
||||
}
|
||||
return s.getToken(basicID)
|
||||
}
|
||||
|
||||
// GetByRefresh Use the refresh token for token information data
|
||||
func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
|
||||
basicID, err := s.getBasicID(refresh)
|
||||
if err != nil || basicID == "" {
|
||||
return nil, err
|
||||
}
|
||||
return s.getToken(basicID)
|
||||
}
|
Reference in New Issue
Block a user