first commit

This commit is contained in:
bvbej 2024-07-23 10:23:43 +08:00
commit 7b4c2521a3
126 changed files with 15931 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
/.idea
/vendor
*.yaml.json
*_test.go

4
README.md Normal file
View File

@ -0,0 +1,4 @@
## Golang公共包
* pb协议文件生成命令
`protoc --proto_path=pkg/websocket/codec/protobuf/protocol --go_out=pkg/websocket/codec/protobuf/protocol --go_opt=paths=source_relative base.proto`

119
go.mod Normal file
View File

@ -0,0 +1,119 @@
module git.bvbej.com/bvbej/base-golang
go 1.22.4
require (
github.com/apolloconfig/agollo/v4 v4.4.0
github.com/gin-contrib/pprof v1.5.0
github.com/gin-gonic/gin v1.10.0
github.com/go-playground/validator/v10 v10.22.0
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jinzhu/now v1.1.5
github.com/json-iterator/go v1.1.12
github.com/mojocn/base64Captcha v1.3.6
github.com/mritd/chinaid v1.0.4
github.com/panjf2000/ants/v2 v2.10.0
github.com/prometheus/client_golang v1.19.1
github.com/qiniu/go-sdk/v7 v7.21.1
github.com/redis/go-redis/v9 v9.5.3
github.com/robfig/cron/v3 v3.0.1
github.com/rs/cors v1.11.0
github.com/rs/cors/wrapper/gin v0.0.0-20240515105523-1562b1715b35
github.com/speps/go-hashids v2.0.0+incompatible
github.com/spf13/cast v1.6.0
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
github.com/tidwall/buntdb v1.3.1
github.com/tidwall/gjson v1.17.1
github.com/tus/tusd v1.13.0
github.com/xuri/excelize/v2 v2.8.1
go.mongodb.org/mongo-driver v1.15.1
go.uber.org/atomic v1.11.0
go.uber.org/multierr v1.11.0
go.uber.org/zap v1.27.0
golang.org/x/crypto v0.24.0
golang.org/x/net v0.26.0
golang.org/x/sync v0.7.0
golang.org/x/time v0.5.0
google.golang.org/grpc v1.64.0
google.golang.org/protobuf v1.34.2
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gorm.io/driver/mysql v1.5.7
gorm.io/gorm v1.25.10
)
require (
github.com/BurntSushi/toml v1.3.2 // indirect
github.com/alex-ant/gomath v0.0.0-20160516115720-89013a210a82 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bmizerany/pat v0.0.0-20170815010413-6226ea591a40 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gofrs/flock v0.8.1 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/klauspost/compress v1.17.2 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
github.com/matishsiao/goInfo v0.0.0-20210923090445-da2e3fa8d45f // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.5.0 // indirect
github.com/prometheus/common v0.48.0 // indirect
github.com/prometheus/procfs v0.12.0 // indirect
github.com/richardlehane/mscfb v1.0.4 // indirect
github.com/richardlehane/msoleps v1.0.3 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/btree v1.4.2 // indirect
github.com/tidwall/grect v0.1.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/rtred v0.1.2 // indirect
github.com/tidwall/tinyqueue v0.1.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/xuri/efp v0.0.0-20231025114914-d1ff6096ae53 // indirect
github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05 // indirect
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/image v0.14.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

2248
go.sum Normal file

File diff suppressed because it is too large Load Diff

120
pkg/aes/aes.go Normal file
View File

@ -0,0 +1,120 @@
package aes
import (
cryptoAes "crypto/aes"
"crypto/cipher"
"encoding/base64"
)
var _ Aes = (*aes)(nil)
type Aes interface {
i()
EncryptCBC(encryptStr string, urlEncode bool) (string, error)
DecryptCBC(decryptStr string, urlEncode bool) (string, error)
EncryptCFB(plain string, urlEncode bool) (string, error)
DecryptCFB(encrypted string, urlEncode bool) (string, error)
}
type aes struct {
key string
iv string
}
func New(key, iv string) Aes {
return &aes{
key: key,
iv: iv,
}
}
func (a *aes) i() {}
func (a *aes) EncryptCBC(encryptStr string, urlEncode bool) (string, error) {
encoder := base64.StdEncoding
if urlEncode {
encoder = base64.URLEncoding
}
encryptBytes := []byte(encryptStr)
block, err := cryptoAes.NewCipher([]byte(a.key))
if err != nil {
return "", err
}
blockSize := block.BlockSize()
encryptBytes = pkcsPadding(encryptBytes, blockSize)
blockMode := cipher.NewCBCEncrypter(block, []byte(a.iv))
encrypted := make([]byte, len(encryptBytes))
blockMode.CryptBlocks(encrypted, encryptBytes)
return encoder.EncodeToString(encrypted), nil
}
func (a *aes) DecryptCBC(decryptStr string, urlEncode bool) (string, error) {
encoder := base64.StdEncoding
if urlEncode {
encoder = base64.URLEncoding
}
decryptBytes, err := encoder.DecodeString(decryptStr)
if err != nil {
return "", err
}
block, err := cryptoAes.NewCipher([]byte(a.key))
if err != nil {
return "", err
}
blockMode := cipher.NewCBCDecrypter(block, []byte(a.iv))
decrypted := make([]byte, len(decryptBytes))
blockMode.CryptBlocks(decrypted, decryptBytes)
decrypted = pkcsUnPadding(decrypted)
return string(decrypted), nil
}
func (a *aes) EncryptCFB(plain string, urlEncode bool) (string, error) {
encoder := base64.StdEncoding
if urlEncode {
encoder = base64.URLEncoding
}
block, err := cryptoAes.NewCipher([]byte(a.key))
if err != nil {
return "", err
}
encrypted := make([]byte, len(plain))
stream := cipher.NewCFBEncrypter(block, []byte(a.iv))
stream.XORKeyStream(encrypted, []byte(plain))
return encoder.EncodeToString(encrypted), nil
}
func (a *aes) DecryptCFB(encrypted string, urlEncode bool) (string, error) {
encoder := base64.StdEncoding
if urlEncode {
encoder = base64.URLEncoding
}
decryptBytes, err := encoder.DecodeString(encrypted)
if err != nil {
return "", err
}
block, err := cryptoAes.NewCipher([]byte(a.key))
if err != nil {
return "", err
}
plain := make([]byte, len(decryptBytes))
stream := cipher.NewCFBDecrypter(block, []byte(a.iv))
stream.XORKeyStream(plain, decryptBytes)
return string(plain), nil
}

25
pkg/aes/padding.go Normal file
View File

@ -0,0 +1,25 @@
package aes
import (
"bytes"
)
/**
* 填充有六种NoPadding, PKCS#5, PKCS#7, ISO 10126, ANSI X9.23和ZerosPadding
* 对于AES来说PKCS5Padding和PKCS7Padding是完全一样的不同在于PKCS5限定了块大小为8bytes而PKCS7没有限定
* 因此对于AES来说两者完全相同但是对于Rijndael就不一样了
* AES是Rijndael在块大小为8bytes时的特例对于使用其他信息块大小的Rijndael算法只能使用PKCS7
* 在AES加密当中严格来说是不能使用pkcs5的因为AES的块大小是16bytes而pkcs5只能用于8bytes通常我们在AES加密中所说的pkcs5指的就是pkcs7
*/
func pkcsPadding(cipherText []byte, blockSize int) []byte {
p := blockSize - len(cipherText)%blockSize
padText := bytes.Repeat([]byte{byte(p)}, p)
return append(cipherText, padText...)
}
func pkcsUnPadding(decrypted []byte) []byte {
length := len(decrypted)
u := int(decrypted[length-1])
return decrypted[:(length - u)]
}

View File

@ -0,0 +1,196 @@
package apk
import (
"archive/zip"
"bytes"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/android_binary"
"image"
"io"
"os"
"strconv"
_ "image/jpeg" // handle jpeg format
_ "image/png" // handle png format
)
// Apk is an application package file for android.
type Apk struct {
f *os.File
zipreader *zip.Reader
manifest Manifest
table *android_binary.TableFile
}
// OpenFile will open the file specified by filename and return Apk
func OpenFile(filename string) (apk *Apk, err error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
f.Close()
}
}()
fi, err := f.Stat()
if err != nil {
return nil, err
}
apk, err = OpenZipReader(f, fi.Size())
if err != nil {
return nil, err
}
apk.f = f
return
}
// OpenZipReader has same arguments like zip.NewReader
func OpenZipReader(r io.ReaderAt, size int64) (*Apk, error) {
zipreader, err := zip.NewReader(r, size)
if err != nil {
return nil, err
}
apk := &Apk{
zipreader: zipreader,
}
if err = apk.parseResources(); err != nil {
return nil, err
}
if err = apk.parseManifest(); err != nil {
return nil, errorf("parse-manifest: %w", err)
}
return apk, nil
}
// Close is avaliable only if apk is created with OpenFile
func (k *Apk) Close() error {
if k.f == nil {
return nil
}
return k.f.Close()
}
// Icon returns the icon image of the APK.
func (k *Apk) Icon(resConfig *android_binary.ResTableConfig) (image.Image, error) {
iconPath, err := k.manifest.App.Icon.WithResTableConfig(resConfig).String()
if err != nil {
return nil, err
}
if android_binary.IsResID(iconPath) {
return nil, newError("unable to convert icon-id to icon path")
}
imgData, err := k.readZipFile(iconPath)
if err != nil {
return nil, err
}
m, _, err := image.Decode(bytes.NewReader(imgData))
return m, err
}
// Label returns the label of the APK.
func (k *Apk) Label(resConfig *android_binary.ResTableConfig) (s string, err error) {
s, err = k.manifest.App.Label.WithResTableConfig(resConfig).String()
if err != nil {
return
}
if android_binary.IsResID(s) {
err = newError("unable to convert label-id to string")
}
return
}
// Manifest returns the manifest of the APK.
func (k *Apk) Manifest() Manifest {
return k.manifest
}
// PackageName returns the package name of the APK.
func (k *Apk) PackageName() string {
return k.manifest.Package.MustString()
}
func isMainIntentFilter(intent ActivityIntentFilter) bool {
ok := false
for _, action := range intent.Actions {
s, err := action.Name.String()
if err == nil && s == "android.intent.action.MAIN" {
ok = true
break
}
}
if !ok {
return false
}
ok = false
for _, category := range intent.Categories {
s, err := category.Name.String()
if err == nil && s == "android.intent.category.LAUNCHER" {
ok = true
break
}
}
return ok
}
// MainActivity returns the name of the main activity.
func (k *Apk) MainActivity() (activity string, err error) {
for _, act := range k.manifest.App.Activities {
for _, intent := range act.IntentFilters {
if isMainIntentFilter(intent) {
return act.Name.String()
}
}
}
for _, act := range k.manifest.App.ActivityAliases {
for _, intent := range act.IntentFilters {
if isMainIntentFilter(intent) {
return act.TargetActivity.String()
}
}
}
return "", newError("No main activity found")
}
func (k *Apk) parseManifest() error {
xmlData, err := k.readZipFile("AndroidManifest.xml")
if err != nil {
return errorf("failed to read AndroidManifest.xml: %w", err)
}
xmlfile, err := android_binary.NewXMLFile(bytes.NewReader(xmlData))
if err != nil {
return errorf("failed to parse AndroidManifest.xml: %w", err)
}
return xmlfile.Decode(&k.manifest, k.table, nil)
}
func (k *Apk) parseResources() (err error) {
resData, err := k.readZipFile("resources.arsc")
if err != nil {
return
}
k.table, err = android_binary.NewTableFile(bytes.NewReader(resData))
return
}
func (k *Apk) readZipFile(name string) (data []byte, err error) {
buf := bytes.NewBuffer(nil)
for _, file := range k.zipreader.File {
if file.Name != name {
continue
}
rc, er := file.Open()
if er != nil {
err = er
return
}
defer rc.Close()
_, err = io.Copy(buf, rc)
if err != nil {
return
}
return buf.Bytes(), nil
}
return nil, fmt.Errorf("File %s not found", strconv.Quote(name))
}

View File

@ -0,0 +1,108 @@
package apk
import "git.bvbej.com/bvbej/base-golang/pkg/android_binary"
// Instrumentation is an application instrumentation code.
type Instrumentation struct {
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
Target android_binary.String `xml:"http://schemas.android.com/apk/res/android targetPackage,attr"`
HandleProfiling android_binary.Bool `xml:"http://schemas.android.com/apk/res/android handleProfiling,attr"`
FunctionalTest android_binary.Bool `xml:"http://schemas.android.com/apk/res/android functionalTest,attr"`
}
// ActivityAction is an action of an activity.
type ActivityAction struct {
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
}
// ActivityCategory is a category of an activity.
type ActivityCategory struct {
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
}
// ActivityIntentFilter is an androidbinary.Int32ent filter of an activity.
type ActivityIntentFilter struct {
Actions []ActivityAction `xml:"action"`
Categories []ActivityCategory `xml:"category"`
}
// AppActivity is an activity in an application.
type AppActivity struct {
Theme android_binary.String `xml:"http://schemas.android.com/apk/res/android theme,attr"`
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
Label android_binary.String `xml:"http://schemas.android.com/apk/res/android label,attr"`
ScreenOrientation android_binary.String `xml:"http://schemas.android.com/apk/res/android screenOrientation,attr"`
IntentFilters []ActivityIntentFilter `xml:"intent-filter"`
}
// AppActivityAlias https://developer.android.com/guide/topics/manifest/activity-alias-element
type AppActivityAlias struct {
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
Label android_binary.String `xml:"http://schemas.android.com/apk/res/android label,attr"`
TargetActivity android_binary.String `xml:"http://schemas.android.com/apk/res/android targetActivity,attr"`
IntentFilters []ActivityIntentFilter `xml:"intent-filter"`
}
// MetaData is a metadata in an application.
type MetaData struct {
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
Value android_binary.String `xml:"http://schemas.android.com/apk/res/android value,attr"`
}
// Application is an application in an APK.
type Application struct {
AllowTaskReparenting android_binary.Bool `xml:"http://schemas.android.com/apk/res/android allowTaskReparenting,attr"`
AllowBackup android_binary.Bool `xml:"http://schemas.android.com/apk/res/android allowBackup,attr"`
BackupAgent android_binary.String `xml:"http://schemas.android.com/apk/res/android backupAgent,attr"`
Debuggable android_binary.Bool `xml:"http://schemas.android.com/apk/res/android debuggable,attr"`
Description android_binary.String `xml:"http://schemas.android.com/apk/res/android description,attr"`
Enabled android_binary.Bool `xml:"http://schemas.android.com/apk/res/android enabled,attr"`
HasCode android_binary.Bool `xml:"http://schemas.android.com/apk/res/android hasCode,attr"`
HardwareAccelerated android_binary.Bool `xml:"http://schemas.android.com/apk/res/android hardwareAccelerated,attr"`
Icon android_binary.String `xml:"http://schemas.android.com/apk/res/android icon,attr"`
KillAfterRestore android_binary.Bool `xml:"http://schemas.android.com/apk/res/android killAfterRestore,attr"`
LargeHeap android_binary.Bool `xml:"http://schemas.android.com/apk/res/android largeHeap,attr"`
Label android_binary.String `xml:"http://schemas.android.com/apk/res/android label,attr"`
Logo android_binary.String `xml:"http://schemas.android.com/apk/res/android logo,attr"`
ManageSpaceActivity android_binary.String `xml:"http://schemas.android.com/apk/res/android manageSpaceActivity,attr"`
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
Permission android_binary.String `xml:"http://schemas.android.com/apk/res/android permission,attr"`
Persistent android_binary.Bool `xml:"http://schemas.android.com/apk/res/android persistent,attr"`
Process android_binary.String `xml:"http://schemas.android.com/apk/res/android process,attr"`
RestoreAnyVersion android_binary.Bool `xml:"http://schemas.android.com/apk/res/android restoreAnyVersion,attr"`
RequiredAccountType android_binary.String `xml:"http://schemas.android.com/apk/res/android requiredAccountType,attr"`
RestrictedAccountType android_binary.String `xml:"http://schemas.android.com/apk/res/android restrictedAccountType,attr"`
SupportsRtl android_binary.Bool `xml:"http://schemas.android.com/apk/res/android supportsRtl,attr"`
TaskAffinity android_binary.String `xml:"http://schemas.android.com/apk/res/android taskAffinity,attr"`
TestOnly android_binary.Bool `xml:"http://schemas.android.com/apk/res/android testOnly,attr"`
Theme android_binary.String `xml:"http://schemas.android.com/apk/res/android theme,attr"`
UIOptions android_binary.String `xml:"http://schemas.android.com/apk/res/android uiOptions,attr"`
VMSafeMode android_binary.Bool `xml:"http://schemas.android.com/apk/res/android vmSafeMode,attr"`
Activities []AppActivity `xml:"activity"`
ActivityAliases []AppActivityAlias `xml:"activity-alias"`
MetaData []MetaData `xml:"meta-data"`
}
// UsesSDK is target SDK version.
type UsesSDK struct {
Min android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android minSdkVersion,attr"`
Target android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android targetSdkVersion,attr"`
Max android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android maxSdkVersion,attr"`
}
// UsesPermission is user grant the system permission.
type UsesPermission struct {
Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
Max android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android maxSdkVersion,attr"`
}
// Manifest is a manifest of an APK.
type Manifest struct {
Package android_binary.String `xml:"package,attr"`
VersionCode android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android versionCode,attr"`
VersionName android_binary.String `xml:"http://schemas.android.com/apk/res/android versionName,attr"`
App Application `xml:"application"`
Instrument Instrumentation `xml:"instrumentation"`
SDK UsesSDK `xml:"uses-sdk"`
UsesPermissions []UsesPermission `xml:"uses-permission"`
}

View File

@ -0,0 +1,9 @@
package apk
import (
"errors"
"fmt"
)
var newError = errors.New
var errorf = fmt.Errorf

View File

@ -0,0 +1,258 @@
package android_binary
import (
"bytes"
"encoding/binary"
"io"
"unicode/utf16"
)
// ChunkType is a type of a resource chunk.
type ChunkType uint16
// Chunk types.
const (
ResNullChunkType ChunkType = 0x0000
ResStringPoolChunkType ChunkType = 0x0001
ResTableChunkType ChunkType = 0x0002
ResXMLChunkType ChunkType = 0x0003
// Chunk types in RES_XML_TYPE
ResXMLFirstChunkType ChunkType = 0x0100
ResXMLStartNamespaceType ChunkType = 0x0100
ResXMLEndNamespaceType ChunkType = 0x0101
ResXMLStartElementType ChunkType = 0x0102
ResXMLEndElementType ChunkType = 0x0103
ResXMLCDataType ChunkType = 0x0104
ResXMLLastChunkType ChunkType = 0x017f
// This contains a uint32_t array mapping strings in the string
// pool back to resource identifiers. It is optional.
ResXMLResourceMapType ChunkType = 0x0180
// Chunk types in RES_TABLE_TYPE
ResTablePackageType ChunkType = 0x0200
ResTableTypeType ChunkType = 0x0201
ResTableTypeSpecType ChunkType = 0x0202
)
// ResChunkHeader is a header of a resource chunk.
type ResChunkHeader struct {
Type ChunkType
HeaderSize uint16
Size uint32
}
// Flags are flags for string pool header.
type Flags uint32
// the values of Flags.
const (
SortedFlag Flags = 1 << 0
UTF8Flag Flags = 1 << 8
)
// ResStringPoolHeader is a chunk header of string pool.
type ResStringPoolHeader struct {
Header ResChunkHeader
StringCount uint32
StyleCount uint32
Flags Flags
StringStart uint32
StylesStart uint32
}
// ResStringPoolSpan is a span of style information associated with
// a string in the pool.
type ResStringPoolSpan struct {
FirstChar, LastChar uint32
}
// ResStringPool is a string pool resource.
type ResStringPool struct {
Header ResStringPoolHeader
Strings []string
Styles []ResStringPoolSpan
}
// NilResStringPoolRef is nil reference for string pool.
const NilResStringPoolRef = ResStringPoolRef(0xFFFFFFFF)
// ResStringPoolRef is a type representing a reference to a string.
type ResStringPoolRef uint32
// DataType is a type of the data value.
type DataType uint8
// The constants for DataType
const (
TypeNull DataType = 0x00
TypeReference DataType = 0x01
TypeAttribute DataType = 0x02
TypeString DataType = 0x03
TypeFloat DataType = 0x04
TypeDemention DataType = 0x05
TypeFraction DataType = 0x06
TypeFirstInt DataType = 0x10
TypeIntDec DataType = 0x10
TypeIntHex DataType = 0x11
TypeIntBoolean DataType = 0x12
TypeFirstColorInt DataType = 0x1c
TypeIntColorARGB8 DataType = 0x1c
TypeIntColorRGB8 DataType = 0x1d
TypeIntColorARGB4 DataType = 0x1e
TypeIntColorRGB4 DataType = 0x1f
TypeLastColorInt DataType = 0x1f
TypeLastInt DataType = 0x1f
)
// ResValue is a representation of a value in a resource
type ResValue struct {
Size uint16
Res0 uint8
DataType DataType
Data uint32
}
// GetString returns a string referenced by ref.
func (pool *ResStringPool) GetString(ref ResStringPoolRef) string {
return pool.Strings[int(ref)]
}
func readStringPool(sr *io.SectionReader) (*ResStringPool, error) {
sp := new(ResStringPool)
if err := binary.Read(sr, binary.LittleEndian, &sp.Header); err != nil {
return nil, err
}
stringStarts := make([]uint32, sp.Header.StringCount)
if err := binary.Read(sr, binary.LittleEndian, stringStarts); err != nil {
return nil, err
}
styleStarts := make([]uint32, sp.Header.StyleCount)
if err := binary.Read(sr, binary.LittleEndian, styleStarts); err != nil {
return nil, err
}
sp.Strings = make([]string, sp.Header.StringCount)
for i, start := range stringStarts {
var str string
var err error
if _, err := sr.Seek(int64(sp.Header.StringStart+start), io.SeekStart); err != nil {
return nil, err
}
if (sp.Header.Flags & UTF8Flag) == 0 {
str, err = readUTF16(sr)
} else {
str, err = readUTF8(sr)
}
if err != nil {
return nil, err
}
sp.Strings[i] = str
}
sp.Styles = make([]ResStringPoolSpan, sp.Header.StyleCount)
for i, start := range styleStarts {
if _, err := sr.Seek(int64(sp.Header.StylesStart+start), io.SeekStart); err != nil {
return nil, err
}
if err := binary.Read(sr, binary.LittleEndian, &sp.Styles[i]); err != nil {
return nil, err
}
}
return sp, nil
}
func readUTF16(sr *io.SectionReader) (string, error) {
// read length of string
size, err := readUTF16length(sr)
if err != nil {
return "", err
}
// read string value
buf := make([]uint16, size)
if err := binary.Read(sr, binary.LittleEndian, buf); err != nil {
return "", err
}
return string(utf16.Decode(buf)), nil
}
func readUTF16length(sr *io.SectionReader) (int, error) {
var size int
var first, second uint16
if err := binary.Read(sr, binary.LittleEndian, &first); err != nil {
return 0, err
}
if (first & 0x8000) != 0 {
if err := binary.Read(sr, binary.LittleEndian, &second); err != nil {
return 0, err
}
size = (int(first&0x7FFF) << 16) + int(second)
} else {
size = int(first)
}
return size, nil
}
func readUTF8(sr *io.SectionReader) (string, error) {
// skip utf16 length
_, err := readUTF8length(sr)
if err != nil {
return "", err
}
// read lenth of string
size, err := readUTF8length(sr)
if err != nil {
return "", err
}
buf := make([]uint8, size)
if err := binary.Read(sr, binary.LittleEndian, buf); err != nil {
return "", err
}
return string(buf), nil
}
func readUTF8length(sr *io.SectionReader) (int, error) {
var size int
var first, second uint8
if err := binary.Read(sr, binary.LittleEndian, &first); err != nil {
return 0, err
}
if (first & 0x80) != 0 {
if err := binary.Read(sr, binary.LittleEndian, &second); err != nil {
return 0, err
}
size = (int(first&0x7F) << 8) + int(second)
} else {
size = int(first)
}
return size, nil
}
func newZeroFilledReader(r io.Reader, actual int64, expected int64) (io.Reader, error) {
if actual >= expected {
// no need to fill
return r, nil
}
// read `actual' bytes from r, and
buf := new(bytes.Buffer)
if _, err := io.CopyN(buf, r, actual); err != nil {
return nil, err
}
// fill zero until `expected' bytes
for i := actual; i < expected; i++ {
if err := buf.WriteByte(0x00); err != nil {
return nil, err
}
}
return buf, nil
}

1093
pkg/android_binary/table.go Normal file

File diff suppressed because it is too large Load Diff

333
pkg/android_binary/type.go Normal file
View File

@ -0,0 +1,333 @@
package android_binary
import (
"encoding/xml"
"fmt"
"reflect"
"strconv"
)
type injector interface {
inject(table *TableFile, config *ResTableConfig)
}
var injectorType = reflect.TypeOf((*injector)(nil)).Elem()
func inject(val reflect.Value, table *TableFile, config *ResTableConfig) {
if val.Kind() == reflect.Ptr {
if val.IsNil() {
return
}
val = val.Elem()
}
if val.CanInterface() && val.Type().Implements(injectorType) {
val.Interface().(injector).inject(table, config)
return
}
if val.CanAddr() {
pv := val.Addr()
if pv.CanInterface() && pv.Type().Implements(injectorType) {
pv.Interface().(injector).inject(table, config)
return
}
}
switch val.Kind() {
default:
// ignore other types
return
case reflect.Slice, reflect.Array:
l := val.Len()
for i := 0; i < l; i++ {
inject(val.Index(i), table, config)
}
return
case reflect.Struct:
l := val.NumField()
for i := 0; i < l; i++ {
inject(val.Field(i), table, config)
}
}
}
// Bool is a boolean value in XML file.
// It may be an immediate value or a reference.
type Bool struct {
value string
table *TableFile
config *ResTableConfig
}
// WithTableFile ties TableFile to the Bool.
func (v Bool) WithTableFile(table *TableFile) Bool {
return Bool{
value: v.value,
table: table,
config: v.config,
}
}
// WithResTableConfig ties ResTableConfig to the Bool.
func (v Bool) WithResTableConfig(config *ResTableConfig) Bool {
return Bool{
value: v.value,
table: v.table,
config: config,
}
}
func (v *Bool) inject(table *TableFile, config *ResTableConfig) {
v.table = table
v.config = config
}
// SetBool sets a boolean value.
func (v *Bool) SetBool(value bool) {
v.value = strconv.FormatBool(value)
}
// SetResID sets a boolean value with the resource id.
func (v *Bool) SetResID(resID ResID) {
v.value = resID.String()
}
// UnmarshalXMLAttr implements xml.UnmarshalerAttr.
func (v *Bool) UnmarshalXMLAttr(attr xml.Attr) error {
v.value = attr.Value
return nil
}
// MarshalXMLAttr implements xml.MarshalerAttr.
func (v Bool) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
if v.value == "" {
// return the zero value of bool
return xml.Attr{
Name: name,
Value: "false",
}, nil
}
return xml.Attr{
Name: name,
Value: v.value,
}, nil
}
// Bool returns the boolean value.
// It resolves the reference if needed.
func (v Bool) Bool() (bool, error) {
if v.value == "" {
return false, nil
}
if !IsResID(v.value) {
return strconv.ParseBool(v.value)
}
id, err := ParseResID(v.value)
if err != nil {
return false, err
}
value, err := v.table.GetResource(id, v.config)
if err != nil {
return false, err
}
ret, ok := value.(bool)
if !ok {
return false, fmt.Errorf("invalid type: %T", value)
}
return ret, nil
}
// MustBool is same as Bool, but it panics if it fails to parse the value.
func (v Bool) MustBool() bool {
ret, err := v.Bool()
if err != nil {
panic(err)
}
return ret
}
// Int32 is an integer value in XML file.
// It may be an immediate value or a reference.
type Int32 struct {
value string
table *TableFile
config *ResTableConfig
}
// WithTableFile ties TableFile to the Bool.
func (v Int32) WithTableFile(table *TableFile) Int32 {
return Int32{
value: v.value,
table: table,
config: v.config,
}
}
// WithResTableConfig ties ResTableConfig to the Bool.
func (v Int32) WithResTableConfig(config *ResTableConfig) Bool {
return Bool{
value: v.value,
table: v.table,
config: config,
}
}
func (v *Int32) inject(table *TableFile, config *ResTableConfig) {
v.table = table
v.config = config
}
// SetInt32 sets an integer value.
func (v *Int32) SetInt32(value int32) {
v.value = strconv.FormatInt(int64(value), 10)
}
// SetResID sets a boolean value with the resource id.
func (v *Int32) SetResID(resID ResID) {
v.value = resID.String()
}
// UnmarshalXMLAttr implements xml.UnmarshalerAttr.
func (v *Int32) UnmarshalXMLAttr(attr xml.Attr) error {
v.value = attr.Value
return nil
}
// MarshalXMLAttr implements xml.MarshalerAttr.
func (v Int32) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
if v.value == "" {
// return the zero value of int32
return xml.Attr{
Name: name,
Value: "0",
}, nil
}
return xml.Attr{
Name: name,
Value: v.value,
}, nil
}
// Int32 returns the integer value.
// It resolves the reference if needed.
func (v Int32) Int32() (int32, error) {
if v.value == "" {
return 0, nil
}
if !IsResID(v.value) {
v, err := strconv.ParseInt(v.value, 10, 32)
return int32(v), err
}
id, err := ParseResID(v.value)
if err != nil {
return 0, err
}
value, err := v.table.GetResource(id, v.config)
if err != nil {
return 0, err
}
ret, ok := value.(uint32)
if !ok {
return 0, fmt.Errorf("invalid type: %T", value)
}
return int32(ret), nil
}
// MustInt32 is same as Int32, but it panics if it fails to parse the value.
func (v Int32) MustInt32() int32 {
ret, err := v.Int32()
if err != nil {
panic(err)
}
return ret
}
// String is a boolean value in XML file.
// It may be an immediate value or a reference.
type String struct {
value string
table *TableFile
config *ResTableConfig
}
// WithTableFile ties TableFile to the Bool.
func (v String) WithTableFile(table *TableFile) String {
return String{
value: v.value,
table: table,
config: v.config,
}
}
// WithResTableConfig ties ResTableConfig to the Bool.
func (v String) WithResTableConfig(config *ResTableConfig) String {
return String{
value: v.value,
table: v.table,
config: config,
}
}
func (v *String) inject(table *TableFile, config *ResTableConfig) {
v.table = table
v.config = config
}
// SetString sets a string value.
func (v *String) SetString(value string) {
v.value = value
}
// SetResID sets a boolean value with the resource id.
func (v *String) SetResID(resID ResID) {
v.value = resID.String()
}
// UnmarshalXMLAttr implements xml.UnmarshalerAttr.
func (v *String) UnmarshalXMLAttr(attr xml.Attr) error {
v.value = attr.Value
return nil
}
// MarshalXMLAttr implements xml.MarshalerAttr.
func (v String) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
return xml.Attr{
Name: name,
Value: v.value,
}, nil
}
// String returns the string value.
// It resolves the reference if needed.
func (v String) String() (string, error) {
if !IsResID(v.value) {
return v.value, nil
}
id, err := ParseResID(v.value)
if err != nil {
return "", err
}
value, err := v.table.GetResource(id, v.config)
if err != nil {
return "", err
}
//todo 读取套娃
switch value.(type) {
case string:
return value.(string), nil
case uint32:
return fmt.Sprintf("%d", value.(uint32)), nil
default:
return "", nil
}
}
// MustString is same as String, but it panics if it fails to parse the value.
func (v String) MustString() string {
ret, err := v.String()
if err != nil {
panic(err)
}
return ret
}

