first commit

This commit is contained in:
2024-07-23 10:23:43 +08:00
commit 7b4c2521a3
126 changed files with 15931 additions and 0 deletions

29
pkg/auth/config.go Normal file
View File

@ -0,0 +1,29 @@
package auth
import "time"
// Config authorization configuration parameters
type Config struct {
// access token expiration time, 0 means it doesn't expire
AccessTokenExp time.Duration
// refresh token expiration time, 0 means it doesn't expire
RefreshTokenExp time.Duration
// whether to generate the refreshing token
IsGenerateRefresh bool
}
// RefreshConfig refreshing token config
type RefreshConfig struct {
// whether to reset the refreshing creation time
IsResetRefreshTime bool
// whether to remove access token
IsRemoveAccess bool
// whether to remove refreshing token
IsRemoveRefreshing bool
}
// default configs
var (
DefaultAccessTokenCfg = &Config{AccessTokenExp: time.Hour * 24, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true}
DefaultRefreshTokenCfg = &RefreshConfig{IsResetRefreshTime: true, IsRemoveAccess: true, IsRemoveRefreshing: true}
)

12
pkg/auth/error.go Normal file
View File

@ -0,0 +1,12 @@
package auth
import "errors"
var New = errors.New
var (
ErrInvalidAccessToken = errors.New("invalid access token")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrExpiredAccessToken = errors.New("expired access token")
ErrExpiredRefreshToken = errors.New("expired refresh token")
)

17
pkg/auth/generate.go Normal file
View File

@ -0,0 +1,17 @@
package auth
import (
"time"
)
type (
GenerateBasic struct {
UserID string
CreateAt time.Time
TokenInfo TokenInfo
}
AccessGenerate interface {
Token(data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
}
)

97
pkg/auth/jwt_access.go Normal file
View File

@ -0,0 +1,97 @@
package auth
import (
"encoding/base64"
"errors"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)
// JWTAccessClaims jwt claims
type JWTAccessClaims struct {
jwt.RegisteredClaims
}
// Valid claims verification
func (a *JWTAccessClaims) Valid() error {
if a.ExpiresAt.Before(time.Now()) {
return ErrInvalidAccessToken
}
return nil
}
// NewJWTAccessGenerate create to generate the jwt access token instance
func NewJWTAccessGenerate(key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
return &JWTAccessGenerate{
SignedKey: key,
SignedMethod: method,
}
}
// JWTAccessGenerate generate the jwt access token
type JWTAccessGenerate struct {
SignedKey []byte
SignedMethod jwt.SigningMethod
}
// Token based on the UUID generated token
func (a *JWTAccessGenerate) Token(data *GenerateBasic, isGenRefresh bool) (string, string, error) {
claims := &JWTAccessClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "BvBeJ",
Subject: data.UserID,
ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())),
},
}
token := jwt.NewWithClaims(a.SignedMethod, claims)
var key any
if a.isEs() {
v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isRsOrPS() {
v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isHs() {
key = a.SignedKey
} else {
return "", "", errors.New("unsupported sign method")
}
access, err := token.SignedString(key)
if err != nil {
return "", "", err
}
refresh := ""
if isGenRefresh {
t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
refresh = base64.URLEncoding.EncodeToString([]byte(t))
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
}
return access, refresh, nil
}
func (a *JWTAccessGenerate) isEs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
}
func (a *JWTAccessGenerate) isRsOrPS() bool {
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
return isRs || isPs
}
func (a *JWTAccessGenerate) isHs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
}

194
pkg/auth/manager.go Normal file
View File

