195 lines
4.5 KiB
Go
195 lines
4.5 KiB
Go
|
package auth
|
||
|
|
||
|
import (
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
// NewManager create to authorization management instance
|
||
|
func NewManager(ag AccessGenerate, ts TokenStore) *Manager {
|
||
|
return &Manager{
|
||
|
cfg: DefaultAccessTokenCfg,
|
||
|
rCfg: DefaultRefreshTokenCfg,
|
||
|
accessGenerate: ag,
|
||
|
tokenStore: ts,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetConfig mapping the access token generate config
|
||
|
func (m *Manager) SetConfig(cfg *Config) {
|
||
|
m.cfg = cfg
|
||
|
}
|
||
|
|
||
|
// SetRefreshTokenConfig mapping the token refresh config
|
||
|
func (m *Manager) SetRefreshTokenConfig(store *RefreshConfig) {
|
||
|
m.rCfg = store
|
||
|
}
|
||
|
|
||
|
// Manager provide authorization management
|
||
|
type Manager struct {
|
||
|
cfg *Config
|
||
|
rCfg *RefreshConfig
|
||
|
accessGenerate AccessGenerate
|
||
|
tokenStore TokenStore
|
||
|
}
|
||
|
|
||
|
// GenerateAccessToken generate the access token
|
||
|
func (m *Manager) GenerateAccessToken(userID string) (TokenInfo, error) {
|
||
|
ti := NewToken()
|
||
|
ti.SetUserID(userID)
|
||
|
|
||
|
createAt := time.Now()
|
||
|
ti.SetAccessCreateAt(createAt)
|
||
|
|
||
|
// set access token expires
|
||
|
ti.SetAccessExpiresIn(m.cfg.AccessTokenExp)
|
||
|
if m.cfg.IsGenerateRefresh {
|
||
|
ti.SetRefreshCreateAt(createAt)
|
||
|
ti.SetRefreshExpiresIn(m.cfg.RefreshTokenExp)
|
||
|
}
|
||
|
|
||
|
td := &GenerateBasic{
|
||
|
UserID: userID,
|
||
|
CreateAt: createAt,
|
||
|
TokenInfo: ti,
|
||
|
}
|
||
|
|
||
|
av, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
ti.SetAccess(av)
|
||
|
|
||
|
if rv != "" {
|
||
|
ti.SetRefresh(rv)
|
||
|
}
|
||
|
|
||
|
err = m.tokenStore.Create(ti)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return ti, nil
|
||
|
}
|
||
|
|
||
|
// RefreshAccessToken refreshing an access token
|
||
|
func (m *Manager) RefreshAccessToken(refresh string) (TokenInfo, error) {
|
||
|
ti, err := m.LoadRefreshToken(refresh)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
|
||
|
|
||
|
td := &GenerateBasic{
|
||
|
UserID: ti.GetUserID(),
|
||
|
CreateAt: time.Now(),
|
||
|
TokenInfo: ti,
|
||
|
}
|
||
|
|
||
|
ti.SetAccessCreateAt(td.CreateAt)
|
||
|
if v := m.cfg.AccessTokenExp; v > 0 {
|
||
|
ti.SetAccessExpiresIn(v)
|
||
|
}
|
||
|
|
||
|
if v := m.cfg.RefreshTokenExp; v > 0 {
|
||
|
ti.SetRefreshExpiresIn(v)
|
||
|
}
|
||
|
|
||
|
if m.rCfg.IsResetRefreshTime {
|
||
|
ti.SetRefreshCreateAt(td.CreateAt)
|
||
|
}
|
||
|
|
||
|
tv, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
ti.SetAccess(tv)
|
||
|
if rv != "" {
|
||
|
ti.SetRefresh(rv)
|
||
|
}
|
||
|
|
||
|
if err = m.tokenStore.Create(ti); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if m.rCfg.IsRemoveAccess {
|
||
|
// remove the old access token
|
||
|
if err = m.tokenStore.RemoveByAccess(oldAccess); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if m.rCfg.IsRemoveRefreshing && rv != "" {
|
||
|
// remove the old refresh token
|
||
|
if err = m.tokenStore.RemoveByRefresh(oldRefresh); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if rv == "" {
|
||
|
ti.SetRefresh("")
|
||
|
ti.SetRefreshCreateAt(time.Now())
|
||
|
ti.SetRefreshExpiresIn(0)
|
||
|
}
|
||
|
|
||
|
return ti, nil
|
||
|
}
|
||
|
|
||
|
// RemoveAccessToken use the access token to delete the token information
|
||
|
func (m *Manager) RemoveAccessToken(access string) error {
|
||
|
if access == "" {
|
||
|
return ErrInvalidAccessToken
|
||
|
}
|
||
|
return m.tokenStore.RemoveByAccess(access)
|
||
|
}
|
||
|
|
||
|
// RemoveRefreshToken use the refresh token to delete the token information
|
||
|
func (m *Manager) RemoveRefreshToken(refresh string) error {
|
||
|
if refresh == "" {
|
||
|
return ErrInvalidAccessToken
|
||
|
}
|
||
|
return m.tokenStore.RemoveByRefresh(refresh)
|
||
|
}
|
||
|
|
||
|
// LoadAccessToken according to the access token for corresponding token information
|
||
|
func (m *Manager) LoadAccessToken(access string) (TokenInfo, error) {
|
||
|
if access == "" {
|
||
|
return nil, ErrInvalidAccessToken
|
||
|
}
|
||
|
|
||
|
ct := time.Now()
|
||
|
ti, err := m.tokenStore.GetByAccess(access)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
} else if ti == nil || ti.GetAccess() != access {
|
||
|
return nil, ErrInvalidAccessToken
|
||
|
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
|
||
|
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
|
||
|
return nil, ErrExpiredRefreshToken
|
||
|
} else if ti.GetAccessExpiresIn() != 0 &&
|
||
|
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
|
||
|
return nil, ErrExpiredAccessToken
|
||
|
}
|
||
|
return ti, nil
|
||
|
}
|
||
|
|
||
|
// LoadRefreshToken according to the refresh token for corresponding token information
|
||
|
func (m *Manager) LoadRefreshToken(refresh string) (TokenInfo, error) {
|
||
|
if refresh == "" {
|
||
|
return nil, ErrInvalidRefreshToken
|
||
|
}
|
||
|
|
||
|
ti, err := m.tokenStore.GetByRefresh(refresh)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
} else if ti == nil || ti.GetRefresh() != refresh {
|
||
|
return nil, ErrInvalidRefreshToken
|
||
|
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
|
||
|
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
|
||
|
return nil, ErrExpiredRefreshToken
|
||
|
}
|
||
|
|
||
|
return ti, nil
|
||
|
}
|