271
pkg/android_binary/xml.go Normal file
View File

@ -0,0 +1,271 @@
package android_binary
import (
"bytes"
"encoding/binary"
"encoding/xml"
"fmt"
"io"
"reflect"
)
// XMLFile is an XML file expressed in binary format.
type XMLFile struct {
stringPool *ResStringPool
resourceMap []uint32
notPrecessedNS map[ResStringPoolRef]ResStringPoolRef
namespaces map[ResStringPoolRef]ResStringPoolRef
xmlBuffer bytes.Buffer
}
// ResXMLTreeNode is basic XML tree node.
type ResXMLTreeNode struct {
Header ResChunkHeader
LineNumber uint32
Comment ResStringPoolRef
}
// ResXMLTreeNamespaceExt is extended XML tree node for namespace start/end nodes.
type ResXMLTreeNamespaceExt struct {
Prefix ResStringPoolRef
URI ResStringPoolRef
}
// ResXMLTreeAttrExt is extended XML tree node for start tags -- includes attribute.
type ResXMLTreeAttrExt struct {
NS ResStringPoolRef
Name ResStringPoolRef
AttributeStart uint16
AttributeSize uint16
AttributeCount uint16
IDIndex uint16
ClassIndex uint16
StyleIndex uint16
}
// ResXMLTreeAttribute is an attribute of start tags.
type ResXMLTreeAttribute struct {
NS ResStringPoolRef
Name ResStringPoolRef
RawValue ResStringPoolRef
TypedValue ResValue
}
// ResXMLTreeEndElementExt is extended XML tree node for element start/end nodes.
type ResXMLTreeEndElementExt struct {
NS ResStringPoolRef
Name ResStringPoolRef
}
// NewXMLFile returns a new XMLFile.
func NewXMLFile(r io.ReaderAt) (*XMLFile, error) {
f := new(XMLFile)
sr := io.NewSectionReader(r, 0, 1<<63-1)
fmt.Fprintf(&f.xmlBuffer, xml.Header)
header := new(ResChunkHeader)
if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
return nil, err
}
offset := int64(header.HeaderSize)
for offset < int64(header.Size) {
chunkHeader, err := f.readChunk(r, offset)
if err != nil {
return nil, err
}
offset += int64(chunkHeader.Size)
}
return f, nil
}
// Reader returns a reader of XML file expressed in text format.
func (f *XMLFile) Reader() *bytes.Reader {
return bytes.NewReader(f.xmlBuffer.Bytes())
}
// Decode decodes XML file and stores the result in the value pointed to by v.
// To resolve the resource references, Decode also stores default TableFile and ResTableConfig in the value pointed to by v.
func (f *XMLFile) Decode(v any, table *TableFile, config *ResTableConfig) error {
decoder := xml.NewDecoder(f.Reader())
if err := decoder.Decode(v); err != nil {
return err
}
inject(reflect.ValueOf(v), table, config)
return nil
}
func (f *XMLFile) readChunk(r io.ReaderAt, offset int64) (*ResChunkHeader, error) {
sr := io.NewSectionReader(r, offset, 1<<63-1-offset)
chunkHeader := &ResChunkHeader{}
if _, err := sr.Seek(0, io.SeekStart); err != nil {
return nil, err
}
if err := binary.Read(sr, binary.LittleEndian, chunkHeader); err != nil {
return nil, err
}
var err error
if _, err := sr.Seek(0, io.SeekStart); err != nil {
return nil, err
}
switch chunkHeader.Type {
case ResStringPoolChunkType:
f.stringPool, err = readStringPool(sr)
case ResXMLStartNamespaceType:
err = f.readStartNamespace(sr)
case ResXMLEndNamespaceType:
err = f.readEndNamespace(sr)
case ResXMLStartElementType:
err = f.readStartElement(sr)
case ResXMLEndElementType:
err = f.readEndElement(sr)
}
if err != nil {
return nil, err
}
return chunkHeader, nil
}
// GetString returns a string referenced by ref.
func (f *XMLFile) GetString(ref ResStringPoolRef) string {
return f.stringPool.GetString(ref)
}
func (f *XMLFile) readStartNamespace(sr *io.SectionReader) error {
header := new(ResXMLTreeNode)
if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
return err
}
if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
return err
}
namespace := new(ResXMLTreeNamespaceExt)
if err := binary.Read(sr, binary.LittleEndian, namespace); err != nil {
return err
}
if f.notPrecessedNS == nil {
f.notPrecessedNS = make(map[ResStringPoolRef]ResStringPoolRef)
}
f.notPrecessedNS[namespace.URI] = namespace.Prefix
if f.namespaces == nil {
f.namespaces = make(map[ResStringPoolRef]ResStringPoolRef)
}
f.namespaces[namespace.URI] = namespace.Prefix
return nil
}
func (f *XMLFile) readEndNamespace(sr *io.SectionReader) error {
header := new(ResXMLTreeNode)
if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
return err
}
if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
return err
}
namespace := new(ResXMLTreeNamespaceExt)
if err := binary.Read(sr, binary.LittleEndian, namespace); err != nil {
return err
}
delete(f.namespaces, namespace.URI)
return nil
}
func (f *XMLFile) addNamespacePrefix(ns, name ResStringPoolRef) string {
if ns != NilResStringPoolRef {
prefix := f.GetString(f.namespaces[ns])
return fmt.Sprintf("%s:%s", prefix, f.GetString(name))
}
return f.GetString(name)
}
func (f *XMLFile) readStartElement(sr *io.SectionReader) error {
header := new(ResXMLTreeNode)
if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
return err
}
if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
return err
}
ext := new(ResXMLTreeAttrExt)
if err := binary.Read(sr, binary.LittleEndian, ext); err != nil {
return nil
}
fmt.Fprintf(&f.xmlBuffer, "<%s", f.addNamespacePrefix(ext.NS, ext.Name))
// output XML namespaces
if f.notPrecessedNS != nil {
for uri, prefix := range f.notPrecessedNS {
fmt.Fprintf(&f.xmlBuffer, " xmlns:%s=\"", f.GetString(prefix))
xml.Escape(&f.xmlBuffer, []byte(f.GetString(uri)))
fmt.Fprint(&f.xmlBuffer, "\"")
}
f.notPrecessedNS = nil
}
// process attributes
offset := int64(ext.AttributeStart + header.Header.HeaderSize)
for i := 0; i < int(ext.AttributeCount); i++ {
if _, err := sr.Seek(offset, io.SeekStart); err != nil {
return err
}
attr := new(ResXMLTreeAttribute)
binary.Read(sr, binary.LittleEndian, attr)
var value string
if attr.RawValue != NilResStringPoolRef {
value = f.GetString(attr.RawValue)
} else {
data := attr.TypedValue.Data
switch attr.TypedValue.DataType {
case TypeNull:
value = ""
case TypeReference:
value = fmt.Sprintf("@0x%08X", data)
case TypeIntDec:
value = fmt.Sprintf("%d", data)
case TypeIntHex:
value = fmt.Sprintf("0x%08X", data)
case TypeIntBoolean:
if data != 0 {
value = "true"
} else {
value = "false"
}
default:
value = fmt.Sprintf("@0x%08X", data)
}
}
fmt.Fprintf(&f.xmlBuffer, " %s=\"", f.addNamespacePrefix(attr.NS, attr.Name))
xml.Escape(&f.xmlBuffer, []byte(value))
fmt.Fprint(&f.xmlBuffer, "\"")
offset += int64(ext.AttributeSize)
}
fmt.Fprint(&f.xmlBuffer, ">")
return nil
}
func (f *XMLFile) readEndElement(sr *io.SectionReader) error {
header := new(ResXMLTreeNode)
if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
return err
}
if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
return err
}
ext := new(ResXMLTreeEndElementExt)
if err := binary.Read(sr, binary.LittleEndian, ext); err != nil {
return err
}
fmt.Fprintf(&f.xmlBuffer, "</%s>", f.addNamespacePrefix(ext.NS, ext.Name))
return nil
}

108
pkg/ants/pool.go Normal file
View File

@ -0,0 +1,108 @@
package ants
import (
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/ticker"
"github.com/panjf2000/ants/v2"
"go.uber.org/zap"
"runtime/debug"
"time"
)
var _ GoroutinePool = (*goroutinePool)(nil)
type GoroutinePool interface {
run()
Submit(task func())
Stop()
Size() int
Running() int
Free() int
}
type goroutinePool struct {
pool *ants.Pool
logger *zap.Logger
ticker ticker.Ticker
step int
}
type poolLogger struct {
zap *zap.Logger
}
func (l *poolLogger) Printf(format string, args ...any) {
l.zap.Sugar().Infof(format, args)
}
func NewPool(zapLogger *zap.Logger, step int) (GoroutinePool, error) {
ttl := time.Minute * 5
options := ants.Options{
Nonblocking: true,
ExpiryDuration: ttl,
PanicHandler: func(err any) {
zapLogger.Sugar().Error(
"GoroutinePool panic",
zap.String("error", fmt.Sprintf("%+v", err)),
zap.String("stack", string(debug.Stack())),
)
},
Logger: &poolLogger{zap: zapLogger},
}
antsPool, err := ants.NewPool(step, ants.WithOptions(options))
if err != nil {
return nil, err
}
pool := &goroutinePool{
pool: antsPool,
logger: zapLogger,
ticker: ticker.New(ttl),
step: step,
}
pool.run()
return pool, nil
}
func (p *goroutinePool) run() {
p.ticker.Process(func() {
if p.Free() > p.step {
mul := p.Free() / p.step
p.pool.Tune(p.Size() - p.step*mul)
}
})
}
func (p *goroutinePool) Submit(task func()) {
if p.pool.IsClosed() {
return
}
err := p.pool.Submit(task)
if errors.Is(err, ants.ErrPoolOverload) {
p.pool.Tune(p.Size() + p.step)
p.Submit(task)
}
}
func (p *goroutinePool) Size() int {
return p.pool.Cap()
}
func (p *goroutinePool) Running() int {
return p.pool.Running()
}
func (p *goroutinePool) Free() int {
return p.pool.Free()
}
func (p *goroutinePool) Stop() {
p.ticker.Stop()
p.pool.Release()
}

97
pkg/apollo/config.go Normal file
View File

@ -0,0 +1,97 @@
package apollo
import (
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/env"
"github.com/apolloconfig/agollo/v4"
"github.com/apolloconfig/agollo/v4/component/log"
apolloConfig "github.com/apolloconfig/agollo/v4/env/config"
"github.com/apolloconfig/agollo/v4/storage"
"github.com/spf13/viper"
)
type clientConfig struct {
client agollo.Client
ac *apolloConfig.AppConfig
conf any
onChange func(event *storage.ChangeEvent)
onNewestChange func(*storage.FullChangeEvent)
}
type Option func(*clientConfig)
func WithOnChangeEvent(event func(event *storage.ChangeEvent)) Option {
return func(conf *clientConfig) {
conf.onChange = event
}
}
func WithOnNewestChangeEvent(event func(event *storage.FullChangeEvent)) Option {
return func(conf *clientConfig) {
conf.onNewestChange = event
}
}
func GetApolloConfig(appId, secret string, config any, opt ...Option) error {
var err error
namespace := env.Active().Value() + ".yaml"
c := new(clientConfig)
c.conf = config
c.ac = &apolloConfig.AppConfig{
AppID: appId,
Cluster: "dev",
IP: "https://config.bvbej.com",
NamespaceName: namespace,
IsBackupConfig: false,
Secret: secret,
MustStart: true,
}
for _, option := range opt {
option(c)
}
agollo.SetLogger(&log.DefaultLogger{})
c.client, err = agollo.StartWithConfig(func() (*apolloConfig.AppConfig, error) {
return c.ac, nil
})
if err != nil {
return fmt.Errorf("get config error:[%s]", err)
}
c.client.AddChangeListener(c)
err = c.serialization()
if err != nil {
return fmt.Errorf("unmarshal config error:[%s]", err)
}
return nil
}
func (c *clientConfig) serialization() error {
parser := viper.New()
parser.SetConfigType("yaml")
c.client.GetConfigCache(c.ac.NamespaceName).Range(func(key, value any) bool {
parser.Set(key.(string), value)
return true
})
return parser.Unmarshal(c.conf)
}
func (c *clientConfig) OnChange(event *storage.ChangeEvent) {
_ = c.serialization()
if c.onChange != nil {
c.onChange(event)
}
}
func (c *clientConfig) OnNewestChange(event *storage.FullChangeEvent) {
if c.onNewestChange != nil {
c.onNewestChange(event)
}
}

29
pkg/auth/config.go Normal file
View File

@ -0,0 +1,29 @@
package auth
import "time"
// Config authorization configuration parameters
type Config struct {
// access token expiration time, 0 means it doesn't expire
AccessTokenExp time.Duration
// refresh token expiration time, 0 means it doesn't expire
RefreshTokenExp time.Duration
// whether to generate the refreshing token
IsGenerateRefresh bool
}
// RefreshConfig refreshing token config
type RefreshConfig struct {
// whether to reset the refreshing creation time
IsResetRefreshTime bool
// whether to remove access token
IsRemoveAccess bool
// whether to remove refreshing token
IsRemoveRefreshing bool
}
// default configs
var (
DefaultAccessTokenCfg = &Config{AccessTokenExp: time.Hour * 24, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true}
DefaultRefreshTokenCfg = &RefreshConfig{IsResetRefreshTime: true, IsRemoveAccess: true, IsRemoveRefreshing: true}
)

12
pkg/auth/error.go Normal file
View File

@ -0,0 +1,12 @@
package auth
import "errors"
var New = errors.New
var (
ErrInvalidAccessToken = errors.New("invalid access token")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrExpiredAccessToken = errors.New("expired access token")
ErrExpiredRefreshToken = errors.New("expired refresh token")
)

17
pkg/auth/generate.go Normal file
View File

@ -0,0 +1,17 @@
package auth
import (
"time"
)
type (
GenerateBasic struct {
UserID string
CreateAt time.Time
TokenInfo TokenInfo
}
AccessGenerate interface {
Token(data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
}
)

97
pkg/auth/jwt_access.go Normal file
View File

@ -0,0 +1,97 @@
package auth
import (
"encoding/base64"
"errors"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)
// JWTAccessClaims jwt claims
type JWTAccessClaims struct {
jwt.RegisteredClaims
}
// Valid claims verification
func (a *JWTAccessClaims) Valid() error {
if a.ExpiresAt.Before(time.Now()) {
return ErrInvalidAccessToken
}
return nil
}
// NewJWTAccessGenerate create to generate the jwt access token instance
func NewJWTAccessGenerate(key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
return &JWTAccessGenerate{
SignedKey: key,
SignedMethod: method,
}
}
// JWTAccessGenerate generate the jwt access token
type JWTAccessGenerate struct {
SignedKey []byte
SignedMethod jwt.SigningMethod
}
// Token based on the UUID generated token
func (a *JWTAccessGenerate) Token(data *GenerateBasic, isGenRefresh bool) (string, string, error) {
claims := &JWTAccessClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "BvBeJ",
Subject: data.UserID,
ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())),
},
}
token := jwt.NewWithClaims(a.SignedMethod, claims)
var key any
if a.isEs() {
v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isRsOrPS() {
v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isHs() {
key = a.SignedKey
} else {
return "", "", errors.New("unsupported sign method")
}
access, err := token.SignedString(key)
if err != nil {
return "", "", err
}
refresh := ""
if isGenRefresh {
t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
refresh = base64.URLEncoding.EncodeToString([]byte(t))
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
}
return access, refresh, nil
}
func (a *JWTAccessGenerate) isEs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
}
func (a *JWTAccessGenerate) isRsOrPS() bool {
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
return isRs || isPs
}
func (a *JWTAccessGenerate) isHs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
}

194
pkg/auth/manager.go Normal file
View File

@ -0,0 +1,194 @@
package auth
import (
"time"
)
// NewManager create to authorization management instance
func NewManager(ag AccessGenerate, ts TokenStore) *Manager {
return &Manager{
cfg: DefaultAccessTokenCfg,
rCfg: DefaultRefreshTokenCfg,
accessGenerate: ag,
tokenStore: ts,
}
}
// SetConfig mapping the access token generate config
func (m *Manager) SetConfig(cfg *Config) {
m.cfg = cfg
}
// SetRefreshTokenConfig mapping the token refresh config
func (m *Manager) SetRefreshTokenConfig(store *RefreshConfig) {
m.rCfg = store
}
// Manager provide authorization management
type Manager struct {
cfg *Config
rCfg *RefreshConfig
accessGenerate AccessGenerate
tokenStore TokenStore
}
// GenerateAccessToken generate the access token
func (m *Manager) GenerateAccessToken(userID string) (TokenInfo, error) {
ti := NewToken()
ti.SetUserID(userID)
createAt := time.Now()
ti.SetAccessCreateAt(createAt)
// set access token expires
ti.SetAccessExpiresIn(m.cfg.AccessTokenExp)
if m.cfg.IsGenerateRefresh {
ti.SetRefreshCreateAt(createAt)
ti.SetRefreshExpiresIn(m.cfg.RefreshTokenExp)
}
td := &GenerateBasic{
UserID: userID,
CreateAt: createAt,
TokenInfo: ti,
}
av, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(av)
if rv != "" {
ti.SetRefresh(rv)
}
err = m.tokenStore.Create(ti)
if err != nil {
return nil, err
}
return ti, nil
}
// RefreshAccessToken refreshing an access token
func (m *Manager) RefreshAccessToken(refresh string) (TokenInfo, error) {
ti, err := m.LoadRefreshToken(refresh)
if err != nil {
return nil, err
}
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
td := &GenerateBasic{
UserID: ti.GetUserID(),
CreateAt: time.Now(),
TokenInfo: ti,
}
ti.SetAccessCreateAt(td.CreateAt)
if v := m.cfg.AccessTokenExp; v > 0 {
ti.SetAccessExpiresIn(v)
}
if v := m.cfg.RefreshTokenExp; v > 0 {
ti.SetRefreshExpiresIn(v)
}
if m.rCfg.IsResetRefreshTime {
ti.SetRefreshCreateAt(td.CreateAt)
}
tv, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(tv)
if rv != "" {
ti.SetRefresh(rv)
}
if err = m.tokenStore.Create(ti); err != nil {
return nil, err
}
if m.rCfg.IsRemoveAccess {
// remove the old access token
if err = m.tokenStore.RemoveByAccess(oldAccess); err != nil {
return nil, err
}
}
if m.rCfg.IsRemoveRefreshing && rv != "" {
// remove the old refresh token
if err = m.tokenStore.RemoveByRefresh(oldRefresh); err != nil {
return nil, err
}
}
if rv == "" {
ti.SetRefresh("")
ti.SetRefreshCreateAt(time.Now())
ti.SetRefreshExpiresIn(0)
}
return ti, nil
}
// RemoveAccessToken use the access token to delete the token information
func (m *Manager) RemoveAccessToken(access string) error {
if access == "" {
return ErrInvalidAccessToken
}
return m.tokenStore.RemoveByAccess(access)
}
// RemoveRefreshToken use the refresh token to delete the token information
func (m *Manager) RemoveRefreshToken(refresh string) error {
if refresh == "" {
return ErrInvalidAccessToken
}
return m.tokenStore.RemoveByRefresh(refresh)
}
// LoadAccessToken according to the access token for corresponding token information
func (m *Manager) LoadAccessToken(access string) (TokenInfo, error) {
if access == "" {
return nil, ErrInvalidAccessToken
}
ct := time.Now()
ti, err := m.tokenStore.GetByAccess(access)
if err != nil {
return nil, err
} else if ti == nil || ti.GetAccess() != access {
return nil, ErrInvalidAccessToken
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
return nil, ErrExpiredRefreshToken
} else if ti.GetAccessExpiresIn() != 0 &&
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
return nil, ErrExpiredAccessToken
}
return ti, nil
}
// LoadRefreshToken according to the refresh token for corresponding token information
func (m *Manager) LoadRefreshToken(refresh string) (TokenInfo, error) {
if refresh == "" {
return nil, ErrInvalidRefreshToken
}
ti, err := m.tokenStore.GetByRefresh(refresh)
if err != nil {
return nil, err
} else if ti == nil || ti.GetRefresh() != refresh {
return nil, ErrInvalidRefreshToken
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
return nil, ErrExpiredRefreshToken
}
return ti, nil
}

334
pkg/auth/store.go Normal file
View File

@ -0,0 +1,334 @@
package auth
import (
"context"
"errors"
"fmt"
"github.com/google/uuid"
jsonIterator "github.com/json-iterator/go"
"github.com/redis/go-redis/v9"
"github.com/tidwall/buntdb"
"time"
)
var (
jsonMarshal = jsonIterator.Marshal
jsonUnmarshal = jsonIterator.Unmarshal
)
type TokenStore interface {
Create(info TokenInfo) error
RemoveByAccess(access string) error
RemoveByRefresh(refresh string) error
GetByAccess(access string) (TokenInfo, error)
GetByRefresh(refresh string) (TokenInfo, error)
}
// NewMemoryTokenStore create a token buntStore instance based on memory
func NewMemoryTokenStore() (TokenStore, error) {
return NewFileTokenStore(":memory:")
}
// NewFileTokenStore create a token buntStore instance based on file
func NewFileTokenStore(filename string) (TokenStore, error) {
db, err := buntdb.Open(filename)
if err != nil {
return nil, err
}
return &buntStore{db: db}, nil
}
// buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
type buntStore struct {
db *buntdb.DB
}
func (ts *buntStore) remove(key string) error {
err := ts.db.Update(func(tx *buntdb.Tx) error {
_, err := tx.Delete(key)
return err
})
if errors.Is(err, buntdb.ErrNotFound) {
return nil
}
return err
}
func (ts *buntStore) getData(key string) (TokenInfo, error) {
var ti TokenInfo
err := ts.db.View(func(tx *buntdb.Tx) error {
jv, err := tx.Get(key)
if err != nil {
return err
}
var tm Token
err = jsonUnmarshal([]byte(jv), &tm)
if err != nil {
return err
}
ti = &tm
return nil
})
if err != nil {
if err == buntdb.ErrNotFound {
return nil, nil
}
return nil, err
}
return ti, nil
}
func (ts *buntStore) getBasicID(key string) (string, error) {
var basicID string
err := ts.db.View(func(tx *buntdb.Tx) error {
v, err := tx.Get(key)
if err != nil {
return err
}
basicID = v
return nil
})
if err != nil {
if err == buntdb.ErrNotFound {
return "", nil
}
return "", err
}
return basicID, nil
}
// Create and buntStore the new token information
func (ts *buntStore) Create(info TokenInfo) error {
ct := time.Now()
jv, err := jsonMarshal(info)
if err != nil {
return err
}
return ts.db.Update(func(tx *buntdb.Tx) error {
basicID := uuid.Must(uuid.NewRandom()).String()
aexp := info.GetAccessExpiresIn()
rexp := aexp
expires := true
if refresh := info.GetRefresh(); refresh != "" {
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
if aexp.Seconds() > rexp.Seconds() {
aexp = rexp
}
expires = info.GetRefreshExpiresIn() != 0
_, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
if err != nil {
return err
}
}
_, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
if err != nil {
return err
}
_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
return err
})
}
// RemoveByAccess use the access token to delete the token information
func (ts *buntStore) RemoveByAccess(access string) error {
return ts.remove(access)
}
// RemoveByRefresh use the refresh token to delete the token information
func (ts *buntStore) RemoveByRefresh(refresh string) error {
return ts.remove(refresh)
}
// GetByAccess use the access token for token information data
func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
basicID, err := ts.getBasicID(access)
if err != nil {
return nil, err
}
return ts.getData(basicID)
}
// GetByRefresh use the refresh token for token information data
func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
basicID, err := ts.getBasicID(refresh)
if err != nil {
return nil, err
}
return ts.getData(basicID)
}
/*------------------------------------------------------------------------------------*/
// NewRedisStoreWithCli create an instance of a redis store
func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
store := &redisStore{
cli: cli,
ctx: context.TODO(),
ns: keyNamespace,
}
return store
}
// TokenStore redis token store
type redisStore struct {
cli *redis.Client
ctx context.Context
ns string
}
func (s *redisStore) wrapperKey(key string) string {
return fmt.Sprintf("%s%s", s.ns, key)
}
func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
if err := result.Err(); err != nil {
if err == redis.Nil {
return true, nil
}
return false, err
}
return false, nil
}
func (s *redisStore) remove(key string) error {
result := s.cli.Del(s.ctx, s.wrapperKey(key))
_, err := s.checkError(result)
return err
}
func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
basicID, err := s.getBasicID(tokenString)
if err != nil {
return err
} else if basicID == "" {
return nil
}
err = s.remove(tokenString)
if err != nil {
return err
}
token, err := s.getToken(basicID)
if err != nil {
return err
} else if token == nil {
return nil
}
checkToken := token.GetRefresh()
if isRefresh {
checkToken = token.GetAccess()
}
result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
if err = result.Err(); err != nil && err != redis.Nil {
return err
} else if result.Val() == 0 {
return s.remove(basicID)
}
return nil
}
func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
if ok, err := s.checkError(result); err != nil {
return nil, err
} else if ok {
return nil, nil
}
buf, err := result.Bytes()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
var token Token
if err = jsonUnmarshal(buf, &token); err != nil {
return nil, err
}
return &token, nil
}
func (s *redisStore) getToken(key string) (TokenInfo, error) {
result := s.cli.Get(s.ctx, s.wrapperKey(key))
return s.parseToken(result)
}
func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
if ok, err := s.checkError(result); err != nil {
return "", err
} else if ok {
return "", nil
}
return result.Val(), nil
}
func (s *redisStore) getBasicID(token string) (string, error) {
result := s.cli.Get(s.ctx, s.wrapperKey(token))
return s.parseBasicID(result)
}
// Create and store the new token information
func (s *redisStore) Create(info TokenInfo) error {
ct := time.Now()
jv, err := jsonMarshal(info)
if err != nil {
return err
}
pipe := s.cli.TxPipeline()
basicID := uuid.Must(uuid.NewRandom()).String()
aexp := info.GetAccessExpiresIn()
rexp := aexp
if refresh := info.GetRefresh(); refresh != "" {
rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
if aexp.Seconds() > rexp.Seconds() {
aexp = rexp
}
pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
}
pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
if _, err = pipe.Exec(s.ctx); err != nil {
return err
}
return nil
}
// RemoveByAccess Use the access token to delete the token information
func (s *redisStore) RemoveByAccess(access string) error {
return s.removeToken(access, false)
}
// RemoveByRefresh Use the refresh token to delete the token information
func (s *redisStore) RemoveByRefresh(refresh string) error {
return s.removeToken(refresh, true)
}
// GetByAccess Use the access token for token information data
func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
basicID, err := s.getBasicID(access)
if err != nil || basicID == "" {
return nil, err
}
return s.getToken(basicID)
}
// GetByRefresh Use the refresh token for token information data
func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
basicID, err := s.getBasicID(refresh)
if err != nil || basicID == "" {
return nil, err
}
return s.getToken(basicID)
}

118
pkg/auth/token.go Normal file
View File