@ -0,0 +1,194 @@
package auth
import (
"time"
)
// NewManager create to authorization management instance
func NewManager(ag AccessGenerate, ts TokenStore) *Manager {
return &Manager{
cfg: DefaultAccessTokenCfg,
rCfg: DefaultRefreshTokenCfg,
accessGenerate: ag,
tokenStore: ts,
}
}
// SetConfig mapping the access token generate config
func (m *Manager) SetConfig(cfg *Config) {
m.cfg = cfg
}
// SetRefreshTokenConfig mapping the token refresh config
func (m *Manager) SetRefreshTokenConfig(store *RefreshConfig) {
m.rCfg = store
}
// Manager provide authorization management
type Manager struct {
cfg *Config
rCfg *RefreshConfig
accessGenerate AccessGenerate
tokenStore TokenStore
}
// GenerateAccessToken generate the access token
func (m *Manager) GenerateAccessToken(userID string) (TokenInfo, error) {
ti := NewToken()
ti.SetUserID(userID)
createAt := time.Now()
ti.SetAccessCreateAt(createAt)
// set access token expires
ti.SetAccessExpiresIn(m.cfg.AccessTokenExp)
if m.cfg.IsGenerateRefresh {
ti.SetRefreshCreateAt(createAt)
ti.SetRefreshExpiresIn(m.cfg.RefreshTokenExp)
}
td := &GenerateBasic{
UserID: userID,
CreateAt: createAt,
TokenInfo: ti,
}
av, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(av)
if rv != "" {
ti.SetRefresh(rv)
}
err = m.tokenStore.Create(ti)
if err != nil {
return nil, err
}
return ti, nil
}
// RefreshAccessToken refreshing an access token
func (m *Manager) RefreshAccessToken(refresh string) (TokenInfo, error) {
ti, err := m.LoadRefreshToken(refresh)
if err != nil {
return nil, err
}
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
td := &GenerateBasic{
UserID: ti.GetUserID(),
CreateAt: time.Now(),
TokenInfo: ti,
}
ti.SetAccessCreateAt(td.CreateAt)
if v := m.cfg.AccessTokenExp; v > 0 {
ti.SetAccessExpiresIn(v)
}
if v := m.cfg.RefreshTokenExp; v > 0 {
ti.SetRefreshExpiresIn(v)
}
if m.rCfg.IsResetRefreshTime {
ti.SetRefreshCreateAt(td.CreateAt)
}
tv, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(tv)
if rv != "" {
ti.SetRefresh(rv)
}
if err = m.tokenStore.Create(ti); err != nil {
return nil, err
}
if m.rCfg.IsRemoveAccess {
// remove the old access token
if err = m.tokenStore.RemoveByAccess(oldAccess); err != nil {
return nil, err
}
}
if m.rCfg.IsRemoveRefreshing && rv != "" {
// remove the old refresh token
if err = m.tokenStore.RemoveByRefresh(oldRefresh); err != nil {
return nil, err
}
}
if rv == "" {
ti.SetRefresh("")
ti.SetRefreshCreateAt(time.Now())
ti.SetRefreshExpiresIn(0)
}
return ti, nil
}
// RemoveAccessToken use the access token to delete the token information
func (m *Manager) RemoveAccessToken(access string) error {
if access == "" {
return ErrInvalidAccessToken
}
return m.tokenStore.RemoveByAccess(access)
}
// RemoveRefreshToken use the refresh token to delete the token information
func (m *Manager) RemoveRefreshToken(refresh string) error {
if refresh == "" {
return ErrInvalidAccessToken
}
return m.tokenStore.RemoveByRefresh(refresh)
}
// LoadAccessToken according to the access token for corresponding token information
func (m *Manager) LoadAccessToken(access string) (TokenInfo, error) {
if access == "" {
return nil, ErrInvalidAccessToken
}
ct := time.Now()
ti, err := m.tokenStore.GetByAccess(access)
if err != nil {
return nil, err
} else if ti == nil || ti.GetAccess() != access {
return nil, ErrInvalidAccessToken
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
return nil, ErrExpiredRefreshToken
} else if ti.GetAccessExpiresIn() != 0 &&
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
return nil, ErrExpiredAccessToken
}
return ti, nil
}
// LoadRefreshToken according to the refresh token for corresponding token information
func (m *Manager) LoadRefreshToken(refresh string) (TokenInfo, error) {
if refresh == "" {
return nil, ErrInvalidRefreshToken
}
ti, err := m.tokenStore.GetByRefresh(refresh)
if err != nil {
return nil, err
} else if ti == nil || ti.GetRefresh() != refresh {
return nil, ErrInvalidRefreshToken
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
return nil, ErrExpiredRefreshToken
}
return ti, nil
}

334
pkg/auth/store.go Normal file
View File

