base-golang/pkg/auth/manager.go

195 lines
4.5 KiB
Go
Raw Normal View History

2024-07-23 10:23:43 +08:00
package auth
import (
"time"
)
// NewManager create to authorization management instance
func NewManager(ag AccessGenerate, ts TokenStore) *Manager {
return &Manager{
cfg: DefaultAccessTokenCfg,
rCfg: DefaultRefreshTokenCfg,
accessGenerate: ag,
tokenStore: ts,
}
}
// SetConfig mapping the access token generate config
func (m *Manager) SetConfig(cfg *Config) {
m.cfg = cfg
}
// SetRefreshTokenConfig mapping the token refresh config
func (m *Manager) SetRefreshTokenConfig(store *RefreshConfig) {
m.rCfg = store
}
// Manager provide authorization management
type Manager struct {
cfg *Config
rCfg *RefreshConfig
accessGenerate AccessGenerate
tokenStore TokenStore
}
// GenerateAccessToken generate the access token
func (m *Manager) GenerateAccessToken(userID string) (TokenInfo, error) {
ti := NewToken()
ti.SetUserID(userID)
createAt := time.Now()
ti.SetAccessCreateAt(createAt)
// set access token expires
ti.SetAccessExpiresIn(m.cfg.AccessTokenExp)
if m.cfg.IsGenerateRefresh {
ti.SetRefreshCreateAt(createAt)
ti.SetRefreshExpiresIn(m.cfg.RefreshTokenExp)
}
td := &GenerateBasic{
UserID: userID,
CreateAt: createAt,
TokenInfo: ti,
}
av, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(av)
if rv != "" {
ti.SetRefresh(rv)
}
err = m.tokenStore.Create(ti)
if err != nil {
return nil, err
}
return ti, nil
}
// RefreshAccessToken refreshing an access token
func (m *Manager) RefreshAccessToken(refresh string) (TokenInfo, error) {
ti, err := m.LoadRefreshToken(refresh)
if err != nil {
return nil, err
}
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
td := &GenerateBasic{
UserID: ti.GetUserID(),
CreateAt: time.Now(),
TokenInfo: ti,
}
ti.SetAccessCreateAt(td.CreateAt)
if v := m.cfg.AccessTokenExp; v > 0 {
ti.SetAccessExpiresIn(v)
}
if v := m.cfg.RefreshTokenExp; v > 0 {
ti.SetRefreshExpiresIn(v)
}
if m.rCfg.IsResetRefreshTime {
ti.SetRefreshCreateAt(td.CreateAt)
}
tv, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(tv)
if rv != "" {
ti.SetRefresh(rv)
}
if err = m.tokenStore.Create(ti); err != nil {
return nil, err
}
if m.rCfg.IsRemoveAccess {
// remove the old access token
if err = m.tokenStore.RemoveByAccess(oldAccess); err != nil {
return nil, err
}
}
if m.rCfg.IsRemoveRefreshing && rv != "" {
// remove the old refresh token
if err = m.tokenStore.RemoveByRefresh(oldRefresh); err != nil {
return nil, err
}
}
if rv == "" {
ti.SetRefresh("")
ti.SetRefreshCreateAt(time.Now())
ti.SetRefreshExpiresIn(0)
}
return ti, nil
}
// RemoveAccessToken use the access token to delete the token information
func (m *Manager) RemoveAccessToken(access string) error {
if access == "" {
return ErrInvalidAccessToken
}
return m.tokenStore.RemoveByAccess(access)
}
// RemoveRefreshToken use the refresh token to delete the token information
func (m *Manager) RemoveRefreshToken(refresh string) error {
if refresh == "" {
return ErrInvalidAccessToken
}
return m.tokenStore.RemoveByRefresh(refresh)
}
// LoadAccessToken according to the access token for corresponding token information
func (m *Manager) LoadAccessToken(access string) (TokenInfo, error) {
if access == "" {
return nil, ErrInvalidAccessToken
}
ct := time.Now()
ti, err := m.tokenStore.GetByAccess(access)
if err != nil {
return nil, err
} else if ti == nil || ti.GetAccess() != access {
return nil, ErrInvalidAccessToken
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
return nil, ErrExpiredRefreshToken
} else if ti.GetAccessExpiresIn() != 0 &&
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
return nil, ErrExpiredAccessToken
}
return ti, nil
}
// LoadRefreshToken according to the refresh token for corresponding token information
func (m *Manager) LoadRefreshToken(refresh string) (TokenInfo, error) {
if refresh == "" {
return nil, ErrInvalidRefreshToken
}
ti, err := m.tokenStore.GetByRefresh(refresh)
if err != nil {
return nil, err
} else if ti == nil || ti.GetRefresh() != refresh {
return nil, ErrInvalidRefreshToken
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
return nil, ErrExpiredRefreshToken
}
return ti, nil
}