@ -0,0 +1,118 @@
package auth
import (
"time"
)
// TokenInfo the token information model interface
type TokenInfo interface {
New() TokenInfo
GetUserID() string
SetUserID(string)
GetAccess() string
SetAccess(string)
GetAccessCreateAt() time.Time
SetAccessCreateAt(time.Time)
GetAccessExpiresIn() time.Duration
SetAccessExpiresIn(time.Duration)
GetRefresh() string
SetRefresh(string)
GetRefreshCreateAt() time.Time
SetRefreshCreateAt(time.Time)
GetRefreshExpiresIn() time.Duration
SetRefreshExpiresIn(time.Duration)
}
// NewToken create to token model instance
func NewToken() *Token {
return &Token{}
}
// Token token model
type Token struct {
UserID string `bson:"UserID"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
}
// New create to token model instance
func (t *Token) New() TokenInfo {
return NewToken()
}
// GetUserID the user id
func (t *Token) GetUserID() string {
return t.UserID
}
// SetUserID the user id
func (t *Token) SetUserID(userID string) {
t.UserID = userID
}
// GetAccess access Token
func (t *Token) GetAccess() string {
return t.Access
}
// SetAccess access Token
func (t *Token) SetAccess(access string) {
t.Access = access
}
// GetAccessCreateAt create Time
func (t *Token) GetAccessCreateAt() time.Time {
return t.AccessCreateAt
}
// SetAccessCreateAt create Time
func (t *Token) SetAccessCreateAt(createAt time.Time) {
t.AccessCreateAt = createAt
}
// GetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) GetAccessExpiresIn() time.Duration {
return t.AccessExpiresIn
}
// SetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) SetAccessExpiresIn(exp time.Duration) {
t.AccessExpiresIn = exp
}
// GetRefresh refresh Token
func (t *Token) GetRefresh() string {
return t.Refresh
}
// SetRefresh refresh Token
func (t *Token) SetRefresh(refresh string) {
t.Refresh = refresh
}
// GetRefreshCreateAt create Time
func (t *Token) GetRefreshCreateAt() time.Time {
return t.RefreshCreateAt
}
// SetRefreshCreateAt create Time
func (t *Token) SetRefreshCreateAt(createAt time.Time) {
t.RefreshCreateAt = createAt
}
// GetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) GetRefreshExpiresIn() time.Duration {
return t.RefreshExpiresIn
}
// SetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) SetRefreshExpiresIn(exp time.Duration) {
t.RefreshExpiresIn = exp
}

39
pkg/bcrypt/password.go Normal file
View File

@ -0,0 +1,39 @@
package bcrypt
import (
"fmt"
"golang.org/x/crypto/bcrypt"
)
var _ Password = (*password)(nil)
type Password interface {
i()
Generate(pwd string) string
Validate(pwd, hash string) bool
}
type password struct {
cost int
}
func (p *password) i() {}
func NewPassword(cost int) (Password, error) {
if cost < bcrypt.MinCost || cost > bcrypt.MaxCost {
return nil, fmt.Errorf("cost out of range")
}
return &password{
cost: cost,
}, nil
}
func (p *password) Generate(pwd string) string {
hash, _ := bcrypt.GenerateFromPassword([]byte(pwd), p.cost)
return string(hash)
}
func (p *password) Validate(pwd, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd))
return err == nil
}

23
pkg/browser/browser.go Normal file
View File

@ -0,0 +1,23 @@
package browser
import (
"fmt"
"os/exec"
"runtime"
)
var commands = map[string]string{
"windows": "start",
"darwin": "open",
"linux": "xdg-open",
}
func Open(uri string) error {
run, ok := commands[runtime.GOOS]
if !ok {
return fmt.Errorf("don't know how to open things on %s platform", runtime.GOOS)
}
cmd := exec.Command(run, uri)
return cmd.Start()
}

482
pkg/cache/redis.go vendored Normal file
View File

@ -0,0 +1,482 @@
package cache
import (
"context"
"errors"
"fmt"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"github.com/redis/go-redis/v9"
)
type Option func(*option)
type Trace = trace.T
type option struct {
Trace *trace.Trace
Redis *trace.Redis
}
type RedisConfig struct {
Addr string `yaml:"addr"`
Pass string `yaml:"pass"`
DB int `yaml:"db"`
MaxRetries int `yaml:"maxRetries"` // 最大重试次数
PoolSize int `yaml:"poolSize"` // Redis连接池大小
MinIdleConn int `yaml:"minIdleConn"` // 最小空闲连接数
}
func newOption() *option {
return &option{}
}
var _ Repo = (*cacheRepo)(nil)
type Repo interface {
i()
Client() *redis.Client
Set(key, value string, ttl time.Duration, options ...Option) error
Get(key string, options ...Option) (string, error)
TTL(key string) (time.Duration, error)
Expire(key string, ttl time.Duration) bool
ExpireAt(key string, ttl time.Time) bool
Del(key string, options ...Option) bool
Exists(keys ...string) bool
Incr(key string, options ...Option) (int64, error)
Decr(key string, options ...Option) (int64, error)
HGet(key, field string, options ...Option) (string, error)
HSet(key, field, value string, options ...Option) error
HDel(key, field string, options ...Option) error
HGetAll(key string, options ...Option) (map[string]string, error)
HIncrBy(key, field string, incr int64, options ...Option) (int64, error)
HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error)
LPush(key, value string, options ...Option) error
LLen(key string, options ...Option) (int64, error)
BRPop(key string, timeout time.Duration, options ...Option) (string, error)
Close() error
}
type cacheRepo struct {
client *redis.Client
ctx context.Context
}
func New(cfg RedisConfig) (Repo, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr,
Password: cfg.Pass,
DB: cfg.DB,
MaxRetries: cfg.MaxRetries,
PoolSize: cfg.PoolSize,
MinIdleConns: cfg.MinIdleConn,
})
ctx := context.TODO()
if err := client.Ping(ctx).Err(); err != nil {
return nil, errors.Join(err, errors.New("ping redis err"))
}
return &cacheRepo{
client: client,
ctx: ctx,
}, nil
}
func WithTrace(t Trace) Option {
return func(opt *option) {
if t != nil {
opt.Trace = t.(*trace.Trace)
opt.Redis = new(trace.Redis)
}
}
}
func (c *cacheRepo) i() {}
func (c *cacheRepo) Client() *redis.Client {
return c.client
}
func (c *cacheRepo) Set(key, value string, ttl time.Duration, options ...Option) error {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "set"
opt.Redis.Key = key
opt.Redis.Value = value
opt.Redis.TTL = ttl.Minutes()
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
if err := c.client.Set(c.ctx, key, value, ttl).Err(); err != nil {
return errors.Join(err, fmt.Errorf("redis set key: %s err", key))
}
return nil
}
func (c *cacheRepo) Get(key string, options ...Option) (string, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "get"
opt.Redis.Key = key
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.Get(c.ctx, key).Result()
if err != nil {
return "", errors.Join(err, fmt.Errorf("redis get key: %s err", key))
}
return value, nil
}
func (c *cacheRepo) TTL(key string) (time.Duration, error) {
ttl, err := c.client.TTL(c.ctx, key).Result()
if err != nil {
return -1, errors.Join(err, fmt.Errorf("redis get key: %s err", key))
}
return ttl, nil
}
func (c *cacheRepo) Expire(key string, ttl time.Duration) bool {
ok, _ := c.client.Expire(c.ctx, key, ttl).Result()
return ok
}
func (c *cacheRepo) ExpireAt(key string, ttl time.Time) bool {
ok, _ := c.client.ExpireAt(c.ctx, key, ttl).Result()
return ok
}
func (c *cacheRepo) Exists(keys ...string) bool {
if len(keys) == 0 {
return true
}
value, _ := c.client.Exists(c.ctx, keys...).Result()
return value > 0
}
func (c *cacheRepo) Del(key string, options ...Option) bool {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "del"
opt.Redis.Key = key
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
if key == "" {
return true
}
value, _ := c.client.Del(c.ctx, key).Result()
return value > 0
}
func (c *cacheRepo) Incr(key string, options ...Option) (int64, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "incr"
opt.Redis.Key = key
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.Incr(c.ctx, key).Result()
if err != nil {
return 0, errors.Join(err, fmt.Errorf("redis incr key: %s err", key))
}
return value, nil
}
func (c *cacheRepo) Decr(key string, options ...Option) (int64, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "decr"
opt.Redis.Key = key
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.Decr(c.ctx, key).Result()
if err != nil {
return 0, errors.Join(err, fmt.Errorf("redis decr key: %s err", key))
}
return value, nil
}
func (c *cacheRepo) HGet(key, field string, options ...Option) (string, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "hash get"
opt.Redis.Key = key
opt.Redis.Value = field
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.HGet(c.ctx, key, field).Result()
if err != nil {
return "", errors.Join(err, fmt.Errorf("redis hget key: %s field: %s err", key, field))
}
return value, nil
}
func (c *cacheRepo) HSet(key, field, value string, options ...Option) error {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "hash set"
opt.Redis.Key = key
opt.Redis.Value = field + "/" + value
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
if err := c.client.HSet(c.ctx, key, field, value).Err(); err != nil {
return errors.Join(err, fmt.Errorf("redis hset key: %s field: %s err", key, field))
}
return nil
}
func (c *cacheRepo) HDel(key, field string, options ...Option) error {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "hash del"
opt.Redis.Key = key
opt.Redis.Value = field
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
if err := c.client.HDel(c.ctx, key, field).Err(); err != nil {
return errors.Join(err, fmt.Errorf("redis hdel key: %s field: %s err", key, field))
}
return nil
}
func (c *cacheRepo) HGetAll(key string, options ...Option) (map[string]string, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "hash get all"
opt.Redis.Key = key
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.HGetAll(c.ctx, key).Result()
if err != nil {
return nil, errors.Join(err, fmt.Errorf("redis hget all key: %s err", key))
}
return value, nil
}
func (c *cacheRepo) HIncrBy(key, field string, incr int64, options ...Option) (int64, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "hash incr int64"
opt.Redis.Key = key
opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr)
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.HIncrBy(c.ctx, key, field, incr).Result()
if err != nil {
return 0, errors.Join(err, fmt.Errorf("redis hash incr int64 key: %s err", key))
}
return value, nil
}
func (c *cacheRepo) HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "hash incr float64"
opt.Redis.Key = key
opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr)
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.HIncrByFloat(c.ctx, key, field, incr).Result()
if err != nil {
return 0, errors.Join(err, fmt.Errorf("redis hash incr float64 key: %s err", key))
}
return value, nil
}
func (c *cacheRepo) LPush(key, value string, options ...Option) error {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "list push"
opt.Redis.Key = key
opt.Redis.Value = value
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
_, err := c.client.LPush(c.ctx, key, value).Result()
if err != nil {
return errors.Join(err, fmt.Errorf("redis list push key: %s value: %s err", key, value))
}
return nil
}
func (c *cacheRepo) LLen(key string, options ...Option) (int64, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "list len"
opt.Redis.Key = key
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.LLen(c.ctx, key).Result()
if err != nil {
return 0, errors.Join(err, fmt.Errorf("redis list len key: %s err", key))
}
return value, nil
}
func (c *cacheRepo) BRPop(key string, timeout time.Duration, options ...Option) (string, error) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Redis.Timestamp = time_parse.CSTLayoutString()
opt.Redis.Handle = "list brpop"
opt.Redis.Key = key
opt.Redis.TTL = timeout.Seconds()
opt.Redis.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendRedis(opt.Redis)
}
}()
for _, f := range options {
f(opt)
}
value, err := c.client.BRPop(c.ctx, timeout, key).Result()
if err != nil {
return "", errors.Join(err, fmt.Errorf("redis list len key: %s err", key))
}
return value[1], nil
}
func (c *cacheRepo) Close() error {
return c.client.Close()
}

97
pkg/captcha/base64.go Normal file
View File

@ -0,0 +1,97 @@
package captcha
import (
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/cache"
"github.com/mojocn/base64Captcha"
"go.uber.org/zap"
"strings"
"time"
)
var _ base64Captcha.Store = (*store)(nil)
type store struct {
cache cache.Repo
ttl time.Duration
logger *zap.Logger
ns string
prefix string
}
func (s *store) Set(id string, value string) error {
err := s.cache.Set(fmt.Sprintf("%s%s%s", s.ns, s.prefix, id), value, s.ttl)
if err != nil {
return err
}
return nil
}
func (s *store) Get(id string, clear bool) string {
value, err := s.cache.Get(fmt.Sprintf("%s%s%s", s.ns, s.prefix, id))
if err == nil && clear {
s.cache.Del(fmt.Sprintf("%s%s%s", s.ns, s.prefix, id))
}
return value
}
func (s *store) Verify(id, answer string, clear bool) bool {
value := s.Get(id, clear)
if value == "" || answer == "" {
return false
}
return strings.ToLower(value) == strings.ToLower(answer)
}
func NewStore(cache cache.Repo, ttl time.Duration, namespace string) base64Captcha.Store {
return &store{
cache: cache,
ttl: ttl,
ns: namespace,
prefix: "captcha:base64:",
}
}
var _ Captcha = (*captcha)(nil)
type Captcha interface {
Generate() (id, b64s, answer string, err error)
Verify(id, value string) bool
}
type captcha struct {
driver base64Captcha.Driver
store base64Captcha.Store
}
func NewStringCaptcha(store base64Captcha.Store, height, width, length int) Captcha {
conf := &base64Captcha.DriverString{
Height: height,
Width: width,
NoiseCount: length,
ShowLineOptions: base64Captcha.OptionShowHollowLine,
Length: length,
Source: "ABCDEFGHIJKMNPQRSTUVWXYZabcdefghijkmnpqrstuvwxyz0123456789",
}
return &captcha{
driver: conf.ConvertFonts(),
store: store,
}
}
func NewDigitCaptcha(store base64Captcha.Store, height, width, length int) Captcha {
conf := base64Captcha.NewDriverDigit(height, width, length, 0.7, height)
return &captcha{
driver: conf,
store: store,
}
}
func (c *captcha) Generate() (id, b64s, answer string, err error) {
newCaptcha := base64Captcha.NewCaptcha(c.driver, c.store)
return newCaptcha.Generate()
}
func (c *captcha) Verify(id, value string) bool {
return c.store.Verify(id, value, true)
}

119
pkg/cidr/calc.go Normal file
View File

@ -0,0 +1,119 @@
package cidr
import (
"bytes"
"fmt"
"math"
"net"
"sort"
)
// SuperNetting 合并网段
func SuperNetting(ns []string) (*cidr, error) {
num := len(ns)
if num < 1 || (num&(num-1)) != 0 {
return nil, fmt.Errorf("子网数量必须是2的次方")
}
mask := ""
var cidrs []*cidr
for _, n := range ns {
// 检查子网CIDR有效性
c, err := ParseCIDR(n)
if err != nil {
return nil, fmt.Errorf("网段%v格式错误", n)
}
cidrs = append(cidrs, c)
// TODO 暂只考虑相同子网掩码的网段合并
if len(mask) == 0 {
mask = c.Mask()
} else if c.Mask() != mask {
return nil, fmt.Errorf("子网掩码不一致")
}
}
AscSortCIDRs(cidrs)
// 检查网段是否连续
var network net.IP
for _, c := range cidrs {
if len(network) > 0 {
if !network.Equal(c.ipNet.IP) {
return nil, fmt.Errorf("必须是连续的网段")
}
}
network = net.ParseIP(c.Broadcast())
IncrIP(network)
}
// 子网掩码左移,得到共同的父网段
c := cidrs[0]
ones, bits := c.MaskSize()
ones = ones - int(math.Log2(float64(num)))
c.ipNet.Mask = net.CIDRMask(ones, bits)
c.ipNet.IP.Mask(c.ipNet.Mask)
return c, nil
}
// IncrIP IP地址自增
func IncrIP(ip net.IP) {
for i := len(ip) - 1; i >= 0; i-- {
ip[i]++
if ip[i] > 0 {
break
}
}
}
// DecrIP IP地址自减
func DecrIP(ip net.IP) {
length := len(ip)
for i := length - 1; i >= 0; i-- {
ip[length-1]--
if ip[length-1] < 0xFF {
break
}
for j := 1; j < length; j++ {
ip[length-j-1]--
if ip[length-j-1] < 0xFF {
return
}
}
}
}
// Compare 比较IP大小 a等于b返回0; a大于b返回+1; a小于b返回-1
func Compare(a, b net.IP) int {
return bytes.Compare(a, b)
}
// AscSortCIDRs 升序
func AscSortCIDRs(cs []*cidr) {
sort.Slice(cs, func(i, j int) bool {
if n := bytes.Compare(cs[i].ipNet.IP, cs[j].ipNet.IP); n != 0 {
return n < 0
}
if n := bytes.Compare(cs[i].ipNet.Mask, cs[j].ipNet.Mask); n != 0 {
return n < 0
}
return false
})
}
// DescSortCIDRs 降序
func DescSortCIDRs(cs []*cidr) {
sort.Slice(cs, func(i, j int) bool {
if n := bytes.Compare(cs[i].ipNet.IP, cs[j].ipNet.IP); n != 0 {
return n >= 0
}
if n := bytes.Compare(cs[i].ipNet.Mask, cs[j].ipNet.Mask); n != 0 {
return n >= 0
}
return false
})
}

194
pkg/cidr/ip.go Normal file
View File

@ -0,0 +1,194 @@
package cidr
import (
"encoding/hex"
"fmt"
"math"
"math/big"
"net"
)
// 裂解子网的方式
const (
MethodSubnetNum = 0 // 基于子网数量
MethodHostNum = 1 // 基于主机数量
)
var _ CIDR = (*cidr)(nil)
type CIDR interface {
CIDR() string
IP() string
Network() string
Broadcast() string
Mask() string
MaskSize() (int, int)
IPRange() (string, string)
IPCount() *big.Int
IsIPv4() bool
IsIPv6() bool
Equal(string) bool
Contains(string) bool
ForEachIP(func(string) error) error
ForEachIPBeginWith(string, func(string) error) error
SubNetting(method, num int) ([]*cidr, error)
}
type cidr struct {
ip net.IP
ipNet *net.IPNet
}
// ParseCIDR 解析CIDR网段
func ParseCIDR(s string) (*cidr, error) {
i, n, err := net.ParseCIDR(s)
if err != nil {
return nil, err
}
return &cidr{ip: i, ipNet: n}, nil
}
// Equal 判断网段是否相等
func (c *cidr) Equal(ns string) bool {
c2, err := ParseCIDR(ns)
if err != nil {
return false
}
return c.ipNet.IP.Equal(c2.ipNet.IP) /* && c.ipNet.IP.Equal(c2.ip) */
}
// IsIPv4 判断是否IPv4
func (c *cidr) IsIPv4() bool {
_, bits := c.ipNet.Mask.Size()
return bits/8 == net.IPv4len
}
// IsIPv6 判断是否IPv6
func (c *cidr) IsIPv6() bool {
_, bits := c.ipNet.Mask.Size()
return bits/8 == net.IPv6len
}
// Contains 判断IP是否包含在网段中
func (c *cidr) Contains(ip string) bool {
return c.ipNet.Contains(net.ParseIP(ip))
}
// CIDR 根据子网掩码长度校准后的CIDR
func (c *cidr) CIDR() string {
return c.ipNet.String()
}
// IP CIDR字符串中的IP部分
func (c *cidr) IP() string {
return c.ip.String()
}
// Network 网络号
func (c *cidr) Network() string {
return c.ipNet.IP.String()
}
// MaskSize 子网掩码位数
func (c *cidr) MaskSize() (ones, bits int) {
ones, bits = c.ipNet.Mask.Size()
return
}
// Mask 子网掩码
func (c *cidr) Mask() string {
mask, _ := hex.DecodeString(c.ipNet.Mask.String())
return net.IP([]byte(mask)).String()
}
// Broadcast 广播地址(网段最后一个IP)
func (c *cidr) Broadcast() string {
mask := c.ipNet.Mask
bcst := make(net.IP, len(c.ipNet.IP))
copy(bcst, c.ipNet.IP)
for i := 0; i < len(mask); i++ {
ipIdx := len(bcst) - i - 1
bcst[ipIdx] = c.ipNet.IP[ipIdx] | ^mask[len(mask)-i-1]
}
return bcst.String()
}
// IPRange 起始IP、结束IP
func (c *cidr) IPRange() (start, end string) {
return c.Network(), c.Broadcast()
}
// IPCount IP数量
func (c *cidr) IPCount() *big.Int {
ones, bits := c.ipNet.Mask.Size()
return big.NewInt(0).Lsh(big.NewInt(1), uint(bits-ones))
}
// ForEachIP 遍历网段下所有IP
func (c *cidr) ForEachIP(iterator func(ip string) error) error {
next := make(net.IP, len(c.ipNet.IP))
copy(next, c.ipNet.IP)
for c.ipNet.Contains(next) {
if err := iterator(next.String()); err != nil {
return err
}
IncrIP(next)
}
return nil
}
// ForEachIPBeginWith 从指定IP开始遍历网段下后续的IP
func (c *cidr) ForEachIPBeginWith(beginIP string, iterator func(ip string) error) error {
next := net.ParseIP(beginIP)
for c.ipNet.Contains(next) {
if err := iterator(next.String()); err != nil {
return err
}
IncrIP(next)
}
return nil
}
// SubNetting 裂解网段
func (c *cidr) SubNetting(method, num int) ([]*cidr, error) {
if num < 1 || (num&(num-1)) != 0 {
return nil, fmt.Errorf("裂解数量必须是2的次方")
}
newOnes := int(math.Log2(float64(num)))
ones, bits := c.MaskSize()
switch method {
default:
return nil, fmt.Errorf("不支持的裂解方式")
case MethodSubnetNum:
newOnes = ones + newOnes
// 如果子网的掩码长度大于父网段的长度,则无法裂解
if newOnes > bits {
return nil, nil
}
case MethodHostNum:
newOnes = bits - newOnes
// 如果子网的掩码长度小于等于父网段的掩码长度,则无法裂解
if newOnes <= ones {
return nil, nil
}
// 主机数量转换为子网数量
num = int(math.Pow(float64(2), float64(newOnes-ones)))
}
var cidrs []*cidr
network := make(net.IP, len(c.ipNet.IP))
copy(network, c.ipNet.IP)
for i := 0; i < num; i++ {
cidr, _ := ParseCIDR(fmt.Sprintf("%v/%v", network.String(), newOnes))
cidrs = append(cidrs, cidr)
// 广播地址的下一个IP即为下一段的网络号
network = net.ParseIP(cidr.Broadcast())
IncrIP(network)
}
return cidrs, nil
}

360
pkg/cmap/cmap.go Normal file
View File

@ -0,0 +1,360 @@
package cmap
import (
"encoding/json"
"fmt"
"sync"
)
var ShardCount = 32
type Stringer interface {
fmt.Stringer
comparable
}
// ConcurrentMap A "thread" safe map of type string:Anything.
// To avoid lock bottlenecks this map is dived to several (ShardCount) map shards.
type ConcurrentMap[K comparable, V any] struct {
shards []*ConcurrentMapShared[K, V]
sharding func(key K) uint32
}
// ConcurrentMapShared A "thread" safe string to anything map.
type ConcurrentMapShared[K comparable, V any] struct {
items map[K]V
sync.RWMutex // Read Write mutex, guards access to internal map.
}
func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
m := ConcurrentMap[K, V]{
sharding: sharding,
shards: make([]*ConcurrentMapShared[K, V], ShardCount),
}
for i := 0; i < ShardCount; i++ {
m.shards[i] = &ConcurrentMapShared[K, V]{items: make(map[K]V)}
}
return m
}
// New Creates a new concurrent map.
func New[V any]() ConcurrentMap[string, V] {
return create[string, V](fnv32)
}
// NewStringer Creates a new concurrent map.
func NewStringer[K Stringer, V any]() ConcurrentMap[K, V] {
return create[K, V](strfnv32[K])
}
// NewWithCustomShardingFunction Creates a new concurrent map.
func NewWithCustomShardingFunction[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
return create[K, V](sharding)
}
// GetShard returns shard under given key
func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] {
return m.shards[uint(m.sharding(key))%uint(ShardCount)]
}
func (m ConcurrentMap[K, V]) MSet(data map[K]V) {
for key, value := range data {
shard := m.GetShard(key)
shard.Lock()
shard.items[key] = value
shard.Unlock()
}
}
// Set Sets the given value under the specified key.
func (m ConcurrentMap[K, V]) Set(key K, value V) {
// Get map shard.
shard := m.GetShard(key)
shard.Lock()
shard.items[key] = value
shard.Unlock()
}
// UpsertCb Callback to return new element to be inserted into the map
// It is called while lock is held, therefore it MUST NOT
// try to access other keys in same map, as it can lead to deadlock since
// Go sync.RWLock is not reentrant
type UpsertCb[V any] func(exist bool, valueInMap V, newValue V) V
// Upsert Insert or Update - updates existing element or inserts a new one using UpsertCb
func (m ConcurrentMap[K, V]) Upsert(key K, value V, cb UpsertCb[V]) (res V) {
shard := m.GetShard(key)
shard.Lock()
v, ok := shard.items[key]
res = cb(ok, v, value)
shard.items[key] = res
shard.Unlock()
return res
}
// SetIfAbsent Sets the given value under the specified key if no value was associated with it.
func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool {
// Get map shard.
shard := m.GetShard(key)
shard.Lock()
_, ok := shard.items[key]
if !ok {
shard.items[key] = value
}
shard.Unlock()
return !ok
}
// Get retrieves an element from map under given key.
func (m ConcurrentMap[K, V]) Get(key K) (V, bool) {
// Get shard
shard := m.GetShard(key)
shard.RLock()
// Get item from shard.
val, ok := shard.items[key]
shard.RUnlock()
return val, ok
}
// Count returns the number of elements within the map.
func (m ConcurrentMap[K, V]) Count() int {
count := 0
for i := 0; i < ShardCount; i++ {
shard := m.shards[i]
shard.RLock()
count += len(shard.items)
shard.RUnlock()
}
return count
}
// Has Looks up an item under specified key
func (m ConcurrentMap[K, V]) Has(key K) bool {
// Get shard
shard := m.GetShard(key)
shard.RLock()
// See if element is within shard.
_, ok := shard.items[key]
shard.RUnlock()
return ok
}
// Remove removes an element from the map.
func (m ConcurrentMap[K, V]) Remove(key K) {
// Try to get shard.
shard := m.GetShard(key)
shard.Lock()
delete(shard.items, key)
shard.Unlock()
}
// RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held
// If returns true, the element will be removed from the map
type RemoveCb[K any, V any] func(key K, v V, exists bool) bool
// RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params
// If callback returns true and element exists, it will remove it from the map
// Returns the value returned by the callback (even if element was not present in the map)
func (m ConcurrentMap[K, V]) RemoveCb(key K, cb RemoveCb[K, V]) bool {
// Try to get shard.
shard := m.GetShard(key)
shard.Lock()
v, ok := shard.items[key]
remove := cb(key, v, ok)
if remove && ok {
delete(shard.items, key)
}
shard.Unlock()
return remove
}
// Pop removes an element from the map and returns it
func (m ConcurrentMap[K, V]) Pop(key K) (v V, exists bool) {
// Try to get shard.
shard := m.GetShard(key)
shard.Lock()
v, exists = shard.items[key]
delete(shard.items, key)
shard.Unlock()
return v, exists
}
// IsEmpty checks if map is empty.
func (m ConcurrentMap[K, V]) IsEmpty() bool {
return m.Count() == 0
}
// Tuple Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
type Tuple[K comparable, V any] struct {
Key K
Val V
}
// IterBuffered returns a buffered iterator which could be used in a for range loop.
func (m ConcurrentMap[K, V]) IterBuffered() <-chan Tuple[K, V] {
chans := snapshot(m)
total := 0
for _, c := range chans {
total += cap(c)
}
ch := make(chan Tuple[K, V], total)
go fanIn(chans, ch)
return ch
}
// Clear removes all items from map.
func (m ConcurrentMap[K, V]) Clear() {
for item := range m.IterBuffered() {
m.Remove(item.Key)
}
}
// Returns an array of channels that contains elements in each shard,
// which likely takes a snapshot of `m`.
// It returns once the size of each buffered channel is determined,
// before all the channels are populated using goroutines.
func snapshot[K comparable, V any](m ConcurrentMap[K, V]) (chans []chan Tuple[K, V]) {
//When you access map items before initializing.
if len(m.shards) == 0 {
panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`)
}
chans = make([]chan Tuple[K, V], ShardCount)
wg := sync.WaitGroup{}
wg.Add(ShardCount)
// Foreach shard.
for index, shard := range m.shards {
go func(index int, shard *ConcurrentMapShared[K, V]) {
// Foreach key, value pair.
shard.RLock()
chans[index] = make(chan Tuple[K, V], len(shard.items))
wg.Done()
for key, val := range shard.items {
chans[index] <- Tuple[K, V]{key, val}
}
shard.RUnlock()
close(chans[index])
}(index, shard)
}
wg.Wait()
return chans
}
// fanIn reads elements from channels `chans` into channel `out`
func fanIn[K comparable, V any](chans []chan Tuple[K, V], out chan Tuple[K, V]) {
wg := sync.WaitGroup{}
wg.Add(len(chans))
for _, ch := range chans {
go func(ch chan Tuple[K, V]) {
for t := range ch {
out <- t
}
wg.Done()
}(ch)
}
wg.Wait()
close(out)
}
// Items returns all items as map[string]V
func (m ConcurrentMap[K, V]) Items() map[K]V {
tmp := make(map[K]V)
// Insert items to temporary map.
for item := range m.IterBuffered() {
tmp[item.Key] = item.Val
}
return tmp
}
// IterCb Iterator callbacalled for every key,value found in
// maps. RLock is held for all calls for a given shard
// therefore callback sess consistent view of a shard,
// but not across the shards
type IterCb[K comparable, V any] func(key K, v V)
// IterCb Callback based iterator, cheapest way to read
// all elements in a map.
func (m ConcurrentMap[K, V]) IterCb(fn IterCb[K, V]) {
for idx := range m.shards {
shard := (m.shards)[idx]
shard.RLock()
for key, value := range shard.items {
fn(key, value)
}
shard.RUnlock()
}
}
// Keys returns all keys as []string
func (m ConcurrentMap[K, V]) Keys() []K {
count := m.Count()
ch := make(chan K, count)
go func() {
// Foreach shard.
wg := sync.WaitGroup{}
wg.Add(ShardCount)
for _, shard := range m.shards {
go func(shard *ConcurrentMapShared[K, V]) {
// Foreach key, value pair.
shard.RLock()
for key := range shard.items {
ch <- key
}
shard.RUnlock()
wg.Done()
}(shard)
}
wg.Wait()
close(ch)
}()
// Generate keys
keys := make([]K, 0, count)
for k := range ch {
keys = append(keys, k)
}
return keys
}
// MarshalJSON Reviles ConcurrentMap "private" variables to json marshal.
func (m ConcurrentMap[K, V]) MarshalJSON() ([]byte, error) {
// Create a temporary map, which will hold all item spread across shards.
tmp := make(map[K]V)
// Insert items to temporary map.
for item := range m.IterBuffered() {
tmp[item.Key] = item.Val
}
return json.Marshal(tmp)
}
func strfnv32[K fmt.Stringer](key K) uint32 {
return fnv32(key.String())
}
func fnv32(key string) uint32 {
hash := uint32(2166136261)
const prime32 = uint32(16777619)
keyLength := len(key)
for i := 0; i < keyLength; i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
}
// UnmarshalJSON Reverse process of Marshal.
func (m *ConcurrentMap[K, V]) UnmarshalJSON(b []byte) (err error) {
tmp := make(map[K]V)
// Unmarshal into a single map.
if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
// foreach key,value pair in temporary map insert into our concurrent map.
for key, val := range tmp {
m.Set(key, val)
}
return nil
}

View File

