first commit
This commit is contained in:
		
							
								
								
									
										29
									
								
								pkg/auth/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								pkg/auth/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
// Config authorization configuration parameters
 | 
			
		||||
type Config struct {
 | 
			
		||||
	// access token expiration time, 0 means it doesn't expire
 | 
			
		||||
	AccessTokenExp time.Duration
 | 
			
		||||
	// refresh token expiration time, 0 means it doesn't expire
 | 
			
		||||
	RefreshTokenExp time.Duration
 | 
			
		||||
	// whether to generate the refreshing token
 | 
			
		||||
	IsGenerateRefresh bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RefreshConfig refreshing token config
 | 
			
		||||
type RefreshConfig struct {
 | 
			
		||||
	// whether to reset the refreshing creation time
 | 
			
		||||
	IsResetRefreshTime bool
 | 
			
		||||
	// whether to remove access token
 | 
			
		||||
	IsRemoveAccess bool
 | 
			
		||||
	// whether to remove refreshing token
 | 
			
		||||
	IsRemoveRefreshing bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// default configs
 | 
			
		||||
var (
 | 
			
		||||
	DefaultAccessTokenCfg  = &Config{AccessTokenExp: time.Hour * 24, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true}
 | 
			
		||||
	DefaultRefreshTokenCfg = &RefreshConfig{IsResetRefreshTime: true, IsRemoveAccess: true, IsRemoveRefreshing: true}
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										12
									
								
								pkg/auth/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								pkg/auth/error.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import "errors"
 | 
			
		||||
 | 
			
		||||
var New = errors.New
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrInvalidAccessToken  = errors.New("invalid access token")
 | 
			
		||||
	ErrInvalidRefreshToken = errors.New("invalid refresh token")
 | 
			
		||||
	ErrExpiredAccessToken  = errors.New("expired access token")
 | 
			
		||||
	ErrExpiredRefreshToken = errors.New("expired refresh token")
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										17
									
								
								pkg/auth/generate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								pkg/auth/generate.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type (
 | 
			
		||||
	GenerateBasic struct {
 | 
			
		||||
		UserID    string
 | 
			
		||||
		CreateAt  time.Time
 | 
			
		||||
		TokenInfo TokenInfo
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	AccessGenerate interface {
 | 
			
		||||
		Token(data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										97
									
								
								pkg/auth/jwt_access.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								pkg/auth/jwt_access.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/golang-jwt/jwt/v4"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// JWTAccessClaims jwt claims
 | 
			
		||||
type JWTAccessClaims struct {
 | 
			
		||||
	jwt.RegisteredClaims
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Valid claims verification
 | 
			
		||||
func (a *JWTAccessClaims) Valid() error {
 | 
			
		||||
	if a.ExpiresAt.Before(time.Now()) {
 | 
			
		||||
		return ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewJWTAccessGenerate create to generate the jwt access token instance
 | 
			
		||||
func NewJWTAccessGenerate(key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
 | 
			
		||||
	return &JWTAccessGenerate{
 | 
			
		||||
		SignedKey:    key,
 | 
			
		||||
		SignedMethod: method,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JWTAccessGenerate generate the jwt access token
 | 
			
		||||
type JWTAccessGenerate struct {
 | 
			
		||||
	SignedKey    []byte
 | 
			
		||||
	SignedMethod jwt.SigningMethod
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Token based on the UUID generated token
 | 
			
		||||
func (a *JWTAccessGenerate) Token(data *GenerateBasic, isGenRefresh bool) (string, string, error) {
 | 
			
		||||
	claims := &JWTAccessClaims{
 | 
			
		||||
		RegisteredClaims: jwt.RegisteredClaims{
 | 
			
		||||
			Issuer:    "BvBeJ",
 | 
			
		||||
			Subject:   data.UserID,
 | 
			
		||||
			ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := jwt.NewWithClaims(a.SignedMethod, claims)
 | 
			
		||||
	var key any
 | 
			
		||||
	if a.isEs() {
 | 
			
		||||
		v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", "", err
 | 
			
		||||
		}
 | 
			
		||||
		key = v
 | 
			
		||||
	} else if a.isRsOrPS() {
 | 
			
		||||
		v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", "", err
 | 
			
		||||
		}
 | 
			
		||||
		key = v
 | 
			
		||||
	} else if a.isHs() {
 | 
			
		||||
		key = a.SignedKey
 | 
			
		||||
	} else {
 | 
			
		||||
		return "", "", errors.New("unsupported sign method")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	access, err := token.SignedString(key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", "", err
 | 
			
		||||
	}
 | 
			
		||||
	refresh := ""
 | 
			
		||||
 | 
			
		||||
	if isGenRefresh {
 | 
			
		||||
		t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
 | 
			
		||||
		refresh = base64.URLEncoding.EncodeToString([]byte(t))
 | 
			
		||||
		refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return access, refresh, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *JWTAccessGenerate) isEs() bool {
 | 
			
		||||
	return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *JWTAccessGenerate) isRsOrPS() bool {
 | 
			
		||||
	isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
 | 
			
		||||
	isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
 | 
			
		||||
	return isRs || isPs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *JWTAccessGenerate) isHs() bool {
 | 
			
		||||
	return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										194
									
								
								pkg/auth/manager.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								pkg/auth/manager.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,194 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewManager create to authorization management instance
 | 
			
		||||
func NewManager(ag AccessGenerate, ts TokenStore) *Manager {
 | 
			
		||||
	return &Manager{
 | 
			
		||||
		cfg:            DefaultAccessTokenCfg,
 | 
			
		||||
		rCfg:           DefaultRefreshTokenCfg,
 | 
			
		||||
		accessGenerate: ag,
 | 
			
		||||
		tokenStore:     ts,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetConfig mapping the access token generate config
 | 
			
		||||
func (m *Manager) SetConfig(cfg *Config) {
 | 
			
		||||
	m.cfg = cfg
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefreshTokenConfig  mapping the token refresh config
 | 
			
		||||
func (m *Manager) SetRefreshTokenConfig(store *RefreshConfig) {
 | 
			
		||||
	m.rCfg = store
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Manager provide authorization management
 | 
			
		||||
type Manager struct {
 | 
			
		||||
	cfg            *Config
 | 
			
		||||
	rCfg           *RefreshConfig
 | 
			
		||||
	accessGenerate AccessGenerate
 | 
			
		||||
	tokenStore     TokenStore
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateAccessToken generate the access token
 | 
			
		||||
func (m *Manager) GenerateAccessToken(userID string) (TokenInfo, error) {
 | 
			
		||||
	ti := NewToken()
 | 
			
		||||
	ti.SetUserID(userID)
 | 
			
		||||
 | 
			
		||||
	createAt := time.Now()
 | 
			
		||||
	ti.SetAccessCreateAt(createAt)
 | 
			
		||||
 | 
			
		||||
	// set access token expires
 | 
			
		||||
	ti.SetAccessExpiresIn(m.cfg.AccessTokenExp)
 | 
			
		||||
	if m.cfg.IsGenerateRefresh {
 | 
			
		||||
		ti.SetRefreshCreateAt(createAt)
 | 
			
		||||
		ti.SetRefreshExpiresIn(m.cfg.RefreshTokenExp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	td := &GenerateBasic{
 | 
			
		||||
		UserID:    userID,
 | 
			
		||||
		CreateAt:  createAt,
 | 
			
		||||
		TokenInfo: ti,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	av, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ti.SetAccess(av)
 | 
			
		||||
 | 
			
		||||
	if rv != "" {
 | 
			
		||||
		ti.SetRefresh(rv)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = m.tokenStore.Create(ti)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RefreshAccessToken refreshing an access token
 | 
			
		||||
func (m *Manager) RefreshAccessToken(refresh string) (TokenInfo, error) {
 | 
			
		||||
	ti, err := m.LoadRefreshToken(refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
 | 
			
		||||
 | 
			
		||||
	td := &GenerateBasic{
 | 
			
		||||
		UserID:    ti.GetUserID(),
 | 
			
		||||
		CreateAt:  time.Now(),
 | 
			
		||||
		TokenInfo: ti,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ti.SetAccessCreateAt(td.CreateAt)
 | 
			
		||||
	if v := m.cfg.AccessTokenExp; v > 0 {
 | 
			
		||||
		ti.SetAccessExpiresIn(v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if v := m.cfg.RefreshTokenExp; v > 0 {
 | 
			
		||||
		ti.SetRefreshExpiresIn(v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.rCfg.IsResetRefreshTime {
 | 
			
		||||
		ti.SetRefreshCreateAt(td.CreateAt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tv, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ti.SetAccess(tv)
 | 
			
		||||
	if rv != "" {
 | 
			
		||||
		ti.SetRefresh(rv)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = m.tokenStore.Create(ti); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.rCfg.IsRemoveAccess {
 | 
			
		||||
		// remove the old access token
 | 
			
		||||
		if err = m.tokenStore.RemoveByAccess(oldAccess); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.rCfg.IsRemoveRefreshing && rv != "" {
 | 
			
		||||
		// remove the old refresh token
 | 
			
		||||
		if err = m.tokenStore.RemoveByRefresh(oldRefresh); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if rv == "" {
 | 
			
		||||
		ti.SetRefresh("")
 | 
			
		||||
		ti.SetRefreshCreateAt(time.Now())
 | 
			
		||||
		ti.SetRefreshExpiresIn(0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveAccessToken use the access token to delete the token information
 | 
			
		||||
func (m *Manager) RemoveAccessToken(access string) error {
 | 
			
		||||
	if access == "" {
 | 
			
		||||
		return ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return m.tokenStore.RemoveByAccess(access)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveRefreshToken use the refresh token to delete the token information
 | 
			
		||||
func (m *Manager) RemoveRefreshToken(refresh string) error {
 | 
			
		||||
	if refresh == "" {
 | 
			
		||||
		return ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return m.tokenStore.RemoveByRefresh(refresh)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LoadAccessToken according to the access token for corresponding token information
 | 
			
		||||
func (m *Manager) LoadAccessToken(access string) (TokenInfo, error) {
 | 
			
		||||
	if access == "" {
 | 
			
		||||
		return nil, ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ct := time.Now()
 | 
			
		||||
	ti, err := m.tokenStore.GetByAccess(access)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if ti == nil || ti.GetAccess() != access {
 | 
			
		||||
		return nil, ErrInvalidAccessToken
 | 
			
		||||
	} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
 | 
			
		||||
		ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
 | 
			
		||||
		return nil, ErrExpiredRefreshToken
 | 
			
		||||
	} else if ti.GetAccessExpiresIn() != 0 &&
 | 
			
		||||
		ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
 | 
			
		||||
		return nil, ErrExpiredAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LoadRefreshToken according to the refresh token for corresponding token information
 | 
			
		||||
func (m *Manager) LoadRefreshToken(refresh string) (TokenInfo, error) {
 | 
			
		||||
	if refresh == "" {
 | 
			
		||||
		return nil, ErrInvalidRefreshToken
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ti, err := m.tokenStore.GetByRefresh(refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if ti == nil || ti.GetRefresh() != refresh {
 | 
			
		||||
		return nil, ErrInvalidRefreshToken
 | 
			
		||||
	} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
 | 
			
		||||
		ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
 | 
			
		||||
		return nil, ErrExpiredRefreshToken
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										334
									
								
								pkg/auth/store.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										334
									
								
								pkg/auth/store.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,334 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	jsonIterator "github.com/json-iterator/go"
 | 
			
		||||
	"github.com/redis/go-redis/v9"
 | 
			
		||||
	"github.com/tidwall/buntdb"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	jsonMarshal   = jsonIterator.Marshal
 | 
			
		||||
	jsonUnmarshal = jsonIterator.Unmarshal
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TokenStore interface {
 | 
			
		||||
	Create(info TokenInfo) error
 | 
			
		||||
	RemoveByAccess(access string) error
 | 
			
		||||
	RemoveByRefresh(refresh string) error
 | 
			
		||||
	GetByAccess(access string) (TokenInfo, error)
 | 
			
		||||
	GetByRefresh(refresh string) (TokenInfo, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewMemoryTokenStore create a token buntStore instance based on memory
 | 
			
		||||
func NewMemoryTokenStore() (TokenStore, error) {
 | 
			
		||||
	return NewFileTokenStore(":memory:")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewFileTokenStore create a token buntStore instance based on file
 | 
			
		||||
func NewFileTokenStore(filename string) (TokenStore, error) {
 | 
			
		||||
	db, err := buntdb.Open(filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &buntStore{db: db}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
 | 
			
		||||
type buntStore struct {
 | 
			
		||||
	db *buntdb.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *buntStore) remove(key string) error {
 | 
			
		||||
	err := ts.db.Update(func(tx *buntdb.Tx) error {
 | 
			
		||||
		_, err := tx.Delete(key)
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
	if errors.Is(err, buntdb.ErrNotFound) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *buntStore) getData(key string) (TokenInfo, error) {
 | 
			
		||||
	var ti TokenInfo
 | 
			
		||||
	err := ts.db.View(func(tx *buntdb.Tx) error {
 | 
			
		||||
		jv, err := tx.Get(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tm Token
 | 
			
		||||
		err = jsonUnmarshal([]byte(jv), &tm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		ti = &tm
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == buntdb.ErrNotFound {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *buntStore) getBasicID(key string) (string, error) {
 | 
			
		||||
	var basicID string
 | 
			
		||||
	err := ts.db.View(func(tx *buntdb.Tx) error {
 | 
			
		||||
		v, err := tx.Get(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		basicID = v
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == buntdb.ErrNotFound {
 | 
			
		||||
			return "", nil
 | 
			
		||||
		}
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return basicID, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create and buntStore the new token information
 | 
			
		||||
func (ts *buntStore) Create(info TokenInfo) error {
 | 
			
		||||
	ct := time.Now()
 | 
			
		||||
	jv, err := jsonMarshal(info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ts.db.Update(func(tx *buntdb.Tx) error {
 | 
			
		||||
		basicID := uuid.Must(uuid.NewRandom()).String()
 | 
			
		||||
		aexp := info.GetAccessExpiresIn()
 | 
			
		||||
		rexp := aexp
 | 
			
		||||
		expires := true
 | 
			
		||||
		if refresh := info.GetRefresh(); refresh != "" {
 | 
			
		||||
			rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
 | 
			
		||||
			if aexp.Seconds() > rexp.Seconds() {
 | 
			
		||||
				aexp = rexp
 | 
			
		||||
			}
 | 
			
		||||
			expires = info.GetRefreshExpiresIn() != 0
 | 
			
		||||
			_, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByAccess use the access token to delete the token information
 | 
			
		||||
func (ts *buntStore) RemoveByAccess(access string) error {
 | 
			
		||||
	return ts.remove(access)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByRefresh use the refresh token to delete the token information
 | 
			
		||||
func (ts *buntStore) RemoveByRefresh(refresh string) error {
 | 
			
		||||
	return ts.remove(refresh)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByAccess use the access token for token information data
 | 
			
		||||
func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := ts.getBasicID(access)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return ts.getData(basicID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByRefresh use the refresh token for token information data
 | 
			
		||||
func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := ts.getBasicID(refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return ts.getData(basicID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*------------------------------------------------------------------------------------*/
 | 
			
		||||
 | 
			
		||||
// NewRedisStoreWithCli create an instance of a redis store
 | 
			
		||||
func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
 | 
			
		||||
	store := &redisStore{
 | 
			
		||||
		cli: cli,
 | 
			
		||||
		ctx: context.TODO(),
 | 
			
		||||
		ns:  keyNamespace,
 | 
			
		||||
	}
 | 
			
		||||
	return store
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TokenStore redis token store
 | 
			
		||||
type redisStore struct {
 | 
			
		||||
	cli *redis.Client
 | 
			
		||||
	ctx context.Context
 | 
			
		||||
	ns  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) wrapperKey(key string) string {
 | 
			
		||||
	return fmt.Sprintf("%s%s", s.ns, key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
 | 
			
		||||
	if err := result.Err(); err != nil {
 | 
			
		||||
		if err == redis.Nil {
 | 
			
		||||
			return true, nil
 | 
			
		||||
		}
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	return false, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) remove(key string) error {
 | 
			
		||||
	result := s.cli.Del(s.ctx, s.wrapperKey(key))
 | 
			
		||||
	_, err := s.checkError(result)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
 | 
			
		||||
	basicID, err := s.getBasicID(tokenString)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if basicID == "" {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = s.remove(tokenString)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token, err := s.getToken(basicID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if token == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	checkToken := token.GetRefresh()
 | 
			
		||||
	if isRefresh {
 | 
			
		||||
		checkToken = token.GetAccess()
 | 
			
		||||
	}
 | 
			
		||||
	result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
 | 
			
		||||
	if err = result.Err(); err != nil && err != redis.Nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if result.Val() == 0 {
 | 
			
		||||
		return s.remove(basicID)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
 | 
			
		||||
	if ok, err := s.checkError(result); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if ok {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buf, err := result.Bytes()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == redis.Nil {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var token Token
 | 
			
		||||
	if err = jsonUnmarshal(buf, &token); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &token, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) getToken(key string) (TokenInfo, error) {
 | 
			
		||||
	result := s.cli.Get(s.ctx, s.wrapperKey(key))
 | 
			
		||||
	return s.parseToken(result)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
 | 
			
		||||
	if ok, err := s.checkError(result); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	} else if ok {
 | 
			
		||||
		return "", nil
 | 
			
		||||
	}
 | 
			
		||||
	return result.Val(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) getBasicID(token string) (string, error) {
 | 
			
		||||
	result := s.cli.Get(s.ctx, s.wrapperKey(token))
 | 
			
		||||
	return s.parseBasicID(result)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create and store the new token information
 | 
			
		||||
func (s *redisStore) Create(info TokenInfo) error {
 | 
			
		||||
	ct := time.Now()
 | 
			
		||||
	jv, err := jsonMarshal(info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pipe := s.cli.TxPipeline()
 | 
			
		||||
	basicID := uuid.Must(uuid.NewRandom()).String()
 | 
			
		||||
	aexp := info.GetAccessExpiresIn()
 | 
			
		||||
	rexp := aexp
 | 
			
		||||
 | 
			
		||||
	if refresh := info.GetRefresh(); refresh != "" {
 | 
			
		||||
		rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
 | 
			
		||||
		if aexp.Seconds() > rexp.Seconds() {
 | 
			
		||||
			aexp = rexp
 | 
			
		||||
		}
 | 
			
		||||
		pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
 | 
			
		||||
	pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
 | 
			
		||||
 | 
			
		||||
	if _, err = pipe.Exec(s.ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByAccess Use the access token to delete the token information
 | 
			
		||||
func (s *redisStore) RemoveByAccess(access string) error {
 | 
			
		||||
	return s.removeToken(access, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByRefresh Use the refresh token to delete the token information
 | 
			
		||||
func (s *redisStore) RemoveByRefresh(refresh string) error {
 | 
			
		||||
	return s.removeToken(refresh, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByAccess Use the access token for token information data
 | 
			
		||||
func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := s.getBasicID(access)
 | 
			
		||||
	if err != nil || basicID == "" {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return s.getToken(basicID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByRefresh Use the refresh token for token information data
 | 
			
		||||
func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := s.getBasicID(refresh)
 | 
			
		||||
	if err != nil || basicID == "" {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return s.getToken(basicID)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										118
									
								
								pkg/auth/token.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								pkg/auth/token.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,118 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// TokenInfo the token information model interface
 | 
			
		||||
type TokenInfo interface {
 | 
			
		||||
	New() TokenInfo
 | 
			
		||||
 | 
			
		||||
	GetUserID() string
 | 
			
		||||
	SetUserID(string)
 | 
			
		||||
 | 
			
		||||
	GetAccess() string
 | 
			
		||||
	SetAccess(string)
 | 
			
		||||
	GetAccessCreateAt() time.Time
 | 
			
		||||
	SetAccessCreateAt(time.Time)
 | 
			
		||||
	GetAccessExpiresIn() time.Duration
 | 
			
		||||
	SetAccessExpiresIn(time.Duration)
 | 
			
		||||
 | 
			
		||||
	GetRefresh() string
 | 
			
		||||
	SetRefresh(string)
 | 
			
		||||
	GetRefreshCreateAt() time.Time
 | 
			
		||||
	SetRefreshCreateAt(time.Time)
 | 
			
		||||
	GetRefreshExpiresIn() time.Duration
 | 
			
		||||
	SetRefreshExpiresIn(time.Duration)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewToken create to token model instance
 | 
			
		||||
func NewToken() *Token {
 | 
			
		||||
	return &Token{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Token token model
 | 
			
		||||
type Token struct {
 | 
			
		||||
	UserID           string        `bson:"UserID"`
 | 
			
		||||
	Access           string        `bson:"Access"`
 | 
			
		||||
	AccessCreateAt   time.Time     `bson:"AccessCreateAt"`
 | 
			
		||||
	AccessExpiresIn  time.Duration `bson:"AccessExpiresIn"`
 | 
			
		||||
	Refresh          string        `bson:"Refresh"`
 | 
			
		||||
	RefreshCreateAt  time.Time     `bson:"RefreshCreateAt"`
 | 
			
		||||
	RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New create to token model instance
 | 
			
		||||
func (t *Token) New() TokenInfo {
 | 
			
		||||
	return NewToken()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetUserID the user id
 | 
			
		||||
func (t *Token) GetUserID() string {
 | 
			
		||||
	return t.UserID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetUserID the user id
 | 
			
		||||
func (t *Token) SetUserID(userID string) {
 | 
			
		||||
	t.UserID = userID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAccess access Token
 | 
			
		||||
func (t *Token) GetAccess() string {
 | 
			
		||||
	return t.Access
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetAccess access Token
 | 
			
		||||
func (t *Token) SetAccess(access string) {
 | 
			
		||||
	t.Access = access
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAccessCreateAt create Time
 | 
			
		||||
func (t *Token) GetAccessCreateAt() time.Time {
 | 
			
		||||
	return t.AccessCreateAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetAccessCreateAt create Time
 | 
			
		||||
func (t *Token) SetAccessCreateAt(createAt time.Time) {
 | 
			
		||||
	t.AccessCreateAt = createAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAccessExpiresIn the lifetime in seconds of the access token
 | 
			
		||||
func (t *Token) GetAccessExpiresIn() time.Duration {
 | 
			
		||||
	return t.AccessExpiresIn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetAccessExpiresIn the lifetime in seconds of the access token
 | 
			
		||||
func (t *Token) SetAccessExpiresIn(exp time.Duration) {
 | 
			
		||||
	t.AccessExpiresIn = exp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefresh refresh Token
 | 
			
		||||
func (t *Token) GetRefresh() string {
 | 
			
		||||
	return t.Refresh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefresh refresh Token
 | 
			
		||||
func (t *Token) SetRefresh(refresh string) {
 | 
			
		||||
	t.Refresh = refresh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefreshCreateAt create Time
 | 
			
		||||
func (t *Token) GetRefreshCreateAt() time.Time {
 | 
			
		||||
	return t.RefreshCreateAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefreshCreateAt create Time
 | 
			
		||||
func (t *Token) SetRefreshCreateAt(createAt time.Time) {
 | 
			
		||||
	t.RefreshCreateAt = createAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefreshExpiresIn the lifetime in seconds of the refresh token
 | 
			
		||||
func (t *Token) GetRefreshExpiresIn() time.Duration {
 | 
			
		||||
	return t.RefreshExpiresIn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefreshExpiresIn the lifetime in seconds of the refresh token
 | 
			
		||||
func (t *Token) SetRefreshExpiresIn(exp time.Duration) {
 | 
			
		||||
	t.RefreshExpiresIn = exp
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user