base-golang/pkg/auth/store.go
2024-07-23 10:23:43 +08:00

335 lines
7.6 KiB
Go

package auth
import (
"context"
"errors"
"fmt"
"github.com/google/uuid"
jsonIterator "github.com/json-iterator/go"
"github.com/redis/go-redis/v9"
"github.com/tidwall/buntdb"
"time"
)
var (
jsonMarshal = jsonIterator.Marshal
jsonUnmarshal = jsonIterator.Unmarshal
)
type TokenStore interface {
Create(info TokenInfo) error
RemoveByAccess(access string) error
RemoveByRefresh(refresh string) error
GetByAccess(access string) (TokenInfo, error)
GetByRefresh(refresh string) (TokenInfo, error)
}
// NewMemoryTokenStore create a token buntStore instance based on memory
func NewMemoryTokenStore() (TokenStore, error) {
return NewFileTokenStore(":memory:")
}
// NewFileTokenStore create a token buntStore instance based on file
func NewFileTokenStore(filename string) (TokenStore, error) {
db, err := buntdb.Open(filename)
if err != nil {
return nil, err
}
return &buntStore{db: db}, nil
}
// buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
type buntStore struct {
db *buntdb.DB
}
func (ts *buntStore) remove(key string) error {
err := ts.db.Update(func(tx *buntdb.Tx) error {
_, err := tx.Delete(key)
return err
})
if errors.Is(err, buntdb.ErrNotFound) {
return nil
}
return err
}
func (ts *buntStore) getData(key string) (TokenInfo, error) {
var ti TokenInfo
err := ts.db.View(func(tx *buntdb.Tx) error {
jv, err := tx.Get(key)
if err != nil {
return err
}
var tm Token
err = jsonUnmarshal([]byte(jv), &tm)
if err != nil {
return err
}
ti = &tm
return nil
})
if err != nil {
if err == buntdb.ErrNotFound {
return nil, nil
}
return nil, err
}
return ti, nil
}
func (ts *buntStore) getBasicID(key string) (string, error) {
var basicID string
err := ts.db.View(func(tx *buntdb.Tx) error {
v, err := tx.Get(key)
if err != nil {
return err
}
basicID = v
return nil
})
if err != nil {
if err == buntdb.ErrNotFound {
return "", nil
}
return "", err
}
return basicID, nil
}
// Create and buntStore the new token information
func (ts *buntStore) Create(info TokenInfo) error {
ct := time.Now()
jv, err := jsonMarshal(info)
if err != nil {
return err
}
return ts.db.Update(func(tx *buntdb.Tx) error {
basicID := uuid.Must(uuid.NewRandom()).String()
aexp := info.GetAccessExpiresIn()
rexp := aexp
expires := true
if refresh := info.GetRefresh(); refresh != "" {
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
if aexp.Seconds() > rexp.Seconds() {
aexp = rexp
}
expires = info.GetRefreshExpiresIn() != 0
_, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
if err != nil {
return err
}
}
_, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
if err != nil {
return err
}
_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
return err
})
}
// RemoveByAccess use the access token to delete the token information
func (ts *buntStore) RemoveByAccess(access string) error {
return ts.remove(access)
}
// RemoveByRefresh use the refresh token to delete the token information
func (ts *buntStore) RemoveByRefresh(refresh string) error {
return ts.remove(refresh)
}
// GetByAccess use the access token for token information data
func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
basicID, err := ts.getBasicID(access)
if err != nil {
return nil, err
}
return ts.getData(basicID)
}
// GetByRefresh use the refresh token for token information data
func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
basicID, err := ts.getBasicID(refresh)
if err != nil {
return nil, err
}
return ts.getData(basicID)
}
/*------------------------------------------------------------------------------------*/
// NewRedisStoreWithCli create an instance of a redis store
func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
store := &redisStore{
cli: cli,
ctx: context.TODO(),
ns: keyNamespace,
}
return store
}
// TokenStore redis token store
type redisStore struct {
cli *redis.Client
ctx context.Context
ns string
}
func (s *redisStore) wrapperKey(key string) string {
return fmt.Sprintf("%s%s", s.ns, key)
}
func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
if err := result.Err(); err != nil {
if err == redis.Nil {
return true, nil
}
return false, err
}
return false, nil
}
func (s *redisStore) remove(key string) error {
result := s.cli.Del(s.ctx, s.wrapperKey(key))
_, err := s.checkError(result)
return err
}
func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
basicID, err := s.getBasicID(tokenString)
if err != nil {
return err
} else if basicID == "" {
return nil
}
err = s.remove(tokenString)
if err != nil {
return err
}
token, err := s.getToken(basicID)
if err != nil {
return err
} else if token == nil {
return nil
}
checkToken := token.GetRefresh()
if isRefresh {
checkToken = token.GetAccess()
}
result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
if err = result.Err(); err != nil && err != redis.Nil {
return err
} else if result.Val() == 0 {
return s.remove(basicID)
}
return nil
}
func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
if ok, err := s.checkError(result); err != nil {
return nil, err
} else if ok {
return nil, nil
}
buf, err := result.Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
var token Token
if err = jsonUnmarshal(buf, &token); err != nil {
return nil, err
}
return &token, nil
}
func (s *redisStore) getToken(key string) (TokenInfo, error) {
result := s.cli.Get(s.ctx, s.wrapperKey(key))
return s.parseToken(result)
}
func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
if ok, err := s.checkError(result); err != nil {
return "", err
} else if ok {
return "", nil
}
return result.Val(), nil
}
func (s *redisStore) getBasicID(token string) (string, error) {
result := s.cli.Get(s.ctx, s.wrapperKey(token))
return s.parseBasicID(result)
}
// Create and store the new token information
func (s *redisStore) Create(info TokenInfo) error {
ct := time.Now()
jv, err := jsonMarshal(info)
if err != nil {
return err
}
pipe := s.cli.TxPipeline()
basicID := uuid.Must(uuid.NewRandom()).String()
aexp := info.GetAccessExpiresIn()
rexp := aexp
if refresh := info.GetRefresh(); refresh != "" {
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
if aexp.Seconds() > rexp.Seconds() {
aexp = rexp
}
pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
}
pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
if _, err = pipe.Exec(s.ctx); err != nil {
return err
}
return nil
}
// RemoveByAccess Use the access token to delete the token information
func (s *redisStore) RemoveByAccess(access string) error {
return s.removeToken(access, false)
}
// RemoveByRefresh Use the refresh token to delete the token information
func (s *redisStore) RemoveByRefresh(refresh string) error {
return s.removeToken(refresh, true)
}
// GetByAccess Use the access token for token information data
func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
basicID, err := s.getBasicID(access)
if err != nil || basicID == "" {
return nil, err
}
return s.getToken(basicID)
}
// GetByRefresh Use the refresh token for token information data
func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
basicID, err := s.getBasicID(refresh)
if err != nil || basicID == "" {
return nil, err
}
return s.getToken(basicID)
}