@ -0,0 +1,47 @@
//go:build darwin
// +build darwin
package color
import (
"fmt"
"math/rand"
"strconv"
)
var _ = RandomColor()
// RandomColor generates a random color.
func RandomColor() string {
return fmt.Sprintf("#%s", strconv.FormatInt(int64(rand.Intn(16777216)), 16))
}
// Yellow ...
func Yellow(msg string) string {
return fmt.Sprintf("\x1b[33m%s\x1b[0m", msg)
}
// Red ...
func Red(msg string) string {
return fmt.Sprintf("\x1b[31m%s\x1b[0m", msg)
}
// Redf ...
func Redf(msg string, arg any) string {
return fmt.Sprintf("\x1b[31m%s\x1b[0m %+v\n", msg, arg)
}
// Blue ...
func Blue(msg string) string {
return fmt.Sprintf("\x1b[34m%s\x1b[0m", msg)
}
// Green ...
func Green(msg string) string {
return fmt.Sprintf("\x1b[32m%s\x1b[0m", msg)
}
// Greenf ...
func Greenf(msg string, arg any) string {
return fmt.Sprintf("\x1b[32m%s\x1b[0m %+v\n", msg, arg)
}

47
pkg/color/string_linux.go Normal file
View File

@ -0,0 +1,47 @@
//go:build linux
// +build linux
package color
import (
"fmt"
"math/rand"
"strconv"
)
var _ = RandomColor()
// RandomColor generates a random color.
func RandomColor() string {
return fmt.Sprintf("#%s", strconv.FormatInt(int64(rand.Intn(16777216)), 16))
}
// Yellow ...
func Yellow(msg string) string {
return fmt.Sprintf("\x1b[33m%s\x1b[0m", msg)
}
// Red ...
func Red(msg string) string {
return fmt.Sprintf("\x1b[31m%s\x1b[0m", msg)
}
// Redf ...
func Redf(msg string, arg any) string {
return fmt.Sprintf("\x1b[31m%s\x1b[0m %+v\n", msg, arg)
}
// Blue ...
func Blue(msg string) string {
return fmt.Sprintf("\x1b[34m%s\x1b[0m", msg)
}
// Green ...
func Green(msg string) string {
return fmt.Sprintf("\x1b[32m%s\x1b[0m", msg)
}
// Greenf ...
func Greenf(msg string, arg any) string {
return fmt.Sprintf("\x1b[32m%s\x1b[0m %+v\n", msg, arg)
}

View File

@ -0,0 +1,47 @@
//go:build windows
// +build windows
package color
import (
"fmt"
"math/rand"
"strconv"
)
var _ = RandomColor()
// RandomColor generates a random color.
func RandomColor() string {
return fmt.Sprintf("#%s", strconv.FormatInt(int64(rand.Intn(16777216)), 16))
}
// Yellow ...
func Yellow(msg string) string {
return fmt.Sprintf("\033[33m%s\033[0m", msg)
}
// Red ...
func Red(msg string) string {
return fmt.Sprintf("\033[31m%s\033[0m", msg)
}
// Redf ...
func Redf(msg string, arg any) string {
return fmt.Sprintf("\033[31m%s\033[0m %+v\n", msg, arg)
}
// Blue ...
func Blue(msg string) string {
return fmt.Sprintf("\033[34m%s\033[0m", msg)
}
// Green ...
func Green(msg string) string {
return fmt.Sprintf("\033[32m%s\033[0m", msg)
}
// Greenf ...
func Greenf(msg string, arg any) string {
return fmt.Sprintf("\033[32m%s\033[0m %+v\n", msg, arg)
}

38
pkg/compress/compress.go Normal file
View File

@ -0,0 +1,38 @@
package compress
import (
"bytes"
"compress/zlib"
"io"
)
var _ Compress = (*compress)(nil)
type Compress interface {
DoZlibCompress(src []byte) []byte
DoZlibUnCompress(compressSrc []byte) []byte
}
type compress struct{}
func New() Compress {
return &compress{}
}
// DoZlibCompress 进行zlib压缩
func (c *compress) DoZlibCompress(src []byte) []byte {
var in bytes.Buffer
w := zlib.NewWriter(&in)
_, _ = w.Write(src)
_ = w.Close()
return in.Bytes()
}
// DoZlibUnCompress 进行zlib解压缩
func (c *compress) DoZlibUnCompress(compressSrc []byte) []byte {
b := bytes.NewReader(compressSrc)
var out bytes.Buffer
r, _ := zlib.NewReader(b)
_, _ = io.Copy(&out, r)
return out.Bytes()
}

40
pkg/crontab/crontab.go Normal file
View File

@ -0,0 +1,40 @@
package crontab
import (
"github.com/robfig/cron/v3"
)
var _ Crontab = (*crontab)(nil)
type Crontab interface {
i()
AddFunc(spec string, cmd func()) (entryID cron.EntryID, err error)
Entries() []cron.Entry
Stop()
}
type crontab struct {
cron *cron.Cron
}
func New() Crontab {
return &crontab{
cron: cron.New(),
}
}
func (c *crontab) i() {}
func (c *crontab) AddFunc(spec string, cmd func()) (entryID cron.EntryID, err error) {
entryID, err = c.cron.AddFunc(spec, cmd)
c.cron.Start()
return
}
func (c *crontab) Stop() {
c.cron.Stop()
}
func (c *crontab) Entries() []cron.Entry {
return c.cron.Entries()
}

77
pkg/database/mongo.go Normal file
View File

@ -0,0 +1,77 @@
package database
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"time"
)
var _ MongoDB = (*mongoDB)(nil)
type MongoDB interface {
i()
GetDB() *mongo.Database
Close() error
}
type MongoDBConfig struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
Timeout time.Duration `yaml:"timeout"`
}
type mongoDB struct {
client *mongo.Client
db *mongo.Database
timeout time.Duration
}
func (m *mongoDB) i() {}
func NewMongoDB(cfg MongoDBConfig) (MongoDB, error) {
timeout := cfg.Timeout * time.Second
connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
defer connectCancelFunc()
var auth string
if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
}
client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
))
if err != nil {
return nil, err
}
pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
defer pingCancelFunc()
err = client.Ping(pingCtx, readpref.Primary())
if err != nil {
return nil, err
}
return &mongoDB{
client: client,
db: client.Database(cfg.Name),
timeout: timeout,
}, nil
}
func (m *mongoDB) GetDB() *mongo.Database {
return m.db
}
func (m *mongoDB) Close() error {
disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
defer disconnectCancelFunc()
err := m.client.Disconnect(disconnectCtx)
if err != nil {
return err
}
return nil
}

247
pkg/database/mysql.go Normal file
View File

@ -0,0 +1,247 @@
package database
import (
"errors"
"fmt"
"log"
"os"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
const (
callBackBeforeName = "core:before"
callBackAfterName = "core:after"
startTime = "_start_time"
traceCtxName = "_trace_ctx_name"
)
var _ MysqlRepo = (*mysqlRepo)(nil)
type MysqlRepo interface {
i()
GetRead(options ...Option) *gorm.DB
GetWrite(options ...Option) *gorm.DB
Close() error
}
type MySQLConfig struct {
Read struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"read"`
Write struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"write"`
Base struct {
MaxOpenConn int `yaml:"maxOpenConn"` //最大连接数
MaxIdleConn int `yaml:"maxIdleConn"` //最大空闲连接数
ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
} `yaml:"base"`
}
type mysqlRepo struct {
read *gorm.DB
write *gorm.DB
}
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
if err != nil {
return nil, err
}
dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
if err != nil {
return nil, err
}
return &mysqlRepo{
read: dbr,
write: dbw,
}, nil
}
func (d *mysqlRepo) i() {}
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
}
db := d.read
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
}
return db
}
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
}
db := d.write
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
}
return db
}
func (d *mysqlRepo) Close() (err error) {
rdb, err1 := d.read.DB()
if err1 != nil {
err = errors.Join(err1)
}
err2 := rdb.Close()
if err2 != nil {
err = errors.Join(err2)
}
wdb, err3 := d.write.DB()
if err3 != nil {
err = errors.Join(err3)
}
err4 := wdb.Close()
if err4 != nil {
err = errors.Join(err4)
}
return err
}
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
user,
pass,
addr,
dbName,
true,
"Local")
// 日志配置
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Second, // 慢SQL阈值
Colorful: true, // 彩色打印
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
LogLevel: logger.Error, // 日志级别
},
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
Logger: newLogger,
})
if err != nil {
return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
}
db.Set("gorm:table_options", "CHARSET=utf8mb4")
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
// 设置连接池 用于设置最大打开的连接数默认值为0表示不限制.设置最大的连接数可以避免并发太高导致连接mysql出现too many connections的错误。
sqlDB.SetMaxOpenConns(maxOpenConn)
// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
sqlDB.SetMaxIdleConns(maxIdleConn)
// 设置最大连接超时
sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
// 使用插件
err = db.Use(&TracePlugin{})
if err != nil {
return nil, err
}
return db, nil
}
/***************************************************************/
type TracePlugin struct{}
func (op *TracePlugin) Name() string {
return "TracePlugin"
}
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
// 开始前
_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
// 结束后
_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
return
}
func before(db *gorm.DB) {
db.InstanceSet(startTime, time.Now())
}
func after(db *gorm.DB) {
_traceCtx, isExist := db.InstanceGet(traceCtxName)
if !isExist {
return
}
_trace, ok := _traceCtx.(trace.T)
if !ok {
return
}
_ts, isExist := db.InstanceGet(startTime)
if !isExist {
return
}
ts, ok := _ts.(time.Time)
if !ok {
return
}
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
sqlInfo := new(trace.SQL)
sqlInfo.Timestamp = time_parse.CSTLayoutString()
sqlInfo.SQL = sql
sqlInfo.Stack = utils.FileWithLineNum()
sqlInfo.Rows = db.Statement.RowsAffected
sqlInfo.CostSeconds = time.Since(ts).Seconds()
_trace.AppendSQL(sqlInfo)
}

88
pkg/database/tool.go Normal file
View File

@ -0,0 +1,88 @@
package database
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"gorm.io/gorm"
"reflect"
)
type NullTime sql.NullTime
func (n *NullTime) Scan(value any) error {
return (*sql.NullTime)(n).Scan(value)
}
func (n NullTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
func (n NullTime) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Time)
}
return json.Marshal(nil)
}
func (n *NullTime) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.Valid = false
return nil
}
err := json.Unmarshal(b, &n.Time)
if err == nil {
n.Valid = true
}
return err
}
/*-----------------------------------------------------------*/
type PaginateList struct {
Page int64 `json:"page"`
Size int64 `json:"size"`
Total int64 `json:"total"`
List any `json:"list"`
}
func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
ptr := reflect.ValueOf(model)
if ptr.Kind() != reflect.Ptr {
return nil, fmt.Errorf("model must be pointer")
}
var total int64
err := db.Model(model).Count(&total).Error
if err != nil {
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: make([]any, 0),
}, err
}
offset := size * (page - 1)
err = db.Limit(int(size)).Offset(int(offset)).Find(model).Error
if err != nil {
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: make([]any, 0),
}, err
}
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: model,
}, nil
}

23
pkg/database/trace.go Normal file
View File

@ -0,0 +1,23 @@
package database
import "git.bvbej.com/bvbej/base-golang/pkg/trace"
type Trace = trace.T
type Option func(*option)
func WithTrace(t Trace) Option {
return func(opt *option) {
if t != nil {
opt.Trace = t.(*trace.Trace)
}
}
}
func newOption() *option {
return &option{}
}
type option struct {
Trace *trace.Trace
}

26
pkg/ddm/README.md Normal file
View File

@ -0,0 +1,26 @@
## DDM
动态数据掩码Dynamic Data Masking简称为DDM能够防止把敏感数据暴露给未经授权的用户。
| 类型 | 要求 | 示例 | 说明
| ---- | ---- | ---- | ----
| 手机号 | 前 3 后 4 | 132****7986 | 定长 11 位数字
| 邮箱地址 | 前 1 后 1 | l**w@gmail.com | 仅对 @ 之前的邮箱名称进行掩码
| 姓名 | 隐姓 | *鸿章 | 将姓氏隐藏
| 密码 | 不输出 | ****** |
| 银行卡卡号 | 前 6 后 4 | 622888******5676 | 银行卡卡号最多 19 位数字
| 身份证号 | 前 1 后 1 | 1******7 | 定长 18 位
#### 代码示例
```
// 返回值
type message struct {
Email ddm.Email `json:"email"`
}
msg := new(message)
msg.Email = ddm.Email("xinliangnote@163.com")
...
```

35
pkg/ddm/benchmark.go Normal file
View File

@ -0,0 +1,35 @@
package ddm
import (
"github.com/mritd/chinaid"
)
type BType uint8
const (
BMobile BType = iota
BIDNo
BName
BBankNo
BEmail
BAddress
)
func Benchmark(bType BType) string {
switch bType {
case BMobile:
return chinaid.Mobile()
case BIDNo:
return chinaid.IDNo()
case BEmail:
return chinaid.Email()
case BAddress:
return chinaid.Address()
case BName:
return chinaid.Name()
case BBankNo:
return chinaid.BankNo()
default:
return ""
}
}

62
pkg/ddm/mark.go Normal file
View File

@ -0,0 +1,62 @@
package ddm
import (
"fmt"
"strings"
)
func (m Mobile) MarshalJSON() ([]byte, error) {
if len(m) != 11 {
return []byte(`"` + m + `"`), nil
}
v := fmt.Sprintf("%s****%s", m[:3], m[len(m)-4:])
return []byte(`"` + v + `"`), nil
}
func (bc BankCard) MarshalJSON() ([]byte, error) {
if len(bc) > 19 || len(bc) < 16 {
return []byte(`"` + bc + `"`), nil
}
v := fmt.Sprintf("%s******%s", bc[:6], bc[len(bc)-4:])
return []byte(`"` + v + `"`), nil
}
func (card IDCard) MarshalJSON() ([]byte, error) {
if len(card) != 18 {
return []byte(`"` + card + `"`), nil
}
v := fmt.Sprintf("%s******%s", card[:1], card[len(card)-1:])
return []byte(`"` + v + `"`), nil
}
func (name IDName) MarshalJSON() ([]byte, error) {
if len(name) < 1 {
return []byte(`""`), nil
}
nameRune := []rune(name)
v := fmt.Sprintf("*%s", string(nameRune[1:]))
return []byte(`"` + v + `"`), nil
}
func (pw PassWord) MarshalJSON() ([]byte, error) {
v := "******"
return []byte(`"` + v + `"`), nil
}
func (e Email) MarshalJSON() ([]byte, error) {
if !strings.Contains(string(e), "@") {
return []byte(`"` + e + `"`), nil
}
split := strings.Split(string(e), "@")
if len(split[0]) < 1 || len(split[1]) < 1 {
return []byte(`"` + e + `"`), nil
}
v := fmt.Sprintf("%s***%s", split[0][:1], split[0][len(split[0])-1:])
return []byte(`"` + v + "@" + split[1] + `"`), nil
}

19
pkg/ddm/type.go Normal file
View File

@ -0,0 +1,19 @@
package ddm
// 手机号 132****7986
type Mobile string
// 银行卡号 622888******5676
type BankCard string
// 身份证号 1******7
type IDCard string
// 姓名 *鸿章
type IDName string
// 密码 ******
type PassWord string
// 邮箱 l***w@gmail.com
type Email string

View File

@ -0,0 +1,264 @@
package downloader
import (
"errors"
"git.bvbej.com/bvbej/base-golang/pkg/downloader/base"
"git.bvbej.com/bvbej/base-golang/pkg/downloader/controller"
"git.bvbej.com/bvbej/base-golang/pkg/downloader/fetcher"
"git.bvbej.com/bvbej/base-golang/pkg/downloader/protocol/http"
"git.bvbej.com/bvbej/base-golang/pkg/downloader/util"
"github.com/google/uuid"
"net/url"
"strings"
"sync"
"time"
)
type Listener func(event *Event)
type TaskInfo struct {
ID string
Res *base.Resource
Opts *base.Options
Status base.Status
Progress *Progress
fetcher fetcher.Fetcher
timer *util.Timer
locker *sync.Mutex
}
type Progress struct {
// 下载耗时(纳秒)
Used int64
// 每秒下载字节数
Speed int64
// 已下载的字节数
Downloaded int64
}
type downloader struct {
*controller.DefaultController
fetchBuilders map[string]func() fetcher.Fetcher
task *TaskInfo
listener Listener
finished bool
finishedCh chan error
}
func newDownloader(f func() (protocols []string, builder func() fetcher.Fetcher), options ...controller.Option) *downloader {
d := &downloader{
DefaultController: controller.NewController(options...),
finishedCh: make(chan error, 1),
}
d.fetchBuilders = make(map[string]func() fetcher.Fetcher)
protocols, builder := f()
for _, p := range protocols {
d.fetchBuilders[strings.ToUpper(p)] = builder
}
return d
}
func (d *downloader) buildFetcher(URL string) (fetcher.Fetcher, error) {
parseURL, err := url.Parse(URL)
if err != nil {
return nil, err
}
if fetchBuilder, ok := d.fetchBuilders[strings.ToUpper(parseURL.Scheme)]; ok {
fetched := fetchBuilder()
fetched.Setup(d.DefaultController)
return fetched, nil
}
return nil, errors.New("unsupported protocol")
}
func (d *downloader) Resolve(req *base.Request) (*base.Resource, error) {
fetched, err := d.buildFetcher(req.URL)
if err != nil {
return nil, err
}
return fetched.Resolve(req)
}
func (d *downloader) Create(res *base.Resource, opts *base.Options) (err error) {
fetched, err := d.buildFetcher(res.Req.URL)
if err != nil {
return
}
if !res.Range || opts.Connections < 1 {
opts.Connections = 1
}
err = fetched.Create(res, opts)
if err != nil {
return
}
task := &TaskInfo{
ID: uuid.New().String(),
Res: res,
Opts: opts,
Status: base.DownloadStatusStart,
Progress: &Progress{},
fetcher: fetched,
timer: &util.Timer{},
locker: new(sync.Mutex),
}
d.task = task
task.timer.Start()
d.emit(EventKeyStart)
err = fetched.Start()
if err != nil {
return
}
go func() {
err = fetched.Wait()
if err != nil {
d.emit(EventKeyError, err)
task.Status = base.DownloadStatusError
} else {
task.Progress.Used = task.timer.Used()
if task.Res.TotalSize == 0 {
task.Res.TotalSize = task.fetcher.Progress().TotalDownloaded()
}
used := task.Progress.Used / int64(time.Second)
if used == 0 {
used = 1
}
task.Progress.Speed = task.Res.TotalSize / used
task.Progress.Downloaded = task.Res.TotalSize
d.emit(EventKeyDone)
task.Status = base.DownloadStatusDone
}
d.finished = true
d.emit(EventKeyFinally, err)
d.finishedCh <- err
}()
// 每秒统计一次下载速度
go func() {
for !d.finished {
if d.task.Status == base.DownloadStatusPause {
continue
}
current := d.task.fetcher.Progress().TotalDownloaded()
d.task.Progress.Used = d.task.timer.Used()
d.task.Progress.Speed = current - d.task.Progress.Downloaded
d.task.Progress.Downloaded = current
d.emit(EventKeyProgress)
time.Sleep(time.Second)
}
}()
return
}
func (d *downloader) Pause() error {
d.task.locker.Lock()
defer d.task.locker.Unlock()
d.task.timer.Pause()
err := d.task.fetcher.Pause()
if err != nil {
return err
}
d.emit(EventKeyPause)
d.task.Status = base.DownloadStatusPause
return nil
}
func (d *downloader) Continue() error {
d.task.locker.Lock()
defer d.task.locker.Unlock()
d.task.timer.Continue()
err := d.task.fetcher.Continue()
if err != nil {
return err
}
d.emit(EventKeyContinue)
d.task.Status = base.DownloadStatusStart
return nil
}
func (d *downloader) Listener(fn Listener) {
d.listener = fn
}
func (d *downloader) emit(eventKey EventKey, errs ...error) {
if d.listener != nil {
var err error
if len(errs) > 0 {
err = errs[0]
}
d.listener(&Event{
Key: eventKey,
Task: d.task,
Err: err,
})
}
}
var _ Boot = (*boot)(nil)
type Boot interface {
URL(url string) Boot
Extra(extra any) Boot
Listener(listener Listener) Boot
Create(opts *base.Options) <-chan error
}
type boot struct {
url string
extra any
listener Listener
downloader *downloader
}
func (b *boot) resolve() (*base.Resource, error) {
return b.downloader.Resolve(&base.Request{
URL: b.url,
Extra: b.extra,
})
}
func (b *boot) URL(url string) Boot {
b.url = url
return b
}
func (b *boot) Extra(extra any) Boot {
b.extra = extra
return b
}
func (b *boot) Listener(listener Listener) Boot {
b.listener = listener
return b
}
func (b *boot) Create(opts *base.Options) <-chan error {
res, err := b.resolve()
if err != nil {
b.downloader.finishedCh <- err
return b.downloader.finishedCh
}
b.downloader.Listener(b.listener)
err = b.downloader.Create(res, opts)
if err != nil {
b.downloader.finishedCh <- err
return b.downloader.finishedCh
}
return b.downloader.finishedCh
}
// New 一个文件对应一个实例
func New(options ...controller.Option) Boot {
return &boot{
downloader: newDownloader(http.FetcherBuilder, options...),
}
}

View File

@ -0,0 +1,45 @@
package downloader
import (
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/downloader/base"
"git.bvbej.com/bvbej/base-golang/pkg/downloader/controller"
"git.bvbej.com/bvbej/base-golang/tool"
"golang.org/x/net/proxy"
"net/http"
"net/url"
"runtime"
"testing"
"time"
)
func TestNewDownloader(t *testing.T) {
dialer, err := proxy.SOCKS5("tcp", "127.0.0.1:1080", nil, proxy.Direct)
parse, _ := url.Parse(`socks5://127.0.0.1:1080`)
if err != nil {
t.Fatal(err, dialer)
}
err = <-New(
controller.WithDialer(dialer), // todo: use dialer proxy
controller.WithCookie(nil),
controller.WithProxy(http.ProxyURL(parse)), // todo: use http client proxy
controller.WithTimeout(time.Second*3),
).URL("http://10.0.1.34/com.tencent.tmgp.jxqy.apk").
Listener(func(event *Event) {
if event.Key == EventKeyFinally {
fmt.Println("下载完成!")
}
if event.Key == EventKeyProgress {
fmt.Printf("下载速度:%s/s 已下载:%s 已用时:%s \n",
tool.ByteFmt(event.Task.Progress.Speed),
tool.ByteFmt(event.Task.Progress.Downloaded),
time.Duration(event.Task.Progress.Used),
)
}
}).
Create(&base.Options{
Connections: runtime.NumCPU(),
})
t.Log(err)
}

19
pkg/downloader/event.go Normal file
View File

@ -0,0 +1,19 @@
package downloader
type EventKey string
const (
EventKeyStart EventKey = "start"
EventKeyPause EventKey = "pause"
EventKeyContinue EventKey = "continue"
EventKeyProgress EventKey = "progress"
EventKeyError EventKey = "error"
EventKeyDone EventKey = "done"
EventKeyFinally EventKey = "finally"
)
type Event struct {
Key EventKey
Task *TaskInfo
Err error
}

340
pkg/duration_fmt/fmt.go Normal file
View File