@ -0,0 +1,334 @@
package auth
import (
"context"
"errors"
"fmt"
"github.com/google/uuid"
jsonIterator "github.com/json-iterator/go"
"github.com/redis/go-redis/v9"
"github.com/tidwall/buntdb"
"time"
)
var (
jsonMarshal = jsonIterator.Marshal
jsonUnmarshal = jsonIterator.Unmarshal
)
type TokenStore interface {
Create(info TokenInfo) error
RemoveByAccess(access string) error
RemoveByRefresh(refresh string) error
GetByAccess(access string) (TokenInfo, error)
GetByRefresh(refresh string) (TokenInfo, error)
}
// NewMemoryTokenStore create a token buntStore instance based on memory
func NewMemoryTokenStore() (TokenStore, error) {
return NewFileTokenStore(":memory:")
}
// NewFileTokenStore create a token buntStore instance based on file
func NewFileTokenStore(filename string) (TokenStore, error) {
db, err := buntdb.Open(filename)
if err != nil {
return nil, err
}
return &buntStore{db: db}, nil
}
// buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
type buntStore struct {
db *buntdb.DB
}
func (ts *buntStore) remove(key string) error {
err := ts.db.Update(func(tx *buntdb.Tx) error {
_, err := tx.Delete(key)
return err
})
if errors.Is(err, buntdb.ErrNotFound) {
return nil
}
return err
}
func (ts *buntStore) getData(key string) (TokenInfo, error) {
var ti TokenInfo
err := ts.db.View(func(tx *buntdb.Tx) error {
jv, err := tx.Get(key)
if err != nil {
return err
}
var tm Token
err = jsonUnmarshal([]byte(jv), &tm)
if err != nil {
return err
}
ti = &tm
return nil
})
if err != nil {
if err == buntdb.ErrNotFound {
return nil, nil
}
return nil, err
}
return ti, nil
}
func (ts *buntStore) getBasicID(key string) (string, error) {
var basicID string
err := ts.db.View(func(tx *buntdb.Tx) error {
v, err := tx.Get(key)
if err != nil {
return err
}
basicID = v
return nil
})
if err != nil {
if err == buntdb.ErrNotFound {
return "", nil
}
return "", err
}
return basicID, nil
}
// Create and buntStore the new token information
func (ts *buntStore) Create(info TokenInfo) error {
ct := time.Now()
jv, err := jsonMarshal(info)
if err != nil {
return err
}
return ts.db.Update(func(tx *buntdb.Tx) error {
basicID := uuid.Must(uuid.NewRandom()).String()
aexp := info.GetAccessExpiresIn()
rexp := aexp
expires := true
if refresh := info.GetRefresh(); refresh != "" {
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
if aexp.Seconds() > rexp.Seconds() {
aexp = rexp
}
expires = info.GetRefreshExpiresIn() != 0
_, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
if err != nil {
return err
}
}
_, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
if err != nil {
return err
}
_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
return err
})
}
// RemoveByAccess use the access token to delete the token information
func (ts *buntStore) RemoveByAccess(access string) error {
return ts.remove(access)
}
// RemoveByRefresh use the refresh token to delete the token information
func (ts *buntStore) RemoveByRefresh(refresh string) error {
return ts.remove(refresh)
}
// GetByAccess use the access token for token information data
func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
basicID, err := ts.getBasicID(access)
if err != nil {
return nil, err
}
return ts.getData(basicID)
}
// GetByRefresh use the refresh token for token information data
func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
basicID, err := ts.getBasicID(refresh)
if err != nil {
return nil, err
}
return ts.getData(basicID)
}
/*------------------------------------------------------------------------------------*/
// NewRedisStoreWithCli create an instance of a redis store
func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
store := &redisStore{
cli: cli,
ctx: context.TODO(),
ns: keyNamespace,
}
return store
}
// TokenStore redis token store
type redisStore struct {
cli *redis.Client
ctx context.Context
ns string
}
func (s *redisStore) wrapperKey(key string) string {
return fmt.Sprintf("%s%s", s.ns, key)
}
func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
if err := result.Err(); err != nil {
if err == redis.Nil {
return true, nil
}
return false, err
}
return false, nil
}
func (s *redisStore) remove(key string) error {
result := s.cli.Del(s.ctx, s.wrapperKey(key))
_, err := s.checkError(result)
return err
}
func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
basicID, err := s.getBasicID(tokenString)
if err != nil {
return err
} else if basicID == "" {
return nil
}
err = s.remove(tokenString)
if err != nil {
return err
}
token, err := s.getToken(basicID)
if err != nil {
return err
} else if token == nil {
return nil
}
checkToken := token.GetRefresh()
if isRefresh {
checkToken = token.GetAccess()
}
result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
if err = result.Err(); err != nil && err != redis.Nil {
return err
} else if result.Val() == 0 {
return s.remove(basicID)
}
return nil
}
func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
if ok, err := s.checkError(result); err != nil {
return nil, err
} else if ok {
return nil, nil
}
buf, err := result.Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
var token Token
if err = jsonUnmarshal(buf, &token); err != nil {
return nil, err
}
return &token, nil
}
func (s *redisStore) getToken(key string) (TokenInfo, error) {
result := s.cli.Get(s.ctx, s.wrapperKey(key))
return s.parseToken(result)
}
func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
if ok, err := s.checkError(result); err != nil {
return "", err
} else if ok {
return "", nil
}
return result.Val(), nil
}
func (s *redisStore) getBasicID(token string) (string, error) {
result := s.cli.Get(s.ctx, s.wrapperKey(token))
return s.parseBasicID(result)
}
// Create and store the new token information
func (s *redisStore) Create(info TokenInfo) error {
ct := time.Now()
jv, err := jsonMarshal(info)
if err != nil {
return err
}
pipe := s.cli.TxPipeline()
basicID := uuid.Must(uuid.NewRandom()).String()
aexp := info.GetAccessExpiresIn()
rexp := aexp
if refresh := info.GetRefresh(); refresh != "" {
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
if aexp.Seconds() > rexp.Seconds() {
aexp = rexp
}
pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
}
pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
if _, err = pipe.Exec(s.ctx); err != nil {
return err
}
return nil
}
// RemoveByAccess Use the access token to delete the token information
func (s *redisStore) RemoveByAccess(access string) error {
return s.removeToken(access, false)
}
// RemoveByRefresh Use the refresh token to delete the token information
func (s *redisStore) RemoveByRefresh(refresh string) error {
return s.removeToken(refresh, true)
}
// GetByAccess Use the access token for token information data
func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
basicID, err := s.getBasicID(access)
if err != nil || basicID == "" {
return nil, err
}
return s.getToken(basicID)
}
// GetByRefresh Use the refresh token for token information data
func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
basicID, err := s.getBasicID(refresh)
if err != nil || basicID == "" {
return nil, err
}
return s.getToken(basicID)
}