@ -0,0 +1,340 @@
package duration_fmt
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
var (
//units, _ = DefaultUnitsCoder.Decode("year,week,day,hour,minute,second,millisecond,microsecond")
units, _ = DefaultUnitsCoder.Decode("年,星期,天,小时,分钟,秒,毫秒,微秒")
unitsShort = []string{"y", "w", "d", "h", "m", "s", "ms", "µs"}
)
// Durafmt holds the parsed duration and the original input duration.
type Durafmt struct {
duration time.Duration
input string // Used as reference.
limitN int // Non-zero to limit only first N elements to output.
limitUnit string // Non-empty to limit max unit
}
// LimitToUnit sets the output format, you will not have unit bigger than the UNIT specified. UNIT = "" means no restriction.
func (d *Durafmt) LimitToUnit(unit string) *Durafmt {
d.limitUnit = unit
return d
}
// LimitFirstN sets the output format, outputing only first N elements. n == 0 means no limit.
func (d *Durafmt) LimitFirstN(n int) *Durafmt {
d.limitN = n
return d
}
func (d *Durafmt) Duration() time.Duration {
return d.duration
}
// Truncate sets precision
func (d *Durafmt) Truncate(unit time.Duration) *Durafmt {
d.duration = d.duration.Truncate(unit)
return d
}
// Parse creates a new *Durafmt struct, returns error if input is invalid.
func Parse(dinput time.Duration) *Durafmt {
input := dinput.String()
return &Durafmt{dinput, input, 0, ""}
}
// ParseShort creates a new *Durafmt struct, short form, returns error if input is invalid.
// It's shortcut for `Parse(dur).LimitFirstN(1)`
func ParseShort(dinput time.Duration) *Durafmt {
input := dinput.String()
return &Durafmt{dinput, input, 1, ""}
}
// ParseString creates a new *Durafmt struct from a string.
// returns an error if input is invalid.
func ParseString(input string) (*Durafmt, error) {
if input == "0" || input == "-0" {
return nil, errors.New("durafmt: missing unit in duration " + input)
}
duration, err := time.ParseDuration(input)
if err != nil {
return nil, err
}
return &Durafmt{duration, input, 0, ""}, nil
}
// ParseStringShort creates a new *Durafmt struct from a string, short form
// returns an error if input is invalid.
// It's shortcut for `ParseString(durStr)` and then calling `LimitFirstN(1)`
func ParseStringShort(input string) (*Durafmt, error) {
if input == "0" || input == "-0" {
return nil, errors.New("durafmt: missing unit in duration " + input)
}
duration, err := time.ParseDuration(input)
if err != nil {
return nil, err
}
return &Durafmt{duration, input, 1, ""}, nil
}
// String parses d *Durafmt into a human readable duration with default units.
func (d *Durafmt) String() string {
return d.Format(units)
}
// Format parses d *Durafmt into a human readable duration with units.
func (d *Durafmt) Format(units Units) string {
var duration string
// Check for minus durations.
if string(d.input[0]) == "-" {
duration += "-"
d.duration = -d.duration
}
var microseconds int64
var milliseconds int64
var seconds int64
var minutes int64
var hours int64
var days int64
var weeks int64
var years int64
var shouldConvert = false
remainingSecondsToConvert := int64(d.duration / time.Microsecond)
// Convert duration.
if d.limitUnit == "" {
shouldConvert = true
}
if d.limitUnit == "years" || shouldConvert {
years = remainingSecondsToConvert / (365 * 24 * 3600 * 1000000)
remainingSecondsToConvert -= years * 365 * 24 * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "weeks" || shouldConvert {
weeks = remainingSecondsToConvert / (7 * 24 * 3600 * 1000000)
remainingSecondsToConvert -= weeks * 7 * 24 * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "days" || shouldConvert {
days = remainingSecondsToConvert / (24 * 3600 * 1000000)
remainingSecondsToConvert -= days * 24 * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "hours" || shouldConvert {
hours = remainingSecondsToConvert / (3600 * 1000000)
remainingSecondsToConvert -= hours * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "minutes" || shouldConvert {
minutes = remainingSecondsToConvert / (60 * 1000000)
remainingSecondsToConvert -= minutes * 60 * 1000000
shouldConvert = true
}
if d.limitUnit == "seconds" || shouldConvert {
seconds = remainingSecondsToConvert / 1000000
remainingSecondsToConvert -= seconds * 1000000
shouldConvert = true
}
if d.limitUnit == "milliseconds" || shouldConvert {
milliseconds = remainingSecondsToConvert / 1000
remainingSecondsToConvert -= milliseconds * 1000
}
microseconds = remainingSecondsToConvert
// Create a map of the converted duration time.
durationMap := []int64{
microseconds,
milliseconds,
seconds,
minutes,
hours,
days,
weeks,
years,
}
// Construct duration string.
for i, u := range units.Units() {
v := durationMap[7-i]
strval := strconv.FormatInt(v, 10)
switch {
// add to the duration string if v > 1.
case v > 1:
duration += strval + " " + u.Plural + " "
// remove the plural 's', if v is 1.
case v == 1:
duration += strval + " " + u.Singular + " "
// omit any value with 0s or 0.
case d.duration.String() == "0" || d.duration.String() == "0s":
pattern := fmt.Sprintf("^-?0%s$", unitsShort[i])
isMatch, err := regexp.MatchString(pattern, d.input)
if err != nil {
return ""
}
if isMatch {
duration += strval + " " + u.Plural
}
// omit any value with 0.
case v == 0:
continue
}
}
// trim any remaining spaces.
duration = strings.TrimSpace(duration)
// if more than 2 spaces present return the first 2 strings
// if short version is requested
if d.limitN > 0 {
parts := strings.Split(duration, " ")
if len(parts) > d.limitN*2 {
duration = strings.Join(parts[:d.limitN*2], " ")
}
}
return duration
}
func (d *Durafmt) InternationalString() string {
var duration string
// Check for minus durations.
if string(d.input[0]) == "-" {
duration += "-"
d.duration = -d.duration
}
var microseconds int64
var milliseconds int64
var seconds int64
var minutes int64
var hours int64
var days int64
var weeks int64
var years int64
var shouldConvert = false
remainingSecondsToConvert := int64(d.duration / time.Microsecond)
// Convert duration.
if d.limitUnit == "" {
shouldConvert = true
}
if d.limitUnit == "years" || shouldConvert {
years = remainingSecondsToConvert / (365 * 24 * 3600 * 1000000)
remainingSecondsToConvert -= years * 365 * 24 * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "weeks" || shouldConvert {
weeks = remainingSecondsToConvert / (7 * 24 * 3600 * 1000000)
remainingSecondsToConvert -= weeks * 7 * 24 * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "days" || shouldConvert {
days = remainingSecondsToConvert / (24 * 3600 * 1000000)
remainingSecondsToConvert -= days * 24 * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "hours" || shouldConvert {
hours = remainingSecondsToConvert / (3600 * 1000000)
remainingSecondsToConvert -= hours * 3600 * 1000000
shouldConvert = true
}
if d.limitUnit == "minutes" || shouldConvert {
minutes = remainingSecondsToConvert / (60 * 1000000)
remainingSecondsToConvert -= minutes * 60 * 1000000
shouldConvert = true
}
if d.limitUnit == "seconds" || shouldConvert {
seconds = remainingSecondsToConvert / 1000000
remainingSecondsToConvert -= seconds * 1000000
shouldConvert = true
}
if d.limitUnit == "milliseconds" || shouldConvert {
milliseconds = remainingSecondsToConvert / 1000
remainingSecondsToConvert -= milliseconds * 1000
}
microseconds = remainingSecondsToConvert
// Create a map of the converted duration time.
durationMap := map[string]int64{
"µs": microseconds,
"ms": milliseconds,
"s": seconds,
"m": minutes,
"h": hours,
"d": days,
"w": weeks,
"y": years,
}
// Construct duration string.
for i := range units.Units() {
u := unitsShort[i]
v := durationMap[u]
strval := strconv.FormatInt(v, 10)
switch {
// add to the duration string if v > 0.
case v > 0:
duration += strval + " " + u + " "
// omit any value with 0.
case d.duration.String() == "0":
pattern := fmt.Sprintf("^-?0%s$", unitsShort[i])
isMatch, err := regexp.MatchString(pattern, d.input)
if err != nil {
return ""
}
if isMatch {
duration += strval + " " + u
}
// omit any value with 0.
case v == 0:
continue
}
}
// trim any remaining spaces.
duration = strings.TrimSpace(duration)
// if more than 2 spaces present return the first 2 strings
// if short version is requested
if d.limitN > 0 {
parts := strings.Split(duration, " ")
if len(parts) > d.limitN*2 {
duration = strings.Join(parts[:d.limitN*2], " ")
}
}
return duration
}
func (d *Durafmt) TrimSpace() string {
return strings.Replace(d.String(), " ", "", -1)
}

107
pkg/duration_fmt/units.go Normal file
View File

@ -0,0 +1,107 @@
package duration_fmt
import (
"fmt"
"strings"
)
// DefaultUnitsCoder default units coder using `":"` as PluralSep and `","` as UnitsSep
var DefaultUnitsCoder = UnitsCoder{":", ","}
// Unit the pair of singular and plural units
type Unit struct {
Singular, Plural string
}
// Units duration units
type Units struct {
Year, Week, Day, Hour, Minute,
Second, Millisecond, Microsecond Unit
}
// Units return a slice of units
func (u Units) Units() []Unit {
return []Unit{u.Year, u.Week, u.Day, u.Hour, u.Minute,
u.Second, u.Millisecond, u.Microsecond}
}
// UnitsCoder the units encoder and decoder
type UnitsCoder struct {
// PluralSep char to sep singular and plural pair.
// Example with char `":"`: `"year:year"` (english) or `"mês:meses"` (portuguese)
PluralSep,
// UnitsSep char to sep units (singular and plural pairs).
// Example with char `","`: `"year:year,week:weeks"` (english) or `"mês:meses,semana:semanas"` (portuguese)
UnitsSep string
}
// Encode encodes input Units to string
// Examples with `UnitsCoder{PluralSep: ":", UnitsSep = ","}`
// - singular and plural pair units: `"year:wers,week:weeks,day:days,hour:hours,minute:minutes,second:seconds,millisecond:millliseconds,microsecond:microsseconds"`
func (coder UnitsCoder) Encode(units Units) string {
var pairs = make([]string, 8)
for i, u := range units.Units() {
pairs[i] = u.Singular + coder.PluralSep + u.Plural
}
return strings.Join(pairs, coder.UnitsSep)
}
// Decode decodes input string to Units.
// The input must follow the following formats:
// - Unit format (singular and plural pair)
// - must singular (the plural receives 's' character as suffix)
// - singular and plural: separated by `PluralSep` char
// Example with char `":"`: `"year:year"` (english) or `"mês:meses"` (portuguese)
// - Units format (pairs of Year, Week, Day, Hour, Minute,
// Second, Millisecond and Microsecond units) separated by `UnitsSep` char
// - Examples with `UnitsCoder{PluralSep: ":", UnitsSep = ","}`
// - must singular units: `"year,week,day,hour,minute,second,millisecond,microsecond"`
// - mixed units: `"year,week:weeks,day,hour,minute:minutes,second,millisecond,microsecond"`
// - singular and plural pair units: `"year:wers,week:weeks,day:days,hour:hours,minute:minutes,second:seconds,millisecond:millliseconds,microsecond:microsseconds"`
func (coder UnitsCoder) Decode(s string) (units Units, err error) {
parts := strings.Split(s, coder.UnitsSep)
if len(parts) != 8 {
err = fmt.Errorf("bad parts length")
return units, err
}
var parse = func(name, part string, u *Unit) bool {
ps := strings.Split(part, coder.PluralSep)
switch len(ps) {
case 1:
u.Singular, u.Plural = ps[0], ps[0]
case 2:
u.Singular, u.Plural = ps[0], ps[1]
default:
err = fmt.Errorf("bad unit %q pair length", name)
return false
}
return true
}
if !parse("Year", parts[0], &units.Year) {
return units, err
}
if !parse("Week", parts[1], &units.Week) {
return units, err
}
if !parse("Day", parts[2], &units.Day) {
return units, err
}
if !parse("Hour", parts[3], &units.Hour) {
return units, err
}
if !parse("Minute", parts[4], &units.Minute) {
return units, err
}
if !parse("Second", parts[5], &units.Second) {
return units, err
}
if !parse("Millisecond", parts[6], &units.Millisecond) {
return units, err
}
if !parse("Microsecond", parts[7], &units.Microsecond) {
return units, err
}
return units, err
}

80
pkg/env/env.go vendored Normal file
View File

@ -0,0 +1,80 @@
package env
import (
"fmt"
"strings"
)
var (
active Environment
dev Environment = &environment{value: "dev"} //开发环境
fat Environment = &environment{value: "fat"} //测试环境
uat Environment = &environment{value: "uat"} //预上线环境
pro Environment = &environment{value: "pro"} //正式环境
)
var _ Environment = (*environment)(nil)
// Environment 环境配置
type Environment interface {
Value() string
IsDev() bool
IsFat() bool
IsUat() bool
IsPro() bool
t()
}
type environment struct {
value string
}
func (e *environment) Value() string {
return e.value
}
func (e *environment) IsDev() bool {
return e.value == "dev"
}
func (e *environment) IsFat() bool {
return e.value == "fat"
}
func (e *environment) IsUat() bool {
return e.value == "uat"
}
func (e *environment) IsPro() bool {
return e.value == "pro"
}
func (e *environment) t() {}
func Set(env string) error {
var err error
switch strings.ToLower(strings.TrimSpace(env)) {
case "dev":
active = dev
case "fat":
active = fat
case "uat":
active = uat
case "pro":
active = pro
default:
err = fmt.Errorf("'%s' cannot be found, or it is illegal. enum:dev(开发环境),fat(测试环境),uat(预上线环境),pro(正式环境)", env)
}
return err
}
// Active 当前配置的env
func Active() Environment {
if active == nil {
fmt.Println("Warning: environment not set. The default 'dev' will be used.")
active = dev
}
return active
}

74
pkg/errno/errno.go Normal file
View File

@ -0,0 +1,74 @@
package errno
import (
"encoding/json"
)
var _ Error = (*err)(nil)
type Error interface {
// WithErr 设置错误信息
WithErr(err error) Error
// GetBusinessCode 获取 Business Code
GetBusinessCode() int
// GetHttpCode 获取 HTTP Code
GetHttpCode() int
// GetMsg 获取 Msg
GetMsg() string
// GetErr 获取错误信息
GetErr() error
// ToString 返回 JSON 格式的错误详情
ToString() string
}
type err struct {
HttpCode int // HTTP Code
BusinessCode int // Business Code
Message string // 描述信息
Err error // 错误信息
}
func NewError(httpCode, businessCode int, msg string) Error {
return &err{
HttpCode: httpCode,
BusinessCode: businessCode,
Message: msg,
}
}
func (e *err) WithErr(err error) Error {
e.Err = err
return e
}
func (e *err) GetHttpCode() int {
return e.HttpCode
}
func (e *err) GetBusinessCode() int {
return e.BusinessCode
}
func (e *err) GetMsg() string {
return e.Message
}
func (e *err) GetErr() error {
return e.Err
}
// ToString 返回 JSON 格式的错误详情
func (e *err) ToString() string {
err := &struct {
HttpCode int `json:"http_code"`
BusinessCode int `json:"business_code"`
Message string `json:"message"`
}{
HttpCode: e.HttpCode,
BusinessCode: e.BusinessCode,
Message: e.Message,
}
raw, _ := json.Marshal(err)
return string(raw)
}

111
pkg/excel/export.go Normal file
View File

@ -0,0 +1,111 @@
package excel
import (
"github.com/xuri/excelize/v2"
"strconv"
)
const maxCharCount = 26
var DefaultColumnWidth float64 = 20
type ColumnOption struct {
Field string
Comment string
Width float64
}
func Export(sheetName, filepath string, columns []ColumnOption, rows []map[string]any) error {
f := excelize.NewFile()
sheetIndex, err := f.NewSheet(sheetName)
if err != nil {
return err
}
_ = f.DeleteSheet("Sheet1")
_ = f.SetColWidth(sheetName, "A", string(byte('A'+len(columns)-1)), DefaultColumnWidth)
contentStyle, _ := f.NewStyle(&excelize.Style{
Alignment: &excelize.Alignment{
Horizontal: "center",
Vertical: "center",
WrapText: true,
},
})
titleStyle, _ := f.NewStyle(&excelize.Style{
Alignment: &excelize.Alignment{
Horizontal: "center",
Vertical: "center",
WrapText: true,
},
Font: &excelize.Font{
Bold: true,
Size: 14,
},
})
maxColumnRowNameLen := 1 + len(strconv.Itoa(len(rows)))
columnCount := len(columns)
if columnCount > maxCharCount {
maxColumnRowNameLen++
} else if columnCount > maxCharCount*maxCharCount {
maxColumnRowNameLen += 2
}
//标题
type columnItem struct {
RowName []byte
FieldName string
}
columnNames := make([]columnItem, 0)
for index, column := range columns {
columnName := getColumnName(index, maxColumnRowNameLen)
if column.Width > 0 {
_ = f.SetColWidth(sheetName, string(columnName), string(columnName), column.Width)
}
columnNames = append(columnNames, columnItem{FieldName: column.Field, RowName: columnName})
rowName := getColumnRowName(columnName, 1)
err := f.SetCellValue(sheetName, rowName, column.Comment)
if err != nil {
return err
}
_ = f.SetCellStyle(sheetName, rowName, rowName, titleStyle)
}
//正文
for rowIndex, row := range rows {
for _, item := range columnNames {
rowName := getColumnRowName(item.RowName, rowIndex+2)
err := f.SetCellValue(sheetName, rowName, row[item.FieldName])
if err != nil {
return err
}
_ = f.SetCellStyle(sheetName, rowName, rowName, contentStyle)
}
}
f.SetActiveSheet(sheetIndex)
err = f.SaveAs(filepath)
if err != nil {
return err
}
return nil
}
func getColumnName(column, maxColumnRowNameLen int) []byte {
const A = 'A'
if column < maxCharCount {
slice := make([]byte, 0, maxColumnRowNameLen)
return append(slice, byte(A+column))
} else {
return append(getColumnName(column/maxCharCount-1, maxColumnRowNameLen), byte(A+column%maxCharCount))
}
}
func getColumnRowName(columnName []byte, rowIndex int) (columnRowName string) {
l := len(columnName)
columnName = strconv.AppendInt(columnName, int64(rowIndex), 10)
columnRowName = string(columnName)
columnName = columnName[:l]
return
}

189
pkg/file/file.go Normal file
View File

@ -0,0 +1,189 @@
package file
import (
"bytes"
"fmt"
"io"
"os"
)
var (
buffSize = 1 << 20
)
// ReadLineFromEnd --
type ReadLineFromEnd struct {
f *os.File
fileSize int
bwr *bytes.Buffer
lineBuff []byte
swapBuff []byte
isFirst bool
}
// Exists
func IsExists(path string) (os.FileInfo, bool) {
f, err := os.Stat(path)
return f, err == nil || os.IsExist(err)
}
// NewReadLineFromEnd
func NewReadLineFromEnd(filename string) (rd *ReadLineFromEnd, err error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
info, err := f.Stat()
if err != nil {
return nil, err
}
if info.IsDir() {
return nil, fmt.Errorf("not file")
}
fileSize := int(info.Size())
rd = &ReadLineFromEnd{
f: f,
fileSize: fileSize,
bwr: bytes.NewBuffer([]byte{}),
lineBuff: make([]byte, 0),
swapBuff: make([]byte, buffSize),
isFirst: true,
}
return rd, nil
}
// ReadLine 结尾包含'\n'
func (c *ReadLineFromEnd) ReadLine() (line []byte, err error) {
var ok bool
for {
ok, err = c.buff()
if err != nil {
return nil, err
}
if ok {
break
}
}
line, err = c.bwr.ReadBytes('\n')
if err == io.EOF && c.fileSize > 0 {
err = nil
}
return line, err
}
// Close --
func (c *ReadLineFromEnd) Close() (err error) {
return c.f.Close()
}
func (c *ReadLineFromEnd) buff() (ok bool, err error) {
if c.fileSize == 0 {
return true, nil
}
if c.bwr.Len() >= buffSize {
return true, nil
}
offset := 0
if c.fileSize > buffSize {
offset = c.fileSize - buffSize
}
_, err = c.f.Seek(int64(offset), 0)
if err != nil {
return false, err
}
n, err := c.f.Read(c.swapBuff)
if err != nil && err != io.EOF {
return false, err
}
if c.fileSize < n {
n = c.fileSize
}
if n == 0 {
return true, nil
}
for {
m := bytes.LastIndex(c.swapBuff[:n], []byte{'\n'})
if m == -1 {
break
}
if m < n-1 {
err = c.writeLine(c.swapBuff[m+1 : n])
if err != nil {
return false, err
}
ok = true
} else if m == n-1 && !c.isFirst {
err = c.writeLine(nil)
if err != nil {
return false, err
}
ok = true
}
n = m
if n == 0 {
break
}
}
if n > 0 {
reverseBytes(c.swapBuff[:n])
c.lineBuff = append(c.lineBuff, c.swapBuff[:n]...)
}
if offset == 0 {
err = c.writeLine(nil)
if err != nil {
return false, err
}
ok = true
}
c.fileSize = offset
if c.isFirst {
c.isFirst = false
}
return ok, nil
}
func (c *ReadLineFromEnd) writeLine(b []byte) (err error) {
if len(b) > 0 {
_, err = c.bwr.Write(b)
if err != nil {
return err
}
}
if len(c.lineBuff) > 0 {
reverseBytes(c.lineBuff)
_, err = c.bwr.Write(c.lineBuff)
if err != nil {
return err
}
c.lineBuff = c.lineBuff[:0]
}
_, err = c.bwr.Write([]byte{'\n'})
if err != nil {
return err
}
return nil
}
func reverseBytes(b []byte) {
n := len(b)
if n <= 1 {
return
}
for i := 0; i < n; i++ {
k := n - 1
if k != i {
b[i], b[k] = b[k], b[i]
}
n--
}
}

98
pkg/grpclient/client.go Normal file
View File

@ -0,0 +1,98 @@
package grpclient
import (
"context"
"errors"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
)
var (
defaultDialTimeout = time.Second * 2
)
type Option func(*option)
type option struct {
credential credentials.TransportCredentials
keepalive *keepalive.ClientParameters
dialTimeout time.Duration
unaryInterceptor grpc.UnaryClientInterceptor
}
// WithCredential setup credential for tls
func WithCredential(credential credentials.TransportCredentials) Option {
return func(opt *option) {
opt.credential = credential
}
}
// WithKeepAlive setup keepalive parameters
func WithKeepAlive(keepalive *keepalive.ClientParameters) Option {
return func(opt *option) {
opt.keepalive = keepalive
}
}
// WithDialTimeout set up the dial timeout
func WithDialTimeout(timeout time.Duration) Option {
return func(opt *option) {
opt.dialTimeout = timeout
}
}
func WithUnaryInterceptor(unaryInterceptor grpc.UnaryClientInterceptor) Option {
return func(opt *option) {
opt.unaryInterceptor = unaryInterceptor
}
}
func New(target string, options ...Option) (*grpc.ClientConn, error) {
if target == "" {
return nil, errors.New("target required")
}
opt := new(option)
for _, f := range options {
f(opt)
}
kacp := defaultKeepAlive
if opt.keepalive != nil {
kacp = opt.keepalive
}
dialTimeout := defaultDialTimeout
if opt.dialTimeout > 0 {
dialTimeout = opt.dialTimeout
}
dialOptions := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithKeepaliveParams(*kacp),
}
if opt.unaryInterceptor != nil {
dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(opt.unaryInterceptor))
}
if opt.credential == nil {
dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
dialOptions = append(dialOptions, grpc.WithTransportCredentials(opt.credential))
}
ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
defer cancel()
conn, err := grpc.DialContext(ctx, target, dialOptions...)
if err != nil {
return nil, err
}
return conn, nil
}

View File

@ -0,0 +1,15 @@
package grpclient
import (
"time"
"google.golang.org/grpc/keepalive"
)
var (
defaultKeepAlive = &keepalive.ClientParameters{
Time: 10 * time.Second,
Timeout: time.Second,
PermitWithoutStream: true,
}
)

25
pkg/hash/hash.go Normal file
View File

@ -0,0 +1,25 @@
package hash
var _ Hash = (*hash)(nil)
type Hash interface {
i()
// hashids
HashidsEncode(params []int) (string, error)
HashidsDecode(hash string) ([]int, error)
}
type hash struct {
secret string
length int
}
func New(secret string, length int) Hash {
return &hash{
secret: secret,
length: length,
}
}
func (h *hash) i() {}

41
pkg/hash/hash_hashids.go Normal file
View File

@ -0,0 +1,41 @@
package hash
import (
"github.com/speps/go-hashids"
)
func (h *hash) HashidsEncode(params []int) (string, error) {
hd := hashids.NewData()
hd.Salt = h.secret
hd.MinLength = h.length
hashID, err := hashids.NewWithData(hd)
if err != nil {
return "", err
}
hashStr, err := hashID.Encode(params)
if err != nil {
return "", err
}
return hashStr, nil
}
func (h *hash) HashidsDecode(hash string) ([]int, error) {
hd := hashids.NewData()
hd.Salt = h.secret
hd.MinLength = h.length
hashID, err := hashids.NewWithData(hd)
if err != nil {
return nil, err
}
ids, err := hashID.DecodeWithError(hash)
if err != nil {
return nil, err
}
return ids, nil
}

55
pkg/hmac/hmac.go Normal file
View File

@ -0,0 +1,55 @@
package hmac
import (
baseHmac "crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
)
var _ HMAC = (*hmac)(nil)
type HMAC interface {
i()
Sha1ToString(data string) string
Sha1ToBase64String(data string) string
Sha256ToString(data string) string
Sha256ToBase64String(data string) string
}
type hmac struct {
secret string
}
func New(secret string) HMAC {
return &hmac{secret: secret}
}
func (m *hmac) i() {}
func (m *hmac) Sha256ToString(data string) string {
h := baseHmac.New(sha256.New, []byte(m.secret))
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func (m *hmac) Sha256ToBase64String(data string) string {
h := baseHmac.New(sha256.New, []byte(m.secret))
h.Write([]byte(data))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func (m *hmac) Sha1ToString(data string) string {
h := baseHmac.New(sha1.New, []byte(m.secret))
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func (m *hmac) Sha1ToBase64String(data string) string {
h := baseHmac.New(sha1.New, []byte(m.secret))
h.Write([]byte(data))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

29
pkg/httpclient/alarm.go Normal file
View File

@ -0,0 +1,29 @@
package httpclient
import (
"bufio"
"bytes"
"go.uber.org/zap"
)
// Verify parse the body and verify that it is correct
type AlarmVerify func(body []byte) (shouldAlarm bool)
type AlarmObject interface {
Send(subject, body string) error
}
func onFailedAlarm(title string, raw []byte, logger *zap.Logger, alarmObject AlarmObject) {
buf := bytes.NewBuffer(nil)
scanner := bufio.NewScanner(bytes.NewReader(raw))
for scanner.Scan() {
buf.WriteString(scanner.Text())
buf.WriteString(" <br/>")
}
if err := alarmObject.Send(title, buf.String()); err != nil && logger != nil {
logger.Error("calls failed alarm err", zap.Error(err))
}
}

395
pkg/httpclient/client.go Normal file
View File

@ -0,0 +1,395 @@
package httpclient
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/spf13/cast"
"net/http"
httpURL "net/url"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
)
const (
// DefaultTTL 一次http请求最长执行1分钟
DefaultTTL = time.Minute
)
// Get 请求
func Get(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
return withoutBody(http.MethodGet, url, form, options...)
}
// Delete delete 请求
func Delete(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
return withoutBody(http.MethodDelete, url, form, options...)
}
func withoutBody(method, url string, form httpURL.Values, options ...Option) (body []byte, err error) {
if url == "" {
return nil, errors.New("url required")
}
if len(form) > 0 {
if url, err = addFormValuesIntoURL(url, form); err != nil {
return
}
}
ts := time.Now()
opt := getOption()
defer func() {
if opt.trace != nil {
opt.dialog.Success = err == nil
opt.dialog.CostSeconds = time.Since(ts).Seconds()
opt.trace.AppendDialog(opt.dialog)
}
releaseOption(opt)
}()
for _, f := range options {
f(opt)
}
opt.header["Content-Type"] = []string{"application/x-www-form-urlencoded; charset=utf-8"}
if opt.trace != nil {
opt.header[trace.Header] = []string{opt.trace.ID()}
}
ttl := opt.ttl
if ttl <= 0 {
ttl = DefaultTTL
}
ctx, cancel := context.WithTimeout(context.Background(), ttl)
defer cancel()
if opt.dialog != nil {
decodedURL, _ := httpURL.QueryUnescape(url)
opt.dialog.Request = &trace.Request{
TTL: ttl.String(),
Method: method,
DecodedURL: decodedURL,
Header: opt.header,
}
}
retryTimes := opt.retryTimes
if retryTimes <= 0 {
retryTimes = DefaultRetryTimes
}
retryDelay := opt.retryDelay
if retryDelay <= 0 {
retryDelay = DefaultRetryDelay
}
var httpCode int
defer func() {
if opt.alarmObject == nil {
return
}
if opt.alarmVerify != nil && !opt.alarmVerify(body) && err == nil {
return
}
info := &struct {
TraceID string `json:"trace_id"`
Request struct {
Method string `json:"method"`
URL string `json:"url"`
} `json:"request"`
Response struct {
HTTPCode int `json:"http_code"`
Body string `json:"body"`
} `json:"response"`
Error string `json:"error"`
}{}
if opt.trace != nil {
info.TraceID = opt.trace.ID()
}
info.Request.Method = method
info.Request.URL = url
info.Response.HTTPCode = httpCode
info.Response.Body = string(body)
info.Error = ""
if err != nil {
info.Error = fmt.Sprintf("%+v", err)
}
raw, _ := json.MarshalIndent(info, "", " ")
onFailedAlarm(opt.alarmTitle, raw, opt.logger, opt.alarmObject)
}()
for k := 0; k < retryTimes; k++ {
body, httpCode, err = doHTTP(ctx, method, url, nil, opt)
if shouldRetry(ctx, httpCode) || (opt.retryVerify != nil && opt.retryVerify(body)) {
time.Sleep(retryDelay)
continue
}
return
}
return
}
// PostForm post form 请求
func PostForm(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
return withFormBody(http.MethodPost, url, form, options...)
}
// PostJSON post json 请求
func PostJSON(url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
return withJSONBody(http.MethodPost, url, raw, options...)
}
// PutForm put form 请求
func PutForm(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
return withFormBody(http.MethodPut, url, form, options...)
}
// PutJSON put json 请求
func PutJSON(url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
return withJSONBody(http.MethodPut, url, raw, options...)
}
// PatchFrom patch form 请求
func PatchFrom(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
return withFormBody(http.MethodPatch, url, form, options...)
}
// PatchJSON patch json 请求
func PatchJSON(url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
return withJSONBody(http.MethodPatch, url, raw, options...)
}
func withFormBody(method, url string, form httpURL.Values, options ...Option) (body []byte, err error) {
if url == "" {
return nil, errors.New("url required")
}
if len(form) == 0 {
return nil, errors.New("form required")
}
ts := time.Now()
opt := getOption()
defer func() {
if opt.trace != nil {
opt.dialog.Success = err == nil
opt.dialog.CostSeconds = time.Since(ts).Seconds()
opt.trace.AppendDialog(opt.dialog)
}
releaseOption(opt)
}()
for _, f := range options {
f(opt)
}
opt.header["Content-Type"] = []string{"application/x-www-form-urlencoded; charset=utf-8"}
if opt.trace != nil {
opt.header[trace.Header] = []string{opt.trace.ID()}
}
ttl := opt.ttl
if ttl <= 0 {
ttl = DefaultTTL
}
ctx, cancel := context.WithTimeout(context.Background(), ttl)
defer cancel()
formValue := form.Encode()
if opt.dialog != nil {
decodedURL, _ := httpURL.QueryUnescape(url)
opt.dialog.Request = &trace.Request{
TTL: ttl.String(),
Method: method,
DecodedURL: decodedURL,
Header: opt.header,
Body: formValue,
}
}
retryTimes := opt.retryTimes
if retryTimes <= 0 {
retryTimes = DefaultRetryTimes
}
retryDelay := opt.retryDelay
if retryDelay <= 0 {
retryDelay = DefaultRetryDelay
}
var httpCode int
defer func() {
if opt.alarmObject == nil {
return
}
if opt.alarmVerify != nil && !opt.alarmVerify(body) && err == nil {
return
}
info := &struct {
TraceID string `json:"trace_id"`
Request struct {
Method string `json:"method"`
URL string `json:"url"`
} `json:"request"`
Response struct {
HTTPCode int `json:"http_code"`
Body string `json:"body"`
} `json:"response"`
Error string `json:"error"`
}{}
if opt.trace != nil {
info.TraceID = opt.trace.ID()
}
info.Request.Method = method
info.Request.URL = url
info.Response.HTTPCode = httpCode
info.Response.Body = string(body)
info.Error = ""
if err != nil {
info.Error = fmt.Sprintf("%+v", err)
}
raw, _ := json.MarshalIndent(info, "", " ")
onFailedAlarm(opt.alarmTitle, raw, opt.logger, opt.alarmObject)
}()
for k := 0; k < retryTimes; k++ {
body, httpCode, err = doHTTP(ctx, method, url, []byte(formValue), opt)
if shouldRetry(ctx, httpCode) || (opt.retryVerify != nil && opt.retryVerify(body)) {
time.Sleep(retryDelay)
continue
}
return
}
return
}
func withJSONBody(method, url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
if url == "" {
return nil, errors.New("url required")
}
if len(raw) == 0 {
return nil, errors.New("raw required")
}
ts := time.Now()
opt := getOption()
defer func() {
if opt.trace != nil {
opt.dialog.Success = err == nil
opt.dialog.CostSeconds = time.Since(ts).Seconds()
opt.trace.AppendDialog(opt.dialog)
}
releaseOption(opt)
}()
for _, f := range options {
f(opt)
}
opt.header["Content-Type"] = []string{"application/json; charset=utf-8"}
if opt.trace != nil {
opt.header[trace.Header] = []string{opt.trace.ID()}
}
ttl := opt.ttl
if ttl <= 0 {
ttl = DefaultTTL
}
ctx, cancel := context.WithTimeout(context.Background(), ttl)
defer cancel()
if opt.dialog != nil {
decodedURL, _ := httpURL.QueryUnescape(url)
opt.dialog.Request = &trace.Request{
TTL: ttl.String(),
Method: method,
DecodedURL: decodedURL,
Header: opt.header,
Body: cast.ToString(raw),
}
}
retryTimes := opt.retryTimes
if retryTimes <= 0 {
retryTimes = DefaultRetryTimes
}
retryDelay := opt.retryDelay
if retryDelay <= 0 {
retryDelay = DefaultRetryDelay
}
var httpCode int
defer func() {
if opt.alarmObject == nil {
return
}
if opt.alarmVerify != nil && !opt.alarmVerify(body) && err == nil {
return
}
info := &struct {
TraceID string `json:"trace_id"`
Request struct {
Method string `json:"method"`
URL string `json:"url"`
} `json:"request"`
Response struct {
HTTPCode int `json:"http_code"`
Body string `json:"body"`
} `json:"response"`
Error string `json:"error"`
}{}
if opt.trace != nil {
info.TraceID = opt.trace.ID()
}
info.Request.Method = method
info.Request.URL = url
info.Response.HTTPCode = httpCode
info.Response.Body = string(body)
info.Error = ""
if err != nil {
info.Error = fmt.Sprintf("%+v", err)
}
raw, _ := json.MarshalIndent(info, "", " ")
onFailedAlarm(opt.alarmTitle, raw, opt.logger, opt.alarmObject)
}()
for k := 0; k < retryTimes; k++ {
body, httpCode, err = doHTTP(ctx, method, url, raw, opt)
if shouldRetry(ctx, httpCode) || (opt.retryVerify != nil && opt.retryVerify(body)) {
time.Sleep(retryDelay)
continue
}
return
}
return
}

46
pkg/httpclient/error.go Normal file
View File

@ -0,0 +1,46 @@
package httpclient
var _ ReplyErr = (*replyErr)(nil)
// ReplyErr 错误响应,当 resp.StatusCode != http.StatusOK 时用来包装返回的 httpcode 和 body 。
type ReplyErr interface {
error
StatusCode() int
Body() []byte
}
type replyErr struct {
err error
statusCode int
body []byte
}
func (r *replyErr) Error() string {
return r.err.Error()
}
func (r *replyErr) StatusCode() int {
return r.statusCode
}
func (r *replyErr) Body() []byte {
return r.body
}
func newReplyErr(statusCode int, body []byte, err error) ReplyErr {
return &replyErr{
statusCode: statusCode,
body: body,
err: err,
}
}
// ToReplyErr 尝试将 err 转换为 ReplyErr
func ToReplyErr(err error) (ReplyErr, bool) {
if err == nil {
return nil, false
}
e, ok := err.(ReplyErr)
return e, ok
}

138
pkg/httpclient/option.go Normal file
View File

@ -0,0 +1,138 @@
package httpclient
import (
"sync"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"go.uber.org/zap"
)
var (
cache = &sync.Pool{
New: func() any {
return &option{
header: make(map[string][]string),
}
},
}
)
// Mock 定义接口Mock数据
type Mock func() (body []byte)
// Option 自定义设置http请求
type Option func(*option)
type basicAuth struct {
username string
password string
}
type option struct {
ttl time.Duration
basicAuth *basicAuth
header map[string][]string
trace *trace.Trace
dialog *trace.Dialog
logger *zap.Logger
retryTimes int
retryDelay time.Duration
retryVerify RetryVerify
alarmTitle string
alarmObject AlarmObject
alarmVerify AlarmVerify
mock Mock
}
func (o *option) reset() {
o.ttl = 0
o.basicAuth = nil
o.header = make(map[string][]string)
o.trace = nil
o.dialog = nil
o.logger = nil
o.retryTimes = 0
o.retryDelay = 0
o.retryVerify = nil
o.alarmTitle = ""
o.alarmObject = nil
o.alarmVerify = nil
o.mock = nil
}
func getOption() *option {
return cache.Get().(*option)
}
func releaseOption(opt *option) {
opt.reset()
cache.Put(opt)
}
// WithTTL 本次http请求最长执行时间
func WithTTL(ttl time.Duration) Option {
return func(opt *option) {
opt.ttl = ttl
}
}
// WithHeader 设置http header可以调用多次设置多对key-value
func WithHeader(key, value string) Option {
return func(opt *option) {
opt.header[key] = []string{value}
}
}
// WithBasicAuth 设置基础认证权限
func WithBasicAuth(username, password string) Option {
return func(opt *option) {
opt.basicAuth = &basicAuth{
username: username,
password: password,
}
}
}
// WithTrace 设置trace信息
func WithTrace(t trace.T) Option {
return func(opt *option) {
if t != nil {
opt.trace = t.(*trace.Trace)
opt.dialog = new(trace.Dialog)
}
}
}
// WithLogger 设置logger以便打印关键日志
func WithLogger(logger *zap.Logger) Option {
return func(opt *option) {
opt.logger = logger
}
}
// WithMock 设置 mock 数据
func WithMock(m Mock) Option {
return func(opt *option) {
opt.mock = m
}
}
// WithOnFailedAlarm 设置告警通知
func WithOnFailedAlarm(alarmTitle string, alarmObject AlarmObject, alarmVerify AlarmVerify) Option {
return func(opt *option) {
opt.alarmTitle = alarmTitle
opt.alarmObject = alarmObject
opt.alarmVerify = alarmVerify
}
}
// WithOnFailedRetry 设置失败重试
func WithOnFailedRetry(retryTimes int, retryDelay time.Duration, retryVerify RetryVerify) Option {
return func(opt *option) {
opt.retryTimes = retryTimes
opt.retryDelay = retryDelay
opt.retryVerify = retryVerify
}
}

44
pkg/httpclient/retry.go Normal file
View File

@ -0,0 +1,44 @@
package httpclient
import (
"context"
"net/http"
"time"
)
const (
// DefaultRetryTimes 如果请求失败最多重试3次
DefaultRetryTimes = 3
// DefaultRetryDelay 在重试前延迟等待100毫秒
DefaultRetryDelay = time.Millisecond * 100
)
// Verify parse the body and verify that it is correct
type RetryVerify func(body []byte) (shouldRetry bool)
func shouldRetry(ctx context.Context, httpCode int) bool {
select {
case <-ctx.Done():
return false
default:
}
switch httpCode {
case
_StatusReadRespErr,
_StatusDoReqErr,
http.StatusRequestTimeout,
http.StatusLocked,
http.StatusTooEarly,
http.StatusTooManyRequests,
http.StatusServiceUnavailable,
http.StatusGatewayTimeout:
return true
default:
return false
}
}

147
pkg/httpclient/util.go Normal file
View File

@ -0,0 +1,147 @@
package httpclient
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"go.uber.org/zap"
)
const (
// _StatusReadRespErr read resp body err, should re-call doHTTP again.
_StatusReadRespErr = -204
// _StatusDoReqErr do req err, should re-call doHTTP again.
_StatusDoReqErr = -500
)
var defaultClient = &http.Client{
Transport: &http.Transport{
DisableKeepAlives: true,
DisableCompression: true,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
MaxIdleConns: 100,
MaxConnsPerHost: 100,
MaxIdleConnsPerHost: 100,
},
}
func doHTTP(ctx context.Context, method, url string, payload []byte, opt *option) ([]byte, int, error) {
ts := time.Now()
if mock := opt.mock; mock != nil {
if opt.dialog != nil {
opt.dialog.AppendResponse(&trace.Response{
HttpCode: http.StatusOK,
HttpCodeMsg: http.StatusText(http.StatusOK),
Body: string(mock()),
CostSeconds: time.Since(ts).Seconds(),
})
}
return mock(), http.StatusOK, nil
}
req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(payload))
if err != nil {
return nil, -1, errors.Join(err, fmt.Errorf("new request [%s %s] err", method, url))
}
for key, value := range opt.header {
req.Header.Set(key, value[0])
}
if opt.basicAuth != nil {
req.SetBasicAuth(opt.basicAuth.username, opt.basicAuth.password)
}
resp, err := defaultClient.Do(req)
if err != nil {
err = errors.Join(err, fmt.Errorf("do request [%s %s] err", method, url))
if opt.dialog != nil {
opt.dialog.AppendResponse(&trace.Response{
Body: err.Error(),
CostSeconds: time.Since(ts).Seconds(),
})
}
if opt.logger != nil {
opt.logger.Warn("doHTTP got err", zap.Error(err))
}
return nil, _StatusDoReqErr, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
err = errors.Join(err, fmt.Errorf("read resp body from [%s %s] err", method, url))
if opt.dialog != nil {
opt.dialog.AppendResponse(&trace.Response{
Body: err.Error(),
CostSeconds: time.Since(ts).Seconds(),
})
}
if opt.logger != nil {
opt.logger.Warn("doHTTP got err", zap.Error(err))
}
return nil, _StatusReadRespErr, err
}
defer func() {
if opt.dialog != nil {
opt.dialog.AppendResponse(&trace.Response{
Header: resp.Header,
HttpCode: resp.StatusCode,
HttpCodeMsg: resp.Status,
Body: string(body), // unsafe
CostSeconds: time.Since(ts).Seconds(),
})
}
}()
if resp.StatusCode != http.StatusOK {
return nil, resp.StatusCode, newReplyErr(
resp.StatusCode,
body,
fmt.Errorf("do [%s %s] return code: %d message: %s", method, url, resp.StatusCode, string(body)),
)
}
return body, http.StatusOK, nil
}
// addFormValuesIntoURL append url.Values into url string
func addFormValuesIntoURL(rawURL string, form url.Values) (string, error) {
if rawURL == "" {
return "", errors.New("rawURL required")
}
if len(form) == 0 {
return "", errors.New("form required")
}
target, err := url.Parse(rawURL)
if err != nil {
return "", errors.Join(err, fmt.Errorf("parse rawURL `%s` err", rawURL))
}
urlValues := target.Query()
for key, values := range form {
for _, value := range values {
urlValues.Add(key, value)
}
}
target.RawQuery = urlValues.Encode()
return target.String(), nil
}

71
pkg/limiter/rate.go Normal file
View File

@ -0,0 +1,71 @@
package limiter
import (
"git.bvbej.com/bvbej/base-golang/pkg/ticker"
"golang.org/x/time/rate"
"sync"
"time"
)
var _ RateLimiter = (*rateLimiter)(nil)
type item struct {
lastTime time.Time
limiter *rate.Limiter
}
type RateLimiter interface {
set(key string) *rate.Limiter
get(key string) *rate.Limiter
Allow(key string) bool
}
type rateLimiter struct {
limit rate.Limit
burst int
list *sync.Map
recycle ticker.Ticker
}
func NewRateLimiter(limit rate.Limit, burst int) RateLimiter {
list := new(sync.Map)
t := ticker.New(time.Minute)
t.Process(func() {
list.Range(func(key, value any) bool {
if value.(*item).lastTime.Before(time.Now().Add(-time.Hour)) {
list.Delete(key)
}
return true
})
})
return &rateLimiter{
list: list,
limit: limit,
recycle: t,
burst: burst,
}
}
func (i *rateLimiter) set(key string) *rate.Limiter {
store := &item{
lastTime: time.Now(),
limiter: rate.NewLimiter(i.limit, i.burst),
}
i.list.Store(key, store)
return store.limiter
}
func (i *rateLimiter) get(key string) *rate.Limiter {
value, ok := i.list.Load(key)
if !ok {
return i.set(key)
}
value.(*item).lastTime = time.Now()
i.list.Store(key, value)
return value.(*item).limiter
}
func (i *rateLimiter) Allow(key string) bool {
limiter := i.get(key)
return limiter.Allow()
}

47
pkg/lock/lock.go Normal file
View File

@ -0,0 +1,47 @@
package lock
import (
"sync"
"sync/atomic"
)
var _ Locker = (*locker)(nil)
type Locker interface {
condition() bool
Lock()
Unlock()
}
type locker struct {
lock *sync.Mutex
cond *sync.Cond
v *atomic.Bool
}
func NewLocker() Locker {
lock := new(sync.Mutex)
return &locker{
lock: lock,
cond: sync.NewCond(lock),
v: new(atomic.Bool),
}
}
func (l *locker) condition() bool {
return l.v.Load()
}
func (l *locker) Lock() {
l.cond.L.Lock()
for l.condition() {
l.cond.Wait()
}
l.v.Store(true)
}
func (l *locker) Unlock() {
l.v.Store(false)
l.cond.L.Unlock()
l.cond.Signal()
}

275
pkg/logger/logger.go Normal file
View File

@ -0,0 +1,275 @@
package logger
import (
"io"
"os"
"path/filepath"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
const (
// DefaultLevel the default log level
DefaultLevel = zapcore.InfoLevel
// DefaultTimeLayout the default time layout;
DefaultTimeLayout = time.DateTime
)
// Option custom setup config
type Option func(*option)
type option struct {
level zapcore.Level
fields map[string]string
file io.Writer
timeLayout string
disableConsole bool
disableCaller bool
}
// WithDebugLevel only greater than 'level' will output
func WithDebugLevel() Option {
return func(opt *option) {
opt.level = zapcore.DebugLevel
}
}
// WithInfoLevel only greater than 'level' will output
func WithInfoLevel() Option {
return func(opt *option) {
opt.level = zapcore.InfoLevel
}
}
// WithWarnLevel only greater than 'level' will output
func WithWarnLevel() Option {
return func(opt *option) {
opt.level = zapcore.WarnLevel
}
}
// WithErrorLevel only greater than 'level' will output
func WithErrorLevel() Option {
return func(opt *option) {
opt.level = zapcore.ErrorLevel
}
}
// WithField add some field(s) to log
func WithField(key, value string) Option {
return func(opt *option) {
opt.fields[key] = value
}
}
// WithFileP write log to some file
func WithFileP(file string) Option {
dir := filepath.Dir(file)
if err := os.MkdirAll(dir, 0766); err != nil {
panic(err)
}
f, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0766)
if err != nil {
panic(err)
}
return func(opt *option) {
opt.file = zapcore.Lock(f)
}
}
// WithFileRotationP write log to some file with rotation
func WithFileRotationP(file string) Option {
dir := filepath.Dir(file)
if err := os.MkdirAll(dir, 0766); err != nil {
panic(err)
}
return func(opt *option) {
opt.file = &lumberjack.Logger{ // concurrent-safed
Filename: file, // 文件路径
MaxSize: 128, // 单个文件最大尺寸,默认单位 M
MaxBackups: 300, // 最多保留 300 个备份
MaxAge: 30, // 最大时间,默认单位 day
LocalTime: true, // 使用本地时间
Compress: true, // 是否压缩 disabled by default
}
}
}
// WithTimeLayout custom time format
func WithTimeLayout(timeLayout string) Option {
return func(opt *option) {
opt.timeLayout = timeLayout
}
}
// WithDisableConsole WithEnableConsole write log to os.Stdout or os.Stderr
func WithDisableConsole() Option {
return func(opt *option) {
opt.disableConsole = true
}
}
// WithDisableCaller disable caller field
func WithDisableCaller() Option {
return func(opt *option) {
opt.disableCaller = true
}
}
// NewJSONLogger return a json-encoder zap logger,
func NewJSONLogger(opts ...Option) (*zap.Logger, error) {
opt := &option{level: DefaultLevel, fields: make(map[string]string)}
for _, f := range opts {
f(opt)
}
timeLayout := DefaultTimeLayout
if opt.timeLayout != "" {
timeLayout = opt.timeLayout
}
// similar to zap.NewProductionEncoderConfig()
encoderConfig := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
NameKey: "logger", // used by logger.Named(key); optional; useless
CallerKey: "caller",
MessageKey: "msg",
StacktraceKey: "stacktrace", // use by zap.AddStacktrace; optional; useless
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.LowercaseLevelEncoder, // 小写编码器
EncodeTime: func(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
enc.AppendString(t.Format(timeLayout))
},
EncodeDuration: zapcore.MillisDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder, // 全路径编码器
}
jsonEncoder := zapcore.NewJSONEncoder(encoderConfig)
// lowPriority usd by info\debug\warn
lowPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= opt.level && lvl < zapcore.ErrorLevel
})
// highPriority usd by error\panic\fatal
highPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= opt.level && lvl >= zapcore.ErrorLevel
})
stdout := zapcore.Lock(os.Stdout) // lock for concurrent safe
stderr := zapcore.Lock(os.Stderr) // lock for concurrent safe
core := zapcore.NewTee()
if !opt.disableConsole {
core = zapcore.NewTee(
zapcore.NewCore(jsonEncoder,
zapcore.NewMultiWriteSyncer(stdout),
lowPriority,
),
zapcore.NewCore(jsonEncoder,
zapcore.NewMultiWriteSyncer(stderr),
highPriority,
),
)
}
if opt.file != nil {
core = zapcore.NewTee(core,
zapcore.NewCore(jsonEncoder,
zapcore.AddSync(opt.file),
zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= opt.level
}),
),
)
}
logger := zap.New(core,
zap.WithCaller(!opt.disableCaller),
zap.ErrorOutput(stderr),
)
for key, value := range opt.fields {
logger = logger.WithOptions(zap.Fields(zapcore.Field{Key: key, Type: zapcore.StringType, String: value}))
}
return logger, nil
}
var _ Meta = (*meta)(nil)
// Meta key-value
type Meta interface {
Key() string
Value() any
meta()
}
type meta struct {
key string
value any
}
func (m *meta) Key() string {
return m.key
}
func (m *meta) Value() any {
return m.value
}
func (m *meta) meta() {}
// NewMeta create meat
func NewMeta(key string, value any) Meta {
return &meta{key: key, value: value}
}
// WrapMeta wrap meta to zap fields
func WrapMeta(err error, metas ...Meta) (fields []zap.Field) {
capacity := len(metas) + 1 // namespace meta
if err != nil {
capacity++
}
fields = make([]zap.Field, 0, capacity)
if err != nil {
fields = append(fields, zap.Error(err))
}
fields = append(fields, zap.Namespace("meta"))
for _, meta := range metas {
fields = append(fields, zap.Any(meta.Key(), meta.Value()))
}
return
}
// RestyClientLogger use by resty.Client
func RestyClientLogger(logger *zap.Logger) *RestyClientLog {
return &RestyClientLog{logger: logger}
}
type RestyClientLog struct {
logger *zap.Logger
}
func (r *RestyClientLog) Errorf(format string, v ...any) {
r.logger.Sugar().Errorf(format, v)
}
func (r *RestyClientLog) Warnf(format string, v ...any) {
r.logger.Sugar().Warnf(format, v)
}
func (r *RestyClientLog) Debugf(format string, v ...any) {
r.logger.Sugar().Debugf(format, v)
}

36
pkg/mail/mail.go Normal file
View File

@ -0,0 +1,36 @@
package mail
import (
"gopkg.in/gomail.v2"
)
type Options struct {
MailHost string
MailPort int
MailUser string // 发件人
MailPass string // 发件人密码
MailTo []string // 多个收件人
Subject string // 邮件主题
Body string // 邮件内容
}
func Send(o *Options) error {
m := gomail.NewMessage()
//设置发件人
m.SetHeader("From", o.MailUser)
//设置发送给多个用户
m.SetHeader("To", o.MailTo...)
//设置邮件主题
m.SetHeader("Subject", o.Subject)
//设置邮件正文
m.SetBody("text/html", o.Body)
d := gomail.NewDialer(o.MailHost, o.MailPort, o.MailUser, o.MailPass)
return d.DialAndSend(m)
}

28
pkg/md5/md5.go Normal file
View File

@ -0,0 +1,28 @@
package md5
import (
cryptoMD5 "crypto/md5"
"encoding/hex"
)
var _ MD5 = (*md5)(nil)
type MD5 interface {
i()
// Encrypt 加密
Encrypt(encryptStr string) string
}
type md5 struct{}
func New() MD5 {
return &md5{}
}
func (m *md5) i() {}
func (m *md5) Encrypt(encryptStr string) string {
s := cryptoMD5.New()
s.Write([]byte(encryptStr))
return hex.EncodeToString(s.Sum(nil))
}

51
pkg/metrics/metrics.go Normal file
View File

@ -0,0 +1,51 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/cast"
)
const (
namespace = "bvbej"
subsystem = "api"
)
// metricsRequestsTotal metrics for request total 计数器Counter
var metricsRequestsTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "requests_total",
Help: "request(ms) total",
},
[]string{"method", "path"},
)
// metricsRequestsCost metrics for requests cost 累积直方图Histogram
var metricsRequestsCost = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "requests_cost",
Help: "request(ms) cost milliseconds",
},
[]string{"method", "path", "success"},
)
func init() {
prometheus.MustRegister(metricsRequestsTotal, metricsRequestsCost)
}
// RecordMetrics 记录指标
func RecordMetrics(method, uri string, success bool, costSeconds float64) {
metricsRequestsTotal.With(prometheus.Labels{
"method": method,
"path": uri,
}).Inc()
metricsRequestsCost.With(prometheus.Labels{
"method": method,
"path": uri,
"success": cast.ToString(success),
}).Observe(costSeconds)
}

406
pkg/mux/context.go Normal file
View File

@ -0,0 +1,406 @@
package mux
import (
"bytes"
stdCtx "context"
"io"
"net/http"
"net/url"
"strings"
"sync"
"git.bvbej.com/bvbej/base-golang/pkg/errno"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"go.uber.org/zap"
)
type HandlerFunc func(c Context)
type Trace = trace.T
const (
_Alias = "_alias_"
_TraceName = "_trace_"
_LoggerName = "_logger_"
_BodyName = "_body_"
_PayloadName = "_payload_"
_GraphPayloadName = "_graph_payload_"
_AbortErrorName = "_abort_error_"
_Auth = "_auth_"
)
var contextPool = &sync.Pool{
New: func() any {
return new(context)
},
}
func newContext(ctx *gin.Context) Context {
getContext := contextPool.Get().(*context)
getContext.ctx = ctx
return getContext
}
func releaseContext(ctx Context) {
c := ctx.(*context)
c.ctx = nil
contextPool.Put(c)
}
var _ Context = (*context)(nil)
type Context interface {
init()
Context() *gin.Context
// ShouldBindQuery 反序列化 query
// tag: `form:"xxx"` (注:不要写成 query)
ShouldBindQuery(obj any) error
// ShouldBindPostForm 反序列化 x-www-from-urlencoded
// tag: `form:"xxx"`
ShouldBindPostForm(obj any) error
// ShouldBindForm 同时反序列化 form-data;
// tag: `form:"xxx"`
ShouldBindForm(obj any) error
// ShouldBindJSON 反序列化 post-json
// tag: `json:"xxx"`
ShouldBindJSON(obj any) error
// ShouldBindURI 反序列化 path 参数(如路由路径为 /user/:name)
// tag: `uri:"xxx"`
ShouldBindURI(obj any) error
// Redirect 重定向
Redirect(code int, location string)
// Trace 获取 Trace 对象
Trace() Trace
setTrace(trace Trace)
disableTrace()
// Logger 获取 Logger 对象
Logger() *zap.Logger
setLogger(logger *zap.Logger)
// Payload 正确返回
Payload(payload any)
getPayload() any
// GraphPayload GraphQL返回值 与 api 返回结构不同
GraphPayload(payload any)
getGraphPayload() any
// HTML 返回界面
HTML(name string, obj any)
// AbortWithError 错误返回
AbortWithError(err errno.Error)
abortError() errno.Error
// Header 获取 Header 对象
Header() http.Header
// GetHeader 获取 Header
GetHeader(key string) string
// SetHeader 设置 Header
SetHeader(key, value string)
Auth() any
SetAuth(auth any)
// Authorization 获取请求认证信息
Authorization() string
// Platform 平台标识
Platform() string
// Alias 设置路由别名 for metrics uri
Alias() string
setAlias(path string)
// RequestInputParams 获取所有参数
RequestInputParams() url.Values
// RequestQueryParams 获取 Query 参数
RequestQueryParams() url.Values
// RequestPostFormParams 获取 PostForm 参数
RequestPostFormParams() url.Values
// Request 获取 Request 对象
Request() *http.Request
// RawData 获取 Request.Body
RawData() []byte
// Method 获取 Request.Method
Method() string
// Host 获取 Request.Host
Host() string
// Path 获取 请求的路径 Request.URL.Path (不附带 querystring)
Path() string
// URI 获取 unescape 后的 Request.URL.RequestURI()
URI() string
// RequestContext 获取请求的 context (当 client 关闭后,会自动 canceled)
RequestContext() StdContext
// ResponseWriter 获取 ResponseWriter 对象
ResponseWriter() gin.ResponseWriter
}
type context struct {
ctx *gin.Context
}
type StdContext struct {
stdCtx.Context
Trace
*zap.Logger
}
func (c *context) init() {
body, err := c.ctx.GetRawData()
if err != nil {
panic(err)
}
c.ctx.Set(_BodyName, body) // cache body是为了trace使用
c.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // re-construct req body
}
func (c *context) Context() *gin.Context {
return c.ctx
}
// ShouldBindQuery 反序列化querystring
// tag: `form:"xxx"` (注不要写成query)
func (c *context) ShouldBindQuery(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.Query)
}
// ShouldBindPostForm 反序列化 postform (querystring 会被忽略)
// tag: `form:"xxx"`
func (c *context) ShouldBindPostForm(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.FormPost)
}
// ShouldBindForm 同时反序列化querystring和postform;
// 当querystring和postform存在相同字段时postform优先使用。
// tag: `form:"xxx"`
func (c *context) ShouldBindForm(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.Form)
}
// ShouldBindJSON 反序列化postjson
// tag: `json:"xxx"`
func (c *context) ShouldBindJSON(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.JSON)
}
// ShouldBindURI 反序列化path参数(如路由路径为 /user/:name)
// tag: `uri:"xxx"`
func (c *context) ShouldBindURI(obj any) error {
return c.ctx.ShouldBindUri(obj)
}
// Redirect 重定向
func (c *context) Redirect(code int, location string) {
c.ctx.Redirect(code, location)
}
func (c *context) Trace() Trace {
t, ok := c.ctx.Get(_TraceName)
if !ok || t == nil {
return nil
}
return t.(Trace)
}
func (c *context) setTrace(trace Trace) {
c.ctx.Set(_TraceName, trace)
}
func (c *context) disableTrace() {
c.setTrace(nil)
}
func (c *context) Logger() *zap.Logger {
logger, ok := c.ctx.Get(_LoggerName)
if !ok {
return nil
}
return logger.(*zap.Logger)
}
func (c *context) setLogger(logger *zap.Logger) {
c.ctx.Set(_LoggerName, logger)
}
func (c *context) getPayload() any {
if payload, ok := c.ctx.Get(_PayloadName); ok != false {
return payload
}
return nil
}
func (c *context) Payload(payload any) {
c.ctx.Set(_PayloadName, payload)
}
func (c *context) getGraphPayload() any {
if payload, ok := c.ctx.Get(_GraphPayloadName); ok != false {
return payload
}
return nil
}
func (c *context) GraphPayload(payload any) {
c.ctx.Set(_GraphPayloadName, payload)
}
func (c *context) HTML(name string, obj any) {
c.ctx.HTML(200, name+".html", obj)
}
func (c *context) Header() http.Header {
header := c.ctx.Request.Header
clone := make(http.Header, len(header))
for k, v := range header {
value := make([]string, len(v))
copy(value, v)
clone[k] = value
}
return clone
}
func (c *context) GetHeader(key string) string {
return c.ctx.GetHeader(key)
}
func (c *context) SetHeader(key, value string) {
c.ctx.Header(key, value)
}
func (c *context) Auth() any {
val, ok := c.ctx.Get(_Auth)
if !ok {
return nil
}
return val
}
func (c *context) SetAuth(auth any) {
c.ctx.Set(_Auth, auth)
}
func (c *context) Authorization() string {
return c.ctx.GetHeader("Authorization")
}
func (c *context) Platform() string {
return c.ctx.GetHeader("X-Platform")
}
func (c *context) AbortWithError(err errno.Error) {
if err != nil {
httpCode := err.GetHttpCode()
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
c.ctx.AbortWithStatus(httpCode)
c.ctx.Set(_AbortErrorName, err)
}
}
func (c *context) abortError() errno.Error {
err, _ := c.ctx.Get(_AbortErrorName)
return err.(errno.Error)
}
func (c *context) Alias() string {
path, ok := c.ctx.Get(_Alias)
if !ok {
return ""
}
return path.(string)
}
func (c *context) setAlias(path string) {
if path = strings.TrimSpace(path); path != "" {
c.ctx.Set(_Alias, path)
}
}
// RequestInputParams 获取所有参数
func (c *context) RequestInputParams() url.Values {
_ = c.ctx.Request.ParseForm()
return c.ctx.Request.Form
}
// RequestQueryParams 获取Query参数
func (c *context) RequestQueryParams() url.Values {
query, _ := url.ParseQuery(c.ctx.Request.URL.RawQuery)
return query
}
// RequestPostFormParams 获取 PostForm 参数
func (c *context) RequestPostFormParams() url.Values {
_ = c.ctx.Request.ParseForm()
return c.ctx.Request.PostForm
}
// Request 获取 Request
func (c *context) Request() *http.Request {
return c.ctx.Request
}
func (c *context) RawData() []byte {
body, ok := c.ctx.Get(_BodyName)
if !ok {
return nil
}
return body.([]byte)
}
// Method 请求的method
func (c *context) Method() string {
return c.ctx.Request.Method
}
// Host 请求的host
func (c *context) Host() string {
return c.ctx.Request.Host
}
// Path 请求的路径(不附带querystring)
func (c *context) Path() string {
return c.ctx.Request.URL.Path
}
// URI unescape后的uri
func (c *context) URI() string {
uri, _ := url.QueryUnescape(c.ctx.Request.URL.RequestURI())
return uri
}
// RequestContext (包装 Trace + Logger) 获取请求的 context (当client关闭后会自动canceled)
func (c *context) RequestContext() StdContext {
return StdContext{
//c.ctx.Request.Context(),
stdCtx.Background(),
c.Trace(),
c.Logger(),
}
}
// ResponseWriter 获取 ResponseWriter
func (c *context) ResponseWriter() gin.ResponseWriter {
return c.ctx.Writer
}

466
pkg/mux/core.go Normal file
View File