118
pkg/auth/token.go Normal file
View File

@ -0,0 +1,118 @@
package auth
import (
"time"
)
// TokenInfo the token information model interface
type TokenInfo interface {
New() TokenInfo
GetUserID() string
SetUserID(string)
GetAccess() string
SetAccess(string)
GetAccessCreateAt() time.Time
SetAccessCreateAt(time.Time)
GetAccessExpiresIn() time.Duration
SetAccessExpiresIn(time.Duration)
GetRefresh() string
SetRefresh(string)
GetRefreshCreateAt() time.Time
SetRefreshCreateAt(time.Time)
GetRefreshExpiresIn() time.Duration
SetRefreshExpiresIn(time.Duration)
}
// NewToken create to token model instance
func NewToken() *Token {
return &Token{}
}
// Token token model
type Token struct {
UserID string `bson:"UserID"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
}
// New create to token model instance
func (t *Token) New() TokenInfo {
return NewToken()
}
// GetUserID the user id
func (t *Token) GetUserID() string {
return t.UserID
}
// SetUserID the user id
func (t *Token) SetUserID(userID string) {
t.UserID = userID
}
// GetAccess access Token
func (t *Token) GetAccess() string {
return t.Access
}
// SetAccess access Token
func (t *Token) SetAccess(access string) {
t.Access = access
}
// GetAccessCreateAt create Time
func (t *Token) GetAccessCreateAt() time.Time {
return t.AccessCreateAt
}
// SetAccessCreateAt create Time
func (t *Token) SetAccessCreateAt(createAt time.Time) {
t.AccessCreateAt = createAt
}
// GetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) GetAccessExpiresIn() time.Duration {
return t.AccessExpiresIn
}
// SetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) SetAccessExpiresIn(exp time.Duration) {
t.AccessExpiresIn = exp
}
// GetRefresh refresh Token
func (t *Token) GetRefresh() string {
return t.Refresh
}
// SetRefresh refresh Token
func (t *Token) SetRefresh(refresh string) {
t.Refresh = refresh
}
// GetRefreshCreateAt create Time
func (t *Token) GetRefreshCreateAt() time.Time {
return t.RefreshCreateAt
}
// SetRefreshCreateAt create Time
func (t *Token) SetRefreshCreateAt(createAt time.Time) {
t.RefreshCreateAt = createAt
}
// GetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) GetRefreshExpiresIn() time.Duration {
return t.RefreshExpiresIn
}
// SetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) SetRefreshExpiresIn(exp time.Duration) {
t.RefreshExpiresIn = exp
}