@ -0,0 +1,466 @@
package mux
import (
"errors"
"fmt"
"net/http"
"net/url"
"runtime/debug"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/color"
"git.bvbej.com/bvbej/base-golang/pkg/env"
"git.bvbej.com/bvbej/base-golang/pkg/errno"
"git.bvbej.com/bvbej/base-golang/pkg/limiter"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"git.bvbej.com/bvbej/base-golang/pkg/validator"
"github.com/gin-contrib/pprof"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/prometheus/client_golang/prometheus/promhttp"
cors "github.com/rs/cors/wrapper/gin"
"go.uber.org/multierr"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
type Option func(*option)
type option struct {
enableCors bool
enablePProf bool
enablePrometheus bool
enableOpenBrowser string
staticDirs []string
panicNotify OnPanicNotify
recordMetrics RecordMetrics
rateLimiter limiter.RateLimiter
}
const SuccessCode = 0
type Failure struct {
ResultCode int `json:"result_code"` // 业务码
ResultInfo string `json:"result_info"` // 描述信息
}
type Success struct {
ResultCode int `json:"result_code"` // 业务码
ResultData any `json:"result_data"` //返回数据
}
/******************************************************************************/
// OnPanicNotify 发生panic时通知用
type OnPanicNotify func(ctx Context, err any, stackInfo string)
// RecordMetrics 记录prometheus指标用
// 如果使用AliasForRecordMetrics配置了别名uri将被替换为别名。
type RecordMetrics func(method, uri string, success bool, costSeconds float64)
// DisableTrace 禁用追踪链
func DisableTrace(ctx Context) {
ctx.disableTrace()
}
// WithPanicNotify 设置panic时的通知回调
func WithPanicNotify(notify OnPanicNotify) Option {
return func(opt *option) {
opt.panicNotify = notify
fmt.Println(color.Green("* [register panic notify]"))
}
}
// WithRecordMetrics 设置记录prometheus记录指标回调
func WithRecordMetrics(record RecordMetrics) Option {
return func(opt *option) {
opt.recordMetrics = record
}
}
// WithEnableCors 开启CORS
func WithEnableCors() Option {
return func(opt *option) {
opt.enableCors = true
fmt.Println(color.Green("* [register cors]"))
}
}
// WithEnableRate 开启限流
func WithEnableRate(limit rate.Limit, burst int) Option {
return func(opt *option) {
opt.rateLimiter = limiter.NewRateLimiter(limit, burst)
fmt.Println(color.Green("* [register rate]"))
}
}
// WithStaticDir 设置静态文件目录
func WithStaticDir(dirs []string) Option {
return func(opt *option) {
opt.staticDirs = dirs
fmt.Println(color.Green("* [register rate]"))
}
}
// AliasForRecordMetrics 对请求uri起个别名用于prometheus记录指标。
// 如Get /user/:username 这样的uri因为username会有非常多的情况这样记录prometheus指标会非常的不有好。
func AliasForRecordMetrics(path string) HandlerFunc {
return func(ctx Context) {
ctx.setAlias(path)
}
}
/******************************************************************************/
// RouterGroup 包装gin的RouterGroup
type RouterGroup interface {
Group(string, ...HandlerFunc) RouterGroup
IRoutes
}
var _ IRoutes = (*router)(nil)
// IRoutes 包装gin的IRoutes
type IRoutes interface {
Any(string, ...HandlerFunc)
GET(string, ...HandlerFunc)
POST(string, ...HandlerFunc)
DELETE(string, ...HandlerFunc)
PATCH(string, ...HandlerFunc)
PUT(string, ...HandlerFunc)
OPTIONS(string, ...HandlerFunc)
HEAD(string, ...HandlerFunc)
}
type router struct {
group *gin.RouterGroup
}
func (r *router) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
group := r.group.Group(relativePath, wrapHandlers(handlers...)...)
return &router{group: group}
}
func (r *router) Any(relativePath string, handlers ...HandlerFunc) {
r.group.Any(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) GET(relativePath string, handlers ...HandlerFunc) {
r.group.GET(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) POST(relativePath string, handlers ...HandlerFunc) {
r.group.POST(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) DELETE(relativePath string, handlers ...HandlerFunc) {
r.group.DELETE(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) PATCH(relativePath string, handlers ...HandlerFunc) {
r.group.PATCH(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) PUT(relativePath string, handlers ...HandlerFunc) {
r.group.PUT(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) OPTIONS(relativePath string, handlers ...HandlerFunc) {
r.group.OPTIONS(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) HEAD(relativePath string, handlers ...HandlerFunc) {
r.group.HEAD(relativePath, wrapHandlers(handlers...)...)
}
func wrapHandlers(handlers ...HandlerFunc) []gin.HandlerFunc {
list := make([]gin.HandlerFunc, len(handlers))
for i, handler := range handlers {
fn := handler
list[i] = func(c *gin.Context) {
ctx := newContext(c)
defer releaseContext(ctx)
fn(ctx)
}
}
return list
}
/******************************************************************************/
var _ Mux = (*mux)(nil)
type Mux interface {
http.Handler
Group(relativePath string, handlers ...HandlerFunc) RouterGroup
Routes() gin.RoutesInfo
HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc)
}
type mux struct {
engine *gin.Engine
}
func (m *mux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
m.engine.ServeHTTP(w, req)
}
func (m *mux) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
return &router{
group: m.engine.Group(relativePath, wrapHandlers(handlers...)...),
}
}
func (m *mux) Routes() gin.RoutesInfo {
return m.engine.Routes()
}
func (m *mux) HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc) {
m.engine.GET(relativePath, handlerFunc)
}
func New(logger *zap.Logger, options ...Option) (Mux, error) {
if logger == nil {
return nil, errors.New("logger required")
}
gin.SetMode(gin.ReleaseMode)
binding.Validator = validator.Validator
newMux := &mux{
engine: gin.New(),
}
fmt.Println(color.Green(fmt.Sprintf("* [register env %s]", env.Active().Value())))
// withoutLogPaths 这些请求,默认不记录日志
withoutTracePaths := map[string]bool{
"/metrics": true,
"/favicon.ico": true,
"/system/health": true,
}
opt := new(option)
for _, f := range options {
f(opt)
}
if opt.enablePProf {
pprof.Register(newMux.engine)
fmt.Println(color.Green("* [register pprof]"))
}
if opt.enablePrometheus {
newMux.engine.GET("/metrics", gin.WrapH(promhttp.Handler()))
fmt.Println(color.Green("* [register prometheus]"))
}
if opt.enableCors {
newMux.engine.Use(cors.AllowAll())
}
if opt.staticDirs != nil {
for _, dir := range opt.staticDirs {
newMux.engine.StaticFS(dir, gin.Dir(dir, false))
}
}
// recover两次防止处理时发生panic尤其是在OnPanicNotify中。
newMux.engine.Use(func(ctx *gin.Context) {
defer func() {
if err := recover(); err != nil {
logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", string(debug.Stack())))
}
}()
ctx.Next()
})
newMux.engine.Use(func(ctx *gin.Context) {
ts := time.Now()
newCtx := newContext(ctx)
defer releaseContext(newCtx)
newCtx.init()
newCtx.setLogger(logger)
if !withoutTracePaths[ctx.Request.URL.Path] {
if traceId := newCtx.GetHeader(trace.Header); traceId != "" {
newCtx.setTrace(trace.New(traceId))
} else {
newCtx.setTrace(trace.New(""))
}
}
defer func() {
if err := recover(); err != nil {
stackInfo := string(debug.Stack())
logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", stackInfo))
newCtx.AbortWithError(errno.NewError(
http.StatusInternalServerError,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError)),
)
if notify := opt.panicNotify; notify != nil {
notify(newCtx, err, stackInfo)
}
}
if ctx.Writer.Status() == http.StatusNotFound {
return
}
var (
response any
businessCode int
businessCodeMsg string
abortErr error
graphResponse any
)
if ctx.IsAborted() {
for i := range ctx.Errors { // gin error
multierr.AppendInto(&abortErr, ctx.Errors[i])
}
if err := newCtx.abortError(); err != nil { // customer err
multierr.AppendInto(&abortErr, err.GetErr())
response = err
businessCode = err.GetBusinessCode()
businessCodeMsg = err.GetMsg()
if x := newCtx.Trace(); x != nil {
newCtx.SetHeader(trace.Header, x.ID())
}
ctx.JSON(err.GetHttpCode(), &Failure{
ResultCode: businessCode,
ResultInfo: businessCodeMsg,
})
}
} else {
response = newCtx.getPayload()
if response != nil {
if x := newCtx.Trace(); x != nil {
newCtx.SetHeader(trace.Header, x.ID())
}
ctx.JSON(http.StatusOK, response)
}
}
graphResponse = newCtx.getGraphPayload()
if opt.recordMetrics != nil {
uri := newCtx.Path()
if alias := newCtx.Alias(); alias != "" {
uri = alias
}
opt.recordMetrics(
newCtx.Method(),
uri,
!ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK,
time.Since(ts).Seconds(),
)
}
var t *trace.Trace
if x := newCtx.Trace(); x != nil {
t = x.(*trace.Trace)
} else {
return
}
decodedURL, _ := url.QueryUnescape(ctx.Request.URL.RequestURI())
t.WithRequest(&trace.Request{
TTL: "un-limit",
Method: ctx.Request.Method,
DecodedURL: decodedURL,
Header: ctx.Request.Header,
Body: string(newCtx.RawData()),
})
var responseBody any
if response != nil {
responseBody = response
}
if graphResponse != nil {
responseBody = graphResponse
}
t.WithResponse(&trace.Response{
Header: ctx.Writer.Header(),
HttpCode: ctx.Writer.Status(),
HttpCodeMsg: http.StatusText(ctx.Writer.Status()),
BusinessCode: businessCode,
BusinessCodeMsg: businessCodeMsg,
Body: responseBody,
CostSeconds: time.Since(ts).Seconds(),
})
t.Success = !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK
t.CostSeconds = time.Since(ts).Seconds()
logger.Info("core-interceptor",
zap.Any("method", ctx.Request.Method),
zap.Any("path", decodedURL),
zap.Any("http_code", ctx.Writer.Status()),
zap.Any("business_code", businessCode),
zap.Any("success", t.Success),
zap.Any("cost_seconds", t.CostSeconds),
zap.Any("trace_id", t.Identifier),
zap.Any("trace_info", t),
zap.Error(abortErr),
)
}()
ctx.Next()
})
if opt.rateLimiter != nil {
newMux.engine.Use(func(ctx *gin.Context) {
newCtx := newContext(ctx)
defer releaseContext(newCtx)
if !opt.rateLimiter.Allow(ctx.ClientIP()) {
newCtx.AbortWithError(errno.NewError(
http.StatusTooManyRequests,
http.StatusTooManyRequests,
http.StatusText(http.StatusTooManyRequests)),
)
return
}
ctx.Next()
})
}
newMux.engine.NoMethod(wrapHandlers(DisableTrace)...)
newMux.engine.NoRoute(wrapHandlers(DisableTrace)...)
system := newMux.Group("/system")
{
// 健康检查
system.GET("/health", func(ctx Context) {
resp := &struct {
Timestamp time.Time `json:"timestamp"`
Environment string `json:"environment"`
Host string `json:"host"`
Status string `json:"status"`
}{
Timestamp: time.Now(),
Environment: env.Active().Value(),
Host: ctx.Host(),
Status: "ok",
}
ctx.Payload(resp)
})
}
return newMux, nil
}

View File

@ -0,0 +1,3 @@
package observable
type Iterable <-chan any

View File

@ -0,0 +1,67 @@
package observable
// Ref: github.com/Dreamacro/clash/common/observable
import (
"errors"
"sync"
)
type Observable struct {
iterable Iterable
listener map[Subscription]*Subscriber
mux sync.Mutex
done bool
}
func (o *Observable) process() {
for item := range o.iterable {
o.mux.Lock()
for _, sub := range o.listener {
sub.Emit(item)
}
o.mux.Unlock()
}
o.close()
}
func (o *Observable) close() {
o.mux.Lock()
defer o.mux.Unlock()
o.done = true
for _, sub := range o.listener {
sub.Close()
}
}
func (o *Observable) Subscribe() (Subscription, error) {
o.mux.Lock()
defer o.mux.Unlock()
if o.done {
return nil, errors.New("observable is closed")
}
subscriber := newSubscriber()
o.listener[subscriber.Out()] = subscriber
return subscriber.Out(), nil
}
func (o *Observable) UnSubscribe(sub Subscription) {
o.mux.Lock()
defer o.mux.Unlock()
subscriber, exist := o.listener[sub]
if !exist {
return
}
delete(o.listener, sub)
subscriber.Close()
}
func NewObservable(any Iterable) *Observable {
observable := &Observable{
iterable: any,
listener: map[Subscription]*Subscriber{},
}
go observable.process()
return observable
}

View File

@ -0,0 +1,33 @@
package observable
import (
"sync"
)
type Subscription <-chan any
type Subscriber struct {
buffer chan any
once sync.Once
}
func (s *Subscriber) Emit(item any) {
s.buffer <- item
}
func (s *Subscriber) Out() Subscription {
return s.buffer
}
func (s *Subscriber) Close() {
s.once.Do(func() {
close(s.buffer)
})
}
func newSubscriber() *Subscriber {
sub := &Subscriber{
buffer: make(chan any, 200),
}
return sub
}

50
pkg/p/print.go Normal file
View File

@ -0,0 +1,50 @@
package p
import (
"fmt"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
)
type Option func(*option)
type Trace = trace.T
type option struct {
Trace *trace.Trace
Debug *trace.Debug
}
func newOption() *option {
return &option{}
}
func Println(key string, value any, options ...Option) {
ts := time.Now()
opt := newOption()
defer func() {
if opt.Trace != nil {
opt.Debug.Key = key
opt.Debug.Value = value
opt.Debug.CostSeconds = time.Since(ts).Seconds()
opt.Trace.AppendDebug(opt.Debug)
}
}()
for _, f := range options {
f(opt)
}
fmt.Println(fmt.Sprintf("KEY: %s | VALUE: %v", key, value))
}
// WithTrace 设置trace信息
func WithTrace(t Trace) Option {
return func(opt *option) {
if t != nil {
opt.Trace = t.(*trace.Trace)
opt.Debug = new(trace.Debug)
}
}
}

142
pkg/proxy/io.go Normal file
View File

@ -0,0 +1,142 @@
package proxy
import (
"context"
"golang.org/x/time/rate"
"io"
"net"
"runtime/debug"
"sync"
"time"
)
const burstLimit = 1000 * 1000 * 1000
type Reader struct {
r io.Reader
limiter *rate.Limiter
ctx context.Context
}
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
ctx: context.Background(),
}
}
func (s *Reader) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
}
func (s *Reader) Read(p []byte) (int, error) {
if s.limiter == nil {
return s.r.Read(p)
}
n, err := s.r.Read(p)
if err != nil {
return n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, nil
}
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
return
}
func CloseConn(conn net.Conn) {
if conn != nil {
_ = conn.SetDeadline(time.Now().Add(time.Millisecond))
_ = conn.Close()
}
}
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
var one = &sync.Once{}
go func() {
defer func() {
if e := recover(); e != nil {
logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
var isSrcErr bool
if bytesPreSec > 0 {
newReader := NewReader(src)
newReader.SetRateLimit(bytesPreSec)
_, isSrcErr, err = IoCopy(dst, newReader, func(c int) {
cfn(c, false)
})
} else {
_, isSrcErr, err = IoCopy(dst, src, func(c int) {
cfn(c, false)
})
}
if err != nil {
one.Do(func() {
fn(isSrcErr, err)
})
}
}()
go func() {
defer func() {
if e := recover(); e != nil {
logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
var isSrcErr bool
if bytesPreSec > 0 {
newReader := NewReader(dst)
newReader.SetRateLimit(bytesPreSec)
_, isSrcErr, err = IoCopy(src, newReader, func(c int) {
cfn(c, true)
})
} else {
_, isSrcErr, err = IoCopy(src, dst, func(c int) {
cfn(c, true)
})
}
if err != nil {
one.Do(func() {
fn(isSrcErr, err)
})
}
}()
}
func IoCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
if len(fn) == 1 {
fn[0](nw)
}
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
err = er
isSrcErr = true
break
}
}
return written, isSrcErr, err
}

105
pkg/proxy/service.go Normal file
View File

@ -0,0 +1,105 @@
package proxy
import (
"fmt"
"go.uber.org/zap"
"net"
"runtime/debug"
"sync"
)
var (
servicesMap = new(sync.Map)
logger *zap.Logger
)
type Service struct {
TCPConn TCP
Name string
}
func (s *Service) Stop() {
servicesMap.Delete(s.Name)
s.TCPConn.Close()
}
func Run(name string, args TCPArgs, zapLogger *zap.Logger) *Service {
logger = zapLogger
service := &Service{
TCPConn: &tcp{cfg: args},
Name: name,
}
store, loaded := servicesMap.LoadOrStore(name, service)
if loaded {
service = store.(*Service)
}
go func() {
defer func() {
recoverErr := recover()
if recoverErr != nil {
logger.Sugar().Errorf("%s servcie crashed, ERR: %s\ntrace:%s", name, recoverErr, string(debug.Stack()))
}
}()
startErr := service.TCPConn.Start()
if startErr != nil {
logger.Sugar().Errorf("%s servcie fail, ERR: %s", name, startErr)
}
}()
return service
}
///////////////////////////////////////////////////////////////////////////
type Listener struct {
ip string
port int
Listener net.Listener
errAcceptHandler func(err error)
}
func NewListener(ip string, port int) Listener {
return Listener{
ip: ip,
port: port,
errAcceptHandler: func(err error) {
logger.Sugar().Errorf("accept error , ERR:%s", err)
},
}
}
func (sc *Listener) ListenTCP(fn func(conn net.Conn)) (err error) {
sc.Listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sc.ip, sc.port))
if err == nil {
go func() {
defer func() {
if e := recover(); e != nil {
logger.Sugar().Infof("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
for {
var conn net.Conn
conn, err = sc.Listener.Accept()
if err == nil {
go func() {
defer func() {
if e := recover(); e != nil {
logger.Sugar().Infof("connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
fn(conn)
}()
} else {
sc.errAcceptHandler(err)
break
}
}
}()
}
return
}
func (sc *Listener) CloseListen() error {
return sc.Listener.Close()
}

94
pkg/proxy/tcp.go Normal file
View File

@ -0,0 +1,94 @@
package proxy
import (
"net"
"strconv"
)
var _ TCP = (*tcp)(nil)
type TCPArgs struct {
Local string //监听地址
Parent string //被代理地址
Timeout int //拨号超时(毫秒)
OutCallback func() bool //回调 //是否允许代理
}
type TCP interface {
Start() (err error)
Close()
callback(inConn net.Conn)
outToTCP(inConn net.Conn) (err error)
}
type tcp struct {
inConn net.Conn
outConn net.Conn
listen Listener
cfg TCPArgs
}
func (s *tcp) Start() (err error) {
host, port, _ := net.SplitHostPort(s.cfg.Local)
p, _ := strconv.Atoi(port)
s.listen = NewListener(host, p)
err = s.listen.ListenTCP(s.callback)
if err != nil {
return
}
return
}
func (s *tcp) Close() {
if s.inConn != nil {
CloseConn(s.inConn)
}
if s.outConn != nil {
CloseConn(s.outConn)
}
_ = s.listen.CloseListen()
}
func (s *tcp) callback(inConn net.Conn) {
defer func() {
if err := recover(); err != nil {
logger.Sugar().Infof("conn handler crashed with err : %s", err)
}
}()
if s.cfg.OutCallback != nil {
if !s.cfg.OutCallback() {
CloseConn(inConn)
return
}
}
err := s.outToTCP(inConn)
if err != nil {
CloseConn(inConn)
}
s.inConn = inConn
}
func (s *tcp) outToTCP(inConn net.Conn) error {
outConn, err := ConnectHost(s.cfg.Parent, s.cfg.Timeout)
if err != nil {
return err
}
inAddr := inConn.RemoteAddr().String()
inLocalAddr := inConn.LocalAddr().String()
outAddr := outConn.RemoteAddr().String()
outLocalAddr := outConn.LocalAddr().String()
IoBind(inConn, outConn, func(isSrcErr bool, err error) {
CloseConn(inConn)
CloseConn(outConn)
logger.Sugar().Infof("conn %s - %s - %s -%s released", inAddr, inLocalAddr, outLocalAddr, outAddr)
}, func(n int, d bool) {}, 0)
logger.Sugar().Infof("conn %s - %s - %s -%s connected", inAddr, inLocalAddr, outLocalAddr, outAddr)
return nil
}

236
pkg/qiniu/storage.go Normal file
View File

@ -0,0 +1,236 @@
package qiniu
import (
"context"
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/md5"
"git.bvbej.com/bvbej/base-golang/tool"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
"github.com/tidwall/gjson"
"io"
"net/http"
"net/url"
"path"
"strings"
"time"
)
var _ QiNiu = (*qiNiu)(nil)
type QiNiu interface {
i()
SetDefaultUploadTokenTTL(ttl uint64)
GetCallbackUploadToken(ttl uint64, callbackURL string) string
GetUploadToken(ttl uint64) string
GetPrivateURL(key string, ttl uint64) string
VerifyCallback(req *http.Request) (bool, error)
UploadFile(key, localFile string) (*PutRet, error)
ResumeUploadFile(key, localFile string) (*PutRet, error)
DelFile(key string) error
TimestampSecuritySign(path string, ttl time.Duration) string
GetFileInfo(key string) *storage.FileInfo
ListFiles(prefix, delimiter, marker string, limit int) (entries []storage.ListItem, commonPrefixes []string, nextMarker string, hasNext bool, err error)
GetFileHash(path, qhash string) (hash string, err error)
}
type qiNiu struct {
mac *qbox.Mac
bucketManager *storage.BucketManager
conf *storage.Config
bucket string
domain string
securityKey string
md5 md5.MD5
uploadTokenTTL uint64
}
type PutRet struct {
Key string `json:"key"`
Hash string `json:"hash"`
Fsize string `json:"fsize"`
Fname string `json:"fname"`
Ext string `json:"ext"`
Unique string `json:"unique"`
User string `json:"user"`
}
func New(accessKey, secretKey, bucket, domain, securityKey string) QiNiu {
region, _ := storage.GetRegion(accessKey, bucket)
mac := qbox.NewMac(accessKey, secretKey)
conf := &storage.Config{
Region: region, //空间所在的存储区域
UseHTTPS: true, //是否使用https域名
UseCdnDomains: true, //上传是否使用CDN上传加速
}
return &qiNiu{
mac: mac,
bucketManager: storage.NewBucketManager(mac, conf),
bucket: bucket,
domain: domain,
securityKey: securityKey,
conf: conf,
md5: md5.New(),
uploadTokenTTL: 3600,
}
}
func (q *qiNiu) i() {}
func (q *qiNiu) SetDefaultUploadTokenTTL(ttl uint64) {
q.uploadTokenTTL = ttl
}
func (q *qiNiu) GetUploadToken(ttl uint64) string {
putPolicy := storage.PutPolicy{
Scope: q.bucket,
Expires: ttl,
}
return putPolicy.UploadToken(q.mac)
}
func (q *qiNiu) GetCallbackUploadToken(ttl uint64, callbackURL string) string {
putPolicy := storage.PutPolicy{
Scope: q.bucket,
CallbackURL: callbackURL,
CallbackBody: `{"key":"$(key)","hash":"$(etag)","fname":"$(fname)","fsize":"$(fsize)","ext":"$(ext)","unique":"$(x:unique)","user":"$(x:user)"}`,
CallbackBodyType: "application/json",
Expires: ttl,
}
return putPolicy.UploadToken(q.mac)
}
func (q *qiNiu) GetPrivateURL(key string, ttl uint64) string {
deadline := time.Now().Add(time.Second * time.Duration(ttl)).Unix()
return storage.MakePrivateURL(q.mac, q.domain, key, deadline)
}
func (q *qiNiu) VerifyCallback(req *http.Request) (bool, error) {
return q.mac.VerifyCallback(req)
}
func (q *qiNiu) UploadFile(key, localFile string) (*PutRet, error) {
upToken := q.GetUploadToken(q.uploadTokenTTL)
//构建表单上传的对象
formUploader := storage.NewFormUploader(q.conf)
//请求参数
filename := path.Base(key)
fileSuffix := path.Ext(key)
filePrefix := filename[0 : len(filename)-len(fileSuffix)]
putExtra := &storage.PutExtra{
Params: map[string]string{
"x:unique": filePrefix,
"x:user": "-",
},
}
//自定义返回body
ret := new(PutRet)
err := formUploader.PutFile(context.Background(), ret, upToken, key, localFile, putExtra)
if err != nil {
return nil, err
}
return ret, nil
}
func (q *qiNiu) ResumeUploadFile(key, localFile string) (*PutRet, error) {
upToken := q.GetUploadToken(q.uploadTokenTTL)
//构建分片上传的对象
resumeUploader := storage.NewResumeUploaderV2(q.conf)
//请求参数
filename := path.Base(key)
fileSuffix := path.Ext(key)
filePrefix := filename[0 : len(filename)-len(fileSuffix)]
putExtra := &storage.RputV2Extra{
CustomVars: map[string]string{
"x:unique": filePrefix,
"x:user": "-",
},
}
//自定义返回body
ret := new(PutRet)
err := resumeUploader.PutFile(context.Background(), ret, upToken, key, localFile, putExtra)
if err != nil {
return nil, err
}
return ret, nil
}
func (q *qiNiu) DelFile(key string) error {
err := q.bucketManager.Delete(q.bucket, key)
if err != nil {
return err
}
return nil
}
func (q *qiNiu) TimestampSecuritySign(path string, ttl time.Duration) string {
sep := "/"
path = strings.Trim(path, sep)
splits := strings.Split(path, sep)
for i, split := range splits {
splits[i] = url.QueryEscape(split)
}
path = sep + strings.Join(splits, sep)
unix := time.Now().Add(ttl).Unix()
hex := fmt.Sprintf("%x", unix)
encrypt := q.md5.Encrypt(q.securityKey + path + hex)
param := make(url.Values)
param.Set("sign", encrypt)
param.Set("t", hex)
return param.Encode()
}
func (q *qiNiu) GetFileInfo(key string) *storage.FileInfo {
fileInfo, sErr := q.bucketManager.Stat(q.bucket, key)
if sErr != nil {
return nil
}
return &fileInfo
}
func (q *qiNiu) ListFiles(prefix, delimiter, marker string, limit int) (entries []storage.ListItem,
commonPrefixes []string, nextMarker string, hasNext bool, err error) {
return q.bucketManager.ListFiles(q.bucket, prefix, delimiter, marker, limit)
}
func (q *qiNiu) GetFileHash(path, qhash string) (hash string, err error) {
if !tool.InArray(qhash, []string{"sha1", "md5", "sha256"}) {
return "", errors.New("qhash invalid")
}
sign := q.TimestampSecuritySign(path, time.Second*5)
addr := fmt.Sprintf("https://cdn.mogume.com/%s?%s&qhash/%s", path, sign, qhash)
resp, err := http.Get(addr)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", errors.New(resp.Status)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return gjson.GetBytes(body, "hash").String(), nil
}

131
pkg/rsa/rsa.go Normal file
View File

@ -0,0 +1,131 @@
package rsa
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
)
var _ Public = (*rsaPub)(nil)
var _ Private = (*rsaPri)(nil)
type Public interface {
i()
EncryptURLEncoding(encryptStr string) (string, error)
Encrypt(encryptStr string) (string, error)
}
type Private interface {
i()
Decrypt(decryptStr string) (string, error)
DecryptURLEncoding(decryptStr string) (string, error)
}
type rsaPub struct {
PublicKey string
}
type rsaPri struct {
PrivateKey string
}
func NewPublic(publicKey string) Public {
return &rsaPub{
PublicKey: publicKey,
}
}
func NewPrivate(privateKey string) Private {
return &rsaPri{
PrivateKey: privateKey,
}
}
func (pub *rsaPub) i() {}
func (pub *rsaPub) Encrypt(encryptStr string) (string, error) {
// pem 解码
block, _ := pem.Decode([]byte(pub.PublicKey))
// x509 解码
publicKeyInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return "", err
}
// 类型断言
publicKey := publicKeyInterface.(*rsa.PublicKey)
//对明文进行加密
encryptedStr, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey, []byte(encryptStr))
if err != nil {
return "", err
}
//返回密文
return base64.StdEncoding.EncodeToString(encryptedStr), nil
}
func (pub *rsaPub) EncryptURLEncoding(encryptStr string) (string, error) {
// pem 解码
block, _ := pem.Decode([]byte(pub.PublicKey))
// x509 解码
publicKeyInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return "", err
}
// 类型断言
publicKey := publicKeyInterface.(*rsa.PublicKey)
//对明文进行加密
encryptedStr, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey, []byte(encryptStr))
if err != nil {
return "", err
}
//返回密文
return base64.URLEncoding.EncodeToString(encryptedStr), nil
}
func (pri *rsaPri) i() {}
func (pri *rsaPri) Decrypt(decryptStr string) (string, error) {
// pem 解码
block, _ := pem.Decode([]byte(pri.PrivateKey))
// X509 解码
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", err
}
decryptBytes, err := base64.StdEncoding.DecodeString(decryptStr)
//对密文进行解密
decrypted, _ := rsa.DecryptPKCS1v15(rand.Reader, privateKey, decryptBytes)
//返回明文
return string(decrypted), nil
}
func (pri *rsaPri) DecryptURLEncoding(decryptStr string) (string, error) {
// pem 解码
block, _ := pem.Decode([]byte(pri.PrivateKey))
// X509 解码
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return "", err
}
decryptBytes, err := base64.URLEncoding.DecodeString(decryptStr)
//对密文进行解密
decrypted, _ := rsa.DecryptPKCS1v15(rand.Reader, privateKey, decryptBytes)
//返回明文
return string(decrypted), nil
}

50
pkg/shutdown/shutdown.go Normal file
View File

@ -0,0 +1,50 @@
package shutdown
import (
"os"
"os/signal"
"syscall"
)
var _ Hook = (*hook)(nil)
// Hook a graceful shutdown hook, default with signals of SIGINT and SIGTERM
type Hook interface {
// WithSignals add more signals into hook
WithSignals(signals ...syscall.Signal) Hook
// Close register shutdown handles
Close(funcs ...func())
}
type hook struct {
ctx chan os.Signal
}
// NewHook create a Hook instance
func NewHook() Hook {
hook := &hook{
ctx: make(chan os.Signal, 1),
}
return hook.WithSignals(syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGQUIT)
}
func (h *hook) WithSignals(signals ...syscall.Signal) Hook {
for _, s := range signals {
signal.Notify(h.ctx, s)
}
return h
}
func (h *hook) Close(funcs ...func()) {
select {
case <-h.ctx:
}
signal.Stop(h.ctx)
for _, f := range funcs {
f()
}
}

View File

@ -0,0 +1,52 @@
package signature
import (
"net/http"
"net/url"
"time"
)
var _ Signature = (*signature)(nil)
const (
delimiter = "|"
)
// 合法的 Methods
var methods = map[string]bool{
http.MethodGet: true,
http.MethodPost: true,
http.MethodHead: true,
http.MethodPut: true,
http.MethodPatch: true,
http.MethodDelete: true,
http.MethodConnect: true,
http.MethodOptions: true,
http.MethodTrace: true,
}
type Signature interface {
i()
// Generate 生成签名
Generate(path string, method string, params url.Values) (authorization, date string, err error)
// Verify 验证签名
Verify(authorization, date string, path string, method string, params url.Values) (ok bool, err error)
}
type signature struct {
key string
secret string
ttl time.Duration
}
func New(key, secret string, ttl time.Duration) Signature {
return &signature{
key: key,
secret: secret,
ttl: ttl,
}
}
func (s *signature) i() {}

View File

@ -0,0 +1,62 @@
package signature
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/url"
"strings"
"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
)
// Generate
// path 请求的路径 (不附带 querystring)
func (s *signature) Generate(path string, method string, params url.Values) (authorization, date string, err error) {
if path == "" {
err = errors.New("path required")
return
}
if method == "" {
err = errors.New("method required")
return
}
methodName := strings.ToUpper(method)
if !methods[methodName] {
err = errors.New("method param error")
return
}
// Date
date = time_parse.CSTLayoutString()
// Encode() 方法中自带 sorted by key
sortParamsEncode, err := url.QueryUnescape(params.Encode())
if err != nil {
err = fmt.Errorf("url QueryUnescape error [%v]", err)
return
}
// 加密字符串规则
buffer := bytes.NewBuffer(nil)
buffer.WriteString(path)
buffer.WriteString(delimiter)
buffer.WriteString(methodName)
buffer.WriteString(delimiter)
buffer.WriteString(sortParamsEncode)
buffer.WriteString(delimiter)
buffer.WriteString(date)
// 对数据进行 sha256 加密,并进行 base64 encode
hash := hmac.New(sha256.New, []byte(s.secret))
hash.Write(buffer.Bytes())
digest := base64.StdEncoding.EncodeToString(hash.Sum(nil))
authorization = fmt.Sprintf("%s %s", s.key, digest)
return
}

View File

@ -0,0 +1,73 @@
package signature
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net/url"
"strings"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
)
func (s *signature) Verify(authorization, date string, path string, method string, params url.Values) (ok bool, err error) {
if date == "" {
err = errors.New("date required")
return
}
if path == "" {
err = errors.New("path required")
return
}
if method == "" {
err = errors.New("method required")
return
}
methodName := strings.ToUpper(method)
if !methods[methodName] {
err = errors.New("method param error")
return
}
ts, err := time_parse.ParseCSTInLocation(date)
if err != nil {
err = errors.New("date must follow '2006-01-02 15:04:05'")
return
}
if time_parse.SubInLocation(ts) > float64(s.ttl/time.Second) {
err = fmt.Errorf("date exceeds limit [%v]", s.ttl)
return
}
// Encode() 方法中自带 sorted by key
sortParamsEncode, err := url.QueryUnescape(params.Encode())
if err != nil {
err = fmt.Errorf("url QueryUnescape error [%v]", err)
return
}
buffer := bytes.NewBuffer(nil)
buffer.WriteString(path)
buffer.WriteString(delimiter)
buffer.WriteString(methodName)
buffer.WriteString(delimiter)
buffer.WriteString(sortParamsEncode)
buffer.WriteString(delimiter)
buffer.WriteString(date)
// 对数据进行 hmac 加密,并进行 base64 encode
hash := hmac.New(sha256.New, []byte(s.secret))
hash.Write(buffer.Bytes())
digest := base64.StdEncoding.EncodeToString(hash.Sum(nil))
ok = authorization == fmt.Sprintf("%s %s", s.key, digest)
return
}

26
pkg/sse/client.html Normal file
View File

@ -0,0 +1,26 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Server Sent Event</title>
</head>
<body>
<div id="event-data"></div>
</body>
<script>
const $stream = new EventSource("/stream");
const $log = document.querySelector('#event-data')
$stream.addEventListener("message", function (e) {
log(e.data)
});
function log(msg) {
$log.innerHTML += `<p>消息: ${msg}</p>`
$log.scrollTop += 1000
}
</script>
</html>

118
pkg/sse/server.go Normal file
View File

@ -0,0 +1,118 @@
package sse
import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"sync"
"sync/atomic"
)
var _ Server = (*event)(nil)
type Server interface {
HandlerFunc(auth func(c *gin.Context) (string, error)) gin.HandlerFunc
Push(user, name, msg string) bool
Broadcast(name, msg string)
}
type clientChan struct {
User string
Chan chan msgChan
}
type msgChan struct {
Name string
Message string
}
type event struct {
SessionList sync.Map
Count atomic.Int32
Register chan clientChan
Unregister chan string
}
func NewServer() Server {
e := &event{
SessionList: sync.Map{},
Count: atomic.Int32{},
Register: make(chan clientChan),
Unregister: make(chan string),
}
go e.listen()
return e
}
func (stream *event) listen() {
for {
select {
case client := <-stream.Register:
stream.SessionList.Store(client.User, client.Chan)
stream.Count.Add(1)
case user := <-stream.Unregister:
value, ok := stream.SessionList.Load(user)
if ok {
event := value.(chan msgChan)
close(event)
stream.SessionList.Delete(user)
stream.Count.Add(-1)
}
}
}
}
func (stream *event) HandlerFunc(auth func(c *gin.Context) (string, error)) gin.HandlerFunc {
return func(c *gin.Context) {
user, err := auth(c)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
e := make(chan msgChan)
client := clientChan{
User: user,
Chan: e,
}
stream.Register <- client
defer func() {
stream.Unregister <- user
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Stream(func(w io.Writer) bool {
if msg, ok := <-e; ok {
c.SSEvent(msg.Name, msg.Message)
return true
}
return false
})
c.Next()
}
}
func (stream *event) Push(user, name, msg string) bool {
value, ok := stream.SessionList.Load(user)
if ok {
val := value.(chan msgChan)
val <- msgChan{Name: name, Message: msg}
}
return false
}
func (stream *event) Broadcast(name, msg string) {
stream.SessionList.Range(func(user, value any) bool {
val := value.(chan msgChan)
val <- msgChan{Name: name, Message: msg}
return true
})
}

52
pkg/ticker/ticker.go Normal file
View File

@ -0,0 +1,52 @@
package ticker
import (
"context"
"time"
)
var _ Ticker = (*ticker)(nil)
type Ticker interface {
worker()
Process(fun func())
Stop()
}
type ticker struct {
ticker *time.Ticker
ctx context.Context
cancel context.CancelFunc
f func()
}
func New(d time.Duration) Ticker {
ctx, cancelFunc := context.WithCancel(context.Background())
return &ticker{
ticker: time.NewTicker(d),
ctx: ctx,
cancel: cancelFunc,
}
}
func (t *ticker) worker() {
for {
select {
case <-t.ticker.C:
t.f()
case <-t.ctx.Done():
return
}
}
}
func (t *ticker) Process(fun func()) {
t.f = fun
go t.worker()
}
func (t *ticker) Stop() {
t.ticker.Stop()
t.cancel()
}

View File

@ -0,0 +1,124 @@
package time_parse
import (
"github.com/jinzhu/now"
"math"
"net/http"
"time"
)
var (
cst = time.Local
)
func SetPRCLocation() error {
var err error
if cst, err = time.LoadLocation("Asia/Shanghai"); err != nil {
return err
}
return nil
}
func Now() *now.Now {
timeFormats := append(append(now.TimeFormats, time.DateTime), time.DateOnly)
c := &now.Config{
WeekStartDay: time.Monday,
TimeLocation: cst,
TimeFormats: timeFormats,
}
return c.With(time.Now().In(cst))
}
// RFC3339ToCSTLayout convert rfc3339 value to China standard time layout
// 2020-11-08T08:18:46+08:00 => 2020-11-08 08:18:46
func RFC3339ToCSTLayout(value string) (string, error) {
ts, err := time.Parse(time.RFC3339, value)
if err != nil {
return "", err
}
return ts.In(cst).Format(time.DateTime), nil
}
// GetMilliTimestamp 获取CST毫秒时间戳
func GetMilliTimestamp() int64 {
return time.Now().In(cst).UnixMilli()
}
// GetTimestamp 获取CST时间戳
func GetTimestamp() int64 {
return time.Now().In(cst).Unix()
}
// CSTLayoutString 格式化时间
// 返回 "2006-01-02 15:04:05" 格式的时间
func CSTLayoutString() string {
return time.Now().In(cst).Format(time.DateTime)
}
// ParseCSTInLocation 格式化时间
func ParseCSTInLocation(date string) (time.Time, error) {
return time.ParseInLocation(time.DateTime, date, cst)
}
// CSTLayoutStringToUnix 返回 unix 时间戳
// 2020-01-24 21:11:11 => 1579871471
func CSTLayoutStringToUnix(cstLayoutString string) (int64, error) {
stamp, err := time.ParseInLocation(time.DateTime, cstLayoutString, cst)
if err != nil {
return 0, err
}
return stamp.Unix(), nil
}
// GMTLayoutString 格式化时间
// 返回 "Mon, 02 Jan 2006 15:04:05 GMT" 格式的时间
func GMTLayoutString() string {
return time.Now().In(cst).Format(http.TimeFormat)
}
// ParseGMTInLocation 格式化时间
func ParseGMTInLocation(date string) (time.Time, error) {
return time.ParseInLocation(http.TimeFormat, date, cst)
}
// SubInLocation 计算时间差
func SubInLocation(ts time.Time) float64 {
return math.Abs(time.Now().In(cst).Sub(ts).Seconds())
}
// GetFirstDateOfMonth 获取传入的时间所在月份的第一天即某月第一天的0点。如传入time.Now(), 返回当前月份的第一天0点时间。
func GetFirstDateOfMonth(d time.Time) time.Time {
d = d.AddDate(0, 0, -d.Day()+1)
return GetStartDayTime(d)
}
// GetLastDateOfMonth 获取传入的时间所在月份的最后一天即某月最后一天的0点。如传入time.Now(), 返回当前月份的最后一天0点时间。
func GetLastDateOfMonth(d time.Time) time.Time {
return GetEndDayTime(GetFirstDateOfMonth(d).AddDate(0, 1, -1))
}
// GetStartDayTime 获取某一天的开始时间
func GetStartDayTime(d time.Time) time.Time {
return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, cst)
}
// GetEndDayTime 获取某一天的结束时间
func GetEndDayTime(d time.Time) time.Time {
return GetStartDayTime(d).Add(24 * time.Hour).Add(-time.Second)
}
// GetStartYesterdayTime 获取昨天的起始时间
func GetStartYesterdayTime() time.Time {
d, _ := time.ParseDuration("-24h")
return GetStartDayTime(time.Now().Add(d))
}
// SubInLocationDays 计算两个日期相差的天数
func SubInLocationDays(t1, t2 time.Time) float64 {
t1 = time.Date(t1.Year(), t1.Month(), t1.Day(), 0, 0, 0, 0, cst)
t2 = time.Date(t2.Year(), t2.Month(), t2.Day(), 0, 0, 0, 0, cst)
return math.Abs(t1.Sub(t2).Hours() / 24)
}

20
pkg/token/README.md Normal file
View File

@ -0,0 +1,20 @@
## 与 UrlSign 对应的 PHP 加密算法
```php
// 对 params key 进行排序
ksort($params);
// 对 sortParams 进行 Encode
$sortParamsEncode = http_build_query($params);
// 加密字符串规则 path + method + sortParamsEncode + secret
$encryptStr = $path . $method . $sortParamsEncode . $secret
// 对加密字符串进行 md5
$md5Str = md5($encryptStr);
// 对 md5Str 进行 base64 encode
$tokenString = base64_encode($md5Str);
echo $tokenString;
```

36
pkg/token/token.go Normal file
View File

@ -0,0 +1,36 @@
package token
import (
"net/url"
"time"
"github.com/golang-jwt/jwt/v4"
)
var _ Token = (*token)(nil)
type Token interface {
// i 为了避免被其他包实现
i()
// JwtSign JWT 签名方式
JwtSign(userId, subject string, expireDuration time.Duration) (tokenString string, err error)
JwtParse(tokenString string) (*jwt.RegisteredClaims, error)
// UrlSign URL 签名方式,不支持解密
UrlSign(path string, method string, params url.Values) (tokenString string, err error)
}
type token struct {
secret string
domain []string
}
func New(secret string, domain ...string) Token {
return &token{
secret: secret,
domain: domain,
}
}
func (t *token) i() {}

43
pkg/token/token_jwt.go Normal file
View File

@ -0,0 +1,43 @@
package token
import (
"github.com/golang-jwt/jwt/v4"
"time"
)
func (t *token) JwtSign(userId, subject string, expireDuration time.Duration) (tokenString string, err error) {
// The token content.
// iss: Issuer签发者
// iat: Issued At签发时间用Unix时间戳表示
// exp: Expiration Time过期时间用Unix时间戳表示
// aud: Audience接收该JWT的一方
// sub: Subject该JWT的主题
// nbf: Not Before不要早于这个时间
// jti: JWT ID用于标识JWT的唯一ID
c := &jwt.RegisteredClaims{
Issuer: "BvBeJ",
Subject: subject,
Audience: jwt.ClaimStrings(t.domain),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expireDuration)),
NotBefore: jwt.NewNumericDate(time.Now()),
IssuedAt: jwt.NewNumericDate(time.Now()),
ID: userId,
}
tokenString, err = jwt.NewWithClaims(jwt.SigningMethodHS256, c).SignedString([]byte(t.secret))
return
}
func (t *token) JwtParse(tokenString string) (*jwt.RegisteredClaims, error) {
tokenClaims, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
return []byte(t.secret), nil
})
if tokenClaims != nil {
if c, ok := tokenClaims.Claims.(*jwt.RegisteredClaims); ok && tokenClaims.Valid {
return c, nil
}
}
return nil, err
}

48
pkg/token/token_url.go Normal file
View File

@ -0,0 +1,48 @@
package token
import (
"crypto/md5"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/url"
"strings"
)
// UrlSign
// path 请求的路径 (不附带 querystring)
func (t *token) UrlSign(path string, method string, params url.Values) (tokenString string, err error) {
// 合法的 Methods
methods := map[string]bool{
"get": true,
"post": true,
"put": true,
"path": true,
"delete": true,
"head": true,
"options": true,
}
methodName := strings.ToLower(method)
if !methods[methodName] {
err = errors.New("method param error")
return
}
// Encode() 方法中自带 sorted by key
sortParamsEncode := params.Encode()
// 加密字符串规则 path + method + sortParamsEncode + secret
encryptStr := fmt.Sprintf("%s%s%s%s", path, methodName, sortParamsEncode, t.secret)
// 对加密字符串进行 md5
s := md5.New()
s.Write([]byte(encryptStr))
md5Str := hex.EncodeToString(s.Sum(nil))
// 对 md5Str 进行 base64 encode
tokenString = base64.StdEncoding.EncodeToString([]byte(md5Str))
return
}

81
pkg/trace/README.md Normal file
View File

@ -0,0 +1,81 @@
## trace
一个用于开发调试的辅助工具。
可以实时显示当前页面的操作的请求信息、运行情况、SQL执行、错误提示等。
- `trace.go` 主入口文件;
- `dialog.go` 处理 third_party_requests 记录;
- `debug.go` 处理 debug 记录;
#### 数据格式
##### trace_id
当前 trace 的 ID例如938ff86be98439c6c1a7便于搜索使用。
##### request
请求信息,会包括:
- ttl 请求超时时间例如2s 或 un-limit
- method 请求方式例如GET 或 POST
- decoded_url 请求地址
- header 请求头信息
- body 请求体信息
##### response
- header 响应头信息
- body 响应提信息
- business_code 业务码例如10010
- business_code_msg 业务码信息,例如:签名错误
- http_code HTTP 状态码例如200
- http_code_msg HTTP 状态码信息例如OK
- cost_seconds 耗费时长:单位秒,比如 0.001105661
##### third_party_requests
每一个第三方 http 请求都会生成如下的一组数据,多个请求会生成多组数据。
- request同上 request 结构一致
- response同上 response 结构一致
- success是否成功true 或 false
- cost_seconds耗费时长单位秒
注意response 中的 business_code、business_code_msg 为空,因为各个第三方返回结构不同,这两个字段为空。
##### sqls
执行的 SQL 信息,多个 SQL 会记录多组数据。
- timestamp时间格式2006-01-02 15:04:05
- stack文件地址和行号
- cost_seconds执行时长单位
- sqlSQL 语句
- rows_affected影响行数
##### debugs
- key 打印的标示
- value 打印的值
```cassandraql
// 调试时,使用这个方法:
p.Print("key", "value", p.WithTrace(c.Trace()))
```
只有参数中增加了 `p.WithTrace(c.Trace())`,才会记录到 `debugs` 中。
##### success
是否成功true 或 false
```cassandraql
success = !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK
```
##### cost_seconds
耗费时长:单位秒,比如 0.001105661

7
pkg/trace/debug.go Normal file
View File

@ -0,0 +1,7 @@
package trace
type Debug struct {
Key string `json:"key"` // 标示
Value any `json:"value"` // 值
CostSeconds float64 `json:"cost_seconds"` // 执行时间(单位秒)
}

32
pkg/trace/dialog.go Normal file
View File

@ -0,0 +1,32 @@
package trace
import "sync"
var _ D = (*Dialog)(nil)
type D interface {
i()
AppendResponse(resp *Response)
}
// Dialog 内部调用其它方接口的会话信息失败时会有retry操作所以 response 会有多次。
type Dialog struct {
mux sync.Mutex
Request *Request `json:"request"` // 请求信息
Responses []*Response `json:"responses"` // 返回信息
Success bool `json:"success"` // 是否成功true 或 false
CostSeconds float64 `json:"cost_seconds"` // 执行时长(单位秒)
}
func (d *Dialog) i() {}
// AppendResponse 按转的追加response信息
func (d *Dialog) AppendResponse(resp *Response) {
if resp == nil {
return
}
d.mux.Lock()
d.Responses = append(d.Responses, resp)
d.mux.Unlock()
}

17
pkg/trace/grpc.go Normal file
View File

@ -0,0 +1,17 @@
package trace
import (
"google.golang.org/grpc/metadata"
)
type Grpc struct {
Timestamp string `json:"timestamp"` // 时间格式2006-01-02 15:04:05
Addr string `json:"addr"` // 地址
Method string `json:"method"` // 操作方法
Meta metadata.MD `json:"meta"` // Mate
Request map[string]any `json:"request"` // 请求信息
Response map[string]any `json:"response"` // 返回信息
CostSeconds float64 `json:"cost_seconds"` // 执行时间(单位秒)
Code string `json:"err_code,omitempty"` // 错误码
Message string `json:"err_message,omitempty"` // 错误信息
}

10
pkg/trace/redis.go Normal file
View File

@ -0,0 +1,10 @@
package trace
type Redis struct {
Timestamp string `json:"timestamp"` // 时间格式2006-01-02 15:04:05
Handle string `json:"handle"` // 操作SET/GET 等
Key string `json:"key"` // Key
Value string `json:"value,omitempty"` // Value
TTL float64 `json:"ttl,omitempty"` // 超时时长(单位分)
CostSeconds float64 `json:"cost_seconds"` // 执行时间(单位秒)
}

9
pkg/trace/sql.go Normal file
View File

@ -0,0 +1,9 @@
package trace
type SQL struct {
Timestamp string `json:"timestamp"` // 时间格式2006-01-02 15:04:05
Stack string `json:"stack"` // 文件地址和行号
SQL string `json:"sql"` // SQL 语句
Rows int64 `json:"rows_affected"` // 影响行数
CostSeconds float64 `json:"cost_seconds"` // 执行时长(单位秒)
}

154
pkg/trace/trace.go Normal file
View File

@ -0,0 +1,154 @@
package trace
import (
"crypto/rand"
"encoding/hex"
"sync"
)
const Header = "X-TRACE-ID"
var _ T = (*Trace)(nil)
type T interface {
i()
ID() string
WithRequest(req *Request) *Trace
WithResponse(resp *Response) *Trace
AppendDialog(dialog *Dialog) *Trace
AppendDebug(debug *Debug) *Trace
AppendSQL(sql *SQL) *Trace
AppendRedis(redis *Redis) *Trace
AppendGRPC(grpc *Grpc) *Trace
}
// Trace 记录的参数
type Trace struct {
mux sync.Mutex
Identifier string `json:"trace_id"` // 链路ID
Request *Request `json:"request"` // 请求信息
Response *Response `json:"response"` // 返回信息
ThirdPartyRequests []*Dialog `json:"third_party_requests,omitempty"` // 调用第三方接口的信息
Debugs []*Debug `json:"debugs,omitempty"` // 调试信息
SQLs []*SQL `json:"sqls,omitempty"` // 执行的 SQL 信息
Redis []*Redis `json:"redis,omitempty"` // 执行的 Redis 信息
GRPCs []*Grpc `json:"grpc,omitempty"` // 执行的 gRPC 信息
Success bool `json:"success"` // 请求结果 true or false
CostSeconds float64 `json:"cost_seconds"` // 执行时长(单位秒)
}
// Request 请求信息
type Request struct {
TTL string `json:"ttl,omitempty"` // 请求超时时间
Method string `json:"method"` // 请求方式
DecodedURL string `json:"decoded_url"` // 请求地址
Header any `json:"header"` // 请求 Header 信息
Body any `json:"body"` // 请求 Body 信息
}
// Response 响应信息
type Response struct {
Header any `json:"header"` // Header 信息
Body any `json:"body"` // Body 信息
BusinessCode int `json:"business_code,omitempty"` // 业务码
BusinessCodeMsg string `json:"business_code_msg,omitempty"` // 提示信息
HttpCode int `json:"http_code"` // HTTP 状态码
HttpCodeMsg string `json:"http_code_msg"` // HTTP 状态码信息
CostSeconds float64 `json:"cost_seconds"` // 执行时间(单位秒)
}
func New(id string) *Trace {
if id == "" {
buf := make([]byte, 10)
_, _ = rand.Read(buf)
id = hex.EncodeToString(buf)
}
return &Trace{
Identifier: id,
}
}
func (t *Trace) i() {}
// ID 唯一标识符
func (t *Trace) ID() string {
return t.Identifier
}
// WithRequest 设置request
func (t *Trace) WithRequest(req *Request) *Trace {
t.Request = req
return t
}
// WithResponse 设置response
func (t *Trace) WithResponse(resp *Response) *Trace {
t.Response = resp
return t
}
// AppendDialog 安全的追加内部调用过程dialog
func (t *Trace) AppendDialog(dialog *Dialog) *Trace {
if dialog == nil {
return t
}
t.mux.Lock()
defer t.mux.Unlock()
t.ThirdPartyRequests = append(t.ThirdPartyRequests, dialog)
return t
}
// AppendDebug 追加 debug
func (t *Trace) AppendDebug(debug *Debug) *Trace {
if debug == nil {
return t
}
t.mux.Lock()
defer t.mux.Unlock()
t.Debugs = append(t.Debugs, debug)
return t
}
// AppendSQL 追加 SQL
func (t *Trace) AppendSQL(sql *SQL) *Trace {
if sql == nil {
return t
}
t.mux.Lock()
defer t.mux.Unlock()
t.SQLs = append(t.SQLs, sql)
return t
}
// AppendRedis 追加 Redis
func (t *Trace) AppendRedis(redis *Redis) *Trace {
if redis == nil {
return t
}
t.mux.Lock()
defer t.mux.Unlock()
t.Redis = append(t.Redis, redis)
return t
}
// AppendGRPC 追加 gRPC 调用信息
func (t *Trace) AppendGRPC(grpc *Grpc) *Trace {
if grpc == nil {
return t
}
t.mux.Lock()
defer t.mux.Unlock()
t.GRPCs = append(t.GRPCs, grpc)
return t
}

228
pkg/upload/tus.go Normal file
View File

@ -0,0 +1,228 @@
package upload
import (
"context"
"crypto/sha256"
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/color"
"git.bvbej.com/bvbej/base-golang/pkg/ticker"
"git.bvbej.com/bvbej/base-golang/pkg/token"
"github.com/rs/cors"
"github.com/tus/tusd/pkg/filestore"
tus "github.com/tus/tusd/pkg/handler"
"go.uber.org/zap"
"net/http"
"os"
"strings"
"sync"
"time"
)
var _ Server = (*server)(nil)
type Server interface {
GetUploadToken(string, string, time.Duration) string
GetFileInfo(string) (*tus.FileInfo, error)
Start(func(string, string, tus.FileInfo)) error
Stop() error
}
type server struct {
headerTokenKey string
uploading sync.Map
config Config
token token.Token
store filestore.FileStore
logger *zap.Logger
httpServer *http.Server
ctx context.Context
done context.CancelFunc
checker ticker.Ticker
completedEvent func(sha256, param string, info tus.FileInfo)
}
type Config struct {
ListenAddr string
Path string
Dir string
Secret string
DisableDownload bool
Debug bool
}
func New(conf Config, logger *zap.Logger) Server {
ctx, cancelFunc := context.WithCancel(context.Background())
return &server{
config: conf,
uploading: sync.Map{},
headerTokenKey: "Authorization",
logger: logger,
token: token.New(conf.Secret),
ctx: ctx,
done: cancelFunc,
checker: ticker.New(time.Minute),
}
}
func (s *server) GetUploadToken(sha256, param string, ttl time.Duration) string {
sign, _ := s.token.JwtSign(sha256, param, ttl)
return sign
}
func (s *server) GetFileInfo(id string) (*tus.FileInfo, error) {
upload, err := s.store.GetUpload(context.Background(), id)
if err != nil {
return nil, err
}
info, err := upload.GetInfo(context.Background())
if err != nil {
return nil, err
}
return &info, nil
}
func (s *server) Start(completedEvent func(sha256, param string, info tus.FileInfo)) error {
s.completedEvent = completedEvent
composer := tus.NewStoreComposer()
if err := os.MkdirAll(s.config.Dir, os.ModePerm); err != nil {
return err
}
s.store = filestore.New(s.config.Dir)
s.store.UseIn(composer)
handler, err := tus.NewHandler(tus.Config{
StoreComposer: composer,
BasePath: s.config.Path,
Logger: zap.NewStdLog(s.logger),
NotifyCompleteUploads: true,
NotifyTerminatedUploads: true,
DisableTermination: true,
DisableDownload: s.config.DisableDownload,
RespectForwardedHeaders: strings.Contains(s.config.ListenAddr, "127.0.0.1"),
PreUploadCreateCallback: func(hook tus.HookEvent) error {
authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey)
jwtClaims, err := s.token.JwtParse(authStr)
if err == nil {
_, ok := s.uploading.Load(authStr)
if !ok {
s.uploading.Store(authStr, jwtClaims.ExpiresAt.Time)
return nil
}
return errors.New("repeated")
}
return errors.New("unauthorized")
},
PreFinishResponseCallback: func(hook tus.HookEvent) error {
authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey)
jwtParse, err := s.token.JwtParse(authStr)
if err != nil {
return errors.New("token expired")
}
_, ok := s.uploading.Load(authStr)
if ok {
s.uploading.Delete(authStr)
}
upload, err := s.store.GetUpload(context.Background(), hook.Upload.ID)
if err != nil {
return err
}
info, err := upload.GetInfo(context.Background())
path, exist := info.Storage["Path"]
if err != nil || !exist {
return errors.New("file not found")
}
content, err := os.ReadFile(path)
if err != nil {
return err
}
hash := sha256.New()
hash.Write(content)
sha256Byte := hash.Sum(nil)
sha256String := fmt.Sprintf("%x", sha256Byte)
if !s.config.Debug && sha256String != strings.ToLower(jwtParse.ID) {
_ = os.Remove(path)
_ = os.Remove(path + ".info")
return errors.New("file check error")
}
return nil
},
})
if err != nil {
return err
}
go func() {
for {
select {
case event := <-handler.CompleteUploads:
authStr := event.HTTPRequest.Header.Get(s.headerTokenKey)
jwtParse, _ := s.token.JwtParse(authStr)
if s.completedEvent != nil {
go func() {
s.completedEvent(jwtParse.ID, jwtParse.Subject, event.Upload)
}()
}
case <-s.ctx.Done():
return
}
}
}()
go func() {
for {
select {
case event := <-handler.TerminatedUploads:
upload, _ := s.store.GetUpload(context.Background(), event.Upload.ID)
if upload != nil {
info, _ := upload.GetInfo(context.Background())
path, exist := info.Storage["Path"]
if exist {
_ = os.Remove(path)
_ = os.Remove(path + ".info")
}
}
case <-s.ctx.Done():
return
}
}
}()
s.checker.Process(func() {
s.uploading.Range(func(key, value any) bool {
t := value.(time.Time)
if t.Before(time.Now()) {
s.uploading.Delete(key)
}
return true
})
})
//监听服务
addr := s.config.ListenAddr
mux := http.NewServeMux()
mux.Handle(s.config.Path, http.StripPrefix(s.config.Path, handler))
s.httpServer = &http.Server{
Addr: addr,
Handler: cors.AllowAll().Handler(mux),
}
go func() {
if err = s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Sugar().Fatal("upload server startup err", zap.Error(err))
}
}()
fmt.Println(color.Green(fmt.Sprintf("* [register tusd listen %s]", addr)))
return nil
}
func (s *server) Stop() error {
s.done()
return s.httpServer.Close()
}

169
pkg/urltable/urltable.go Normal file
View File

@ -0,0 +1,169 @@
package urltable
import (
"errors"
"fmt"
"net/http"
"strings"
)
const (
empty = ""
fuzzy = "*"
omitted = "**"
delimiter = "/"
methodView = "VIEW"
)
// parse and validate pattern
func parse(pattern string) ([]string, error) {
const format = "[get, post, put, patch, delete, view]/{a-Z}+/{*}+/{**}"
if pattern = strings.TrimLeft(strings.TrimSpace(pattern), delimiter); pattern == "" {
return nil, fmt.Errorf("pattern illegal, should in format of %s", format)
}
paths := strings.Split(pattern, delimiter)
if len(paths) < 2 {
return nil, fmt.Errorf("pattern illegal, should in format of %s", format)
}
for i := range paths {
paths[i] = strings.TrimSpace(paths[i])
}
// likes get/ get/* get/**
if len(paths) == 2 && (paths[1] == empty || paths[1] == fuzzy || paths[1] == omitted) {
return nil, errors.New("illegal wildcard")
}
switch paths[0] = strings.ToUpper(paths[0]); paths[0] {
case http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
methodView:
default:
return nil, fmt.Errorf("only supports [%s %s %s %s %s %s]",
http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, methodView)
}
for k := 1; k < len(paths); k++ {
if paths[k] == empty && k+1 != len(paths) {
return nil, errors.New("pattern contains illegal empty path")
}
if paths[k] == omitted && k+1 != len(paths) {
return nil, errors.New("pattern contains illegal omitted path")
}
}
return paths, nil
}
// Format pattern
func Format(pattern string) (string, error) {
paths, err := parse(pattern)
if err != nil {
return "", err
}
return strings.Join(paths, delimiter), nil
}
type section struct {
leaf bool
mapping map[string]*section
}
func newSection() *section {
return &section{mapping: make(map[string]*section)}
}
// Table a (thread unsafe) table to store urls
type Table struct {
size int
root *section
}
// NewTable create a table
func NewTable() *Table {
return &Table{root: newSection()}
}
// Size contains how many urls
func (t *Table) Size() int {
return t.size
}
// Append pattern
func (t *Table) Append(pattern string) error {
paths, err := parse(pattern)
if err != nil {
return err
}
insert := false
root := t.root
for i, path := range paths {
if (path == fuzzy && root.mapping[omitted] != nil) ||
(path == omitted && root.mapping[fuzzy] != nil) ||
(path != omitted && root.mapping[omitted] != nil) {
return fmt.Errorf("conflict at %s", strings.Join(paths[:i], delimiter))
}
next := root.mapping[path]
if next == nil {
next = newSection()
root.mapping[path] = next
insert = true
}
root = next
}
if insert {
t.size++
}
root.leaf = true
return nil
}
// Mapping url to pattern
func (t *Table) Mapping(url string) (string, error) {
paths, err := parse(url)
if err != nil {
return "", err
}
pattern := make([]string, 0, len(paths))
root := t.root
for _, path := range paths {
next := root.mapping[path]
if next == nil {
nextFuzzy := root.mapping[fuzzy]
nextOmitted := root.mapping[omitted]
if nextFuzzy == nil && nextOmitted == nil {
return "", nil
}
if nextOmitted != nil {
pattern = append(pattern, omitted)
return strings.Join(pattern, delimiter), nil
}
next = nextFuzzy
path = fuzzy
}
root = next
pattern = append(pattern, path)
}
if root.leaf {
return strings.Join(pattern, delimiter), nil
}
return "", nil
}

Some files were not shown because too many files have changed in this diff Show More