first commit
This commit is contained in:
		
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
/.idea
 | 
			
		||||
/vendor
 | 
			
		||||
 | 
			
		||||
*.yaml.json
 | 
			
		||||
*_test.go
 | 
			
		||||
							
								
								
									
										4
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
			
		||||
## Golang公共包
 | 
			
		||||
 | 
			
		||||
* pb协议文件生成命令
 | 
			
		||||
`protoc --proto_path=pkg/websocket/codec/protobuf/protocol --go_out=pkg/websocket/codec/protobuf/protocol --go_opt=paths=source_relative base.proto`
 | 
			
		||||
							
								
								
									
										119
									
								
								go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								go.mod
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
			
		||||
module git.bvbej.com/bvbej/base-golang
 | 
			
		||||
 | 
			
		||||
go 1.22.4
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/apolloconfig/agollo/v4 v4.4.0
 | 
			
		||||
	github.com/gin-contrib/pprof v1.5.0
 | 
			
		||||
	github.com/gin-gonic/gin v1.10.0
 | 
			
		||||
	github.com/go-playground/validator/v10 v10.22.0
 | 
			
		||||
	github.com/golang-jwt/jwt/v4 v4.5.0
 | 
			
		||||
	github.com/google/uuid v1.6.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.3
 | 
			
		||||
	github.com/jinzhu/now v1.1.5
 | 
			
		||||
	github.com/json-iterator/go v1.1.12
 | 
			
		||||
	github.com/mojocn/base64Captcha v1.3.6
 | 
			
		||||
	github.com/mritd/chinaid v1.0.4
 | 
			
		||||
	github.com/panjf2000/ants/v2 v2.10.0
 | 
			
		||||
	github.com/prometheus/client_golang v1.19.1
 | 
			
		||||
	github.com/qiniu/go-sdk/v7 v7.21.1
 | 
			
		||||
	github.com/redis/go-redis/v9 v9.5.3
 | 
			
		||||
	github.com/robfig/cron/v3 v3.0.1
 | 
			
		||||
	github.com/rs/cors v1.11.0
 | 
			
		||||
	github.com/rs/cors/wrapper/gin v0.0.0-20240515105523-1562b1715b35
 | 
			
		||||
	github.com/speps/go-hashids v2.0.0+incompatible
 | 
			
		||||
	github.com/spf13/cast v1.6.0
 | 
			
		||||
	github.com/spf13/viper v1.19.0
 | 
			
		||||
	github.com/stretchr/testify v1.9.0
 | 
			
		||||
	github.com/tidwall/buntdb v1.3.1
 | 
			
		||||
	github.com/tidwall/gjson v1.17.1
 | 
			
		||||
	github.com/tus/tusd v1.13.0
 | 
			
		||||
	github.com/xuri/excelize/v2 v2.8.1
 | 
			
		||||
	go.mongodb.org/mongo-driver v1.15.1
 | 
			
		||||
	go.uber.org/atomic v1.11.0
 | 
			
		||||
	go.uber.org/multierr v1.11.0
 | 
			
		||||
	go.uber.org/zap v1.27.0
 | 
			
		||||
	golang.org/x/crypto v0.24.0
 | 
			
		||||
	golang.org/x/net v0.26.0
 | 
			
		||||
	golang.org/x/sync v0.7.0
 | 
			
		||||
	golang.org/x/time v0.5.0
 | 
			
		||||
	google.golang.org/grpc v1.64.0
 | 
			
		||||
	google.golang.org/protobuf v1.34.2
 | 
			
		||||
	gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
 | 
			
		||||
	gopkg.in/natefinch/lumberjack.v2 v2.2.1
 | 
			
		||||
	gorm.io/driver/mysql v1.5.7
 | 
			
		||||
	gorm.io/gorm v1.25.10
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/BurntSushi/toml v1.3.2 // indirect
 | 
			
		||||
	github.com/alex-ant/gomath v0.0.0-20160516115720-89013a210a82 // indirect
 | 
			
		||||
	github.com/beorn7/perks v1.0.1 // indirect
 | 
			
		||||
	github.com/bmizerany/pat v0.0.0-20170815010413-6226ea591a40 // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.11.6 // indirect
 | 
			
		||||
	github.com/bytedance/sonic/loader v0.1.1 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 | 
			
		||||
	github.com/cloudwego/base64x v0.1.4 // indirect
 | 
			
		||||
	github.com/cloudwego/iasm v0.2.0 // indirect
 | 
			
		||||
	github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
 | 
			
		||||
	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
 | 
			
		||||
	github.com/fsnotify/fsnotify v1.7.0 // indirect
 | 
			
		||||
	github.com/gabriel-vasile/mimetype v1.4.3 // indirect
 | 
			
		||||
	github.com/gin-contrib/sse v0.1.0 // indirect
 | 
			
		||||
	github.com/go-playground/locales v0.14.1 // indirect
 | 
			
		||||
	github.com/go-playground/universal-translator v0.18.1 // indirect
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.7.0 // indirect
 | 
			
		||||
	github.com/goccy/go-json v0.10.2 // indirect
 | 
			
		||||
	github.com/gofrs/flock v0.8.1 // indirect
 | 
			
		||||
	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
 | 
			
		||||
	github.com/golang/snappy v0.0.4 // indirect
 | 
			
		||||
	github.com/hashicorp/hcl v1.0.0 // indirect
 | 
			
		||||
	github.com/jinzhu/inflection v1.0.0 // indirect
 | 
			
		||||
	github.com/klauspost/compress v1.17.2 // indirect
 | 
			
		||||
	github.com/klauspost/cpuid/v2 v2.2.7 // indirect
 | 
			
		||||
	github.com/leodido/go-urn v1.4.0 // indirect
 | 
			
		||||
	github.com/magiconair/properties v1.8.7 // indirect
 | 
			
		||||
	github.com/matishsiao/goInfo v0.0.0-20210923090445-da2e3fa8d45f // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.20 // indirect
 | 
			
		||||
	github.com/mitchellh/mapstructure v1.5.0 // indirect
 | 
			
		||||
	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 | 
			
		||||
	github.com/modern-go/reflect2 v1.0.2 // indirect
 | 
			
		||||
	github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
 | 
			
		||||
	github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
 | 
			
		||||
	github.com/pelletier/go-toml/v2 v2.2.2 // indirect
 | 
			
		||||
	github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
 | 
			
		||||
	github.com/prometheus/client_model v0.5.0 // indirect
 | 
			
		||||
	github.com/prometheus/common v0.48.0 // indirect
 | 
			
		||||
	github.com/prometheus/procfs v0.12.0 // indirect
 | 
			
		||||
	github.com/richardlehane/mscfb v1.0.4 // indirect
 | 
			
		||||
	github.com/richardlehane/msoleps v1.0.3 // indirect
 | 
			
		||||
	github.com/sagikazarmark/locafero v0.4.0 // indirect
 | 
			
		||||
	github.com/sagikazarmark/slog-shim v0.1.0 // indirect
 | 
			
		||||
	github.com/sourcegraph/conc v0.3.0 // indirect
 | 
			
		||||
	github.com/spf13/afero v1.11.0 // indirect
 | 
			
		||||
	github.com/spf13/pflag v1.0.5 // indirect
 | 
			
		||||
	github.com/subosito/gotenv v1.6.0 // indirect
 | 
			
		||||
	github.com/tidwall/btree v1.4.2 // indirect
 | 
			
		||||
	github.com/tidwall/grect v0.1.4 // indirect
 | 
			
		||||
	github.com/tidwall/match v1.1.1 // indirect
 | 
			
		||||
	github.com/tidwall/pretty v1.2.0 // indirect
 | 
			
		||||
	github.com/tidwall/rtred v0.1.2 // indirect
 | 
			
		||||
	github.com/tidwall/tinyqueue v0.1.1 // indirect
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	github.com/ugorji/go/codec v1.2.12 // indirect
 | 
			
		||||
	github.com/xdg-go/pbkdf2 v1.0.0 // indirect
 | 
			
		||||
	github.com/xdg-go/scram v1.1.2 // indirect
 | 
			
		||||
	github.com/xdg-go/stringprep v1.0.4 // indirect
 | 
			
		||||
	github.com/xuri/efp v0.0.0-20231025114914-d1ff6096ae53 // indirect
 | 
			
		||||
	github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05 // indirect
 | 
			
		||||
	github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
 | 
			
		||||
	golang.org/x/arch v0.8.0 // indirect
 | 
			
		||||
	golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
 | 
			
		||||
	golang.org/x/image v0.14.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.21.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.16.0 // indirect
 | 
			
		||||
	google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
 | 
			
		||||
	gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
 | 
			
		||||
	gopkg.in/ini.v1 v1.67.0 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										120
									
								
								pkg/aes/aes.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								pkg/aes/aes.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,120 @@
 | 
			
		||||
package aes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	cryptoAes "crypto/aes"
 | 
			
		||||
	"crypto/cipher"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Aes = (*aes)(nil)
 | 
			
		||||
 | 
			
		||||
type Aes interface {
 | 
			
		||||
	i()
 | 
			
		||||
 | 
			
		||||
	EncryptCBC(encryptStr string, urlEncode bool) (string, error)
 | 
			
		||||
	DecryptCBC(decryptStr string, urlEncode bool) (string, error)
 | 
			
		||||
 | 
			
		||||
	EncryptCFB(plain string, urlEncode bool) (string, error)
 | 
			
		||||
	DecryptCFB(encrypted string, urlEncode bool) (string, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type aes struct {
 | 
			
		||||
	key string
 | 
			
		||||
	iv  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(key, iv string) Aes {
 | 
			
		||||
	return &aes{
 | 
			
		||||
		key: key,
 | 
			
		||||
		iv:  iv,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *aes) i() {}
 | 
			
		||||
 | 
			
		||||
func (a *aes) EncryptCBC(encryptStr string, urlEncode bool) (string, error) {
 | 
			
		||||
	encoder := base64.StdEncoding
 | 
			
		||||
	if urlEncode {
 | 
			
		||||
		encoder = base64.URLEncoding
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	encryptBytes := []byte(encryptStr)
 | 
			
		||||
	block, err := cryptoAes.NewCipher([]byte(a.key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	blockSize := block.BlockSize()
 | 
			
		||||
	encryptBytes = pkcsPadding(encryptBytes, blockSize)
 | 
			
		||||
	blockMode := cipher.NewCBCEncrypter(block, []byte(a.iv))
 | 
			
		||||
	encrypted := make([]byte, len(encryptBytes))
 | 
			
		||||
	blockMode.CryptBlocks(encrypted, encryptBytes)
 | 
			
		||||
 | 
			
		||||
	return encoder.EncodeToString(encrypted), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *aes) DecryptCBC(decryptStr string, urlEncode bool) (string, error) {
 | 
			
		||||
	encoder := base64.StdEncoding
 | 
			
		||||
	if urlEncode {
 | 
			
		||||
		encoder = base64.URLEncoding
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	decryptBytes, err := encoder.DecodeString(decryptStr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	block, err := cryptoAes.NewCipher([]byte(a.key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	blockMode := cipher.NewCBCDecrypter(block, []byte(a.iv))
 | 
			
		||||
	decrypted := make([]byte, len(decryptBytes))
 | 
			
		||||
	blockMode.CryptBlocks(decrypted, decryptBytes)
 | 
			
		||||
	decrypted = pkcsUnPadding(decrypted)
 | 
			
		||||
 | 
			
		||||
	return string(decrypted), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *aes) EncryptCFB(plain string, urlEncode bool) (string, error) {
 | 
			
		||||
	encoder := base64.StdEncoding
 | 
			
		||||
	if urlEncode {
 | 
			
		||||
		encoder = base64.URLEncoding
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	block, err := cryptoAes.NewCipher([]byte(a.key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	encrypted := make([]byte, len(plain))
 | 
			
		||||
	stream := cipher.NewCFBEncrypter(block, []byte(a.iv))
 | 
			
		||||
	stream.XORKeyStream(encrypted, []byte(plain))
 | 
			
		||||
 | 
			
		||||
	return encoder.EncodeToString(encrypted), nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *aes) DecryptCFB(encrypted string, urlEncode bool) (string, error) {
 | 
			
		||||
	encoder := base64.StdEncoding
 | 
			
		||||
	if urlEncode {
 | 
			
		||||
		encoder = base64.URLEncoding
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	decryptBytes, err := encoder.DecodeString(encrypted)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	block, err := cryptoAes.NewCipher([]byte(a.key))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	plain := make([]byte, len(decryptBytes))
 | 
			
		||||
	stream := cipher.NewCFBDecrypter(block, []byte(a.iv))
 | 
			
		||||
	stream.XORKeyStream(plain, decryptBytes)
 | 
			
		||||
 | 
			
		||||
	return string(plain), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										25
									
								
								pkg/aes/padding.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								pkg/aes/padding.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package aes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * 填充有六种:NoPadding, PKCS#5, PKCS#7, ISO 10126, ANSI X9.23和ZerosPadding
 | 
			
		||||
 * 对于AES来说PKCS5Padding和PKCS7Padding是完全一样的,不同在于PKCS5限定了块大小为8bytes而PKCS7没有限定
 | 
			
		||||
 * 因此对于AES来说两者完全相同,但是对于Rijndael就不一样了
 | 
			
		||||
 * AES是Rijndael在块大小为8bytes时的特例,对于使用其他信息块大小的Rijndael算法只能使用PKCS7
 | 
			
		||||
 * 在AES加密当中严格来说是不能使用pkcs5的,因为AES的块大小是16bytes而pkcs5只能用于8bytes,通常我们在AES加密中所说的pkcs5指的就是pkcs7
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
func pkcsPadding(cipherText []byte, blockSize int) []byte {
 | 
			
		||||
	p := blockSize - len(cipherText)%blockSize
 | 
			
		||||
	padText := bytes.Repeat([]byte{byte(p)}, p)
 | 
			
		||||
	return append(cipherText, padText...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func pkcsUnPadding(decrypted []byte) []byte {
 | 
			
		||||
	length := len(decrypted)
 | 
			
		||||
	u := int(decrypted[length-1])
 | 
			
		||||
	return decrypted[:(length - u)]
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										196
									
								
								pkg/android_binary/apk/apk.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								pkg/android_binary/apk/apk.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,196 @@
 | 
			
		||||
package apk
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"archive/zip"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/android_binary"
 | 
			
		||||
	"image"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
	_ "image/jpeg" // handle jpeg format
 | 
			
		||||
	_ "image/png"  // handle png format
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Apk is an application package file for android.
 | 
			
		||||
type Apk struct {
 | 
			
		||||
	f         *os.File
 | 
			
		||||
	zipreader *zip.Reader
 | 
			
		||||
	manifest  Manifest
 | 
			
		||||
	table     *android_binary.TableFile
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenFile will open the file specified by filename and return Apk
 | 
			
		||||
func OpenFile(filename string) (apk *Apk, err error) {
 | 
			
		||||
	f, err := os.Open(filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			f.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	fi, err := f.Stat()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	apk, err = OpenZipReader(f, fi.Size())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	apk.f = f
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OpenZipReader has same arguments like zip.NewReader
 | 
			
		||||
func OpenZipReader(r io.ReaderAt, size int64) (*Apk, error) {
 | 
			
		||||
	zipreader, err := zip.NewReader(r, size)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	apk := &Apk{
 | 
			
		||||
		zipreader: zipreader,
 | 
			
		||||
	}
 | 
			
		||||
	if err = apk.parseResources(); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if err = apk.parseManifest(); err != nil {
 | 
			
		||||
		return nil, errorf("parse-manifest: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return apk, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close is avaliable only if apk is created with OpenFile
 | 
			
		||||
func (k *Apk) Close() error {
 | 
			
		||||
	if k.f == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return k.f.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Icon returns the icon image of the APK.
 | 
			
		||||
func (k *Apk) Icon(resConfig *android_binary.ResTableConfig) (image.Image, error) {
 | 
			
		||||
	iconPath, err := k.manifest.App.Icon.WithResTableConfig(resConfig).String()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if android_binary.IsResID(iconPath) {
 | 
			
		||||
		return nil, newError("unable to convert icon-id to icon path")
 | 
			
		||||
	}
 | 
			
		||||
	imgData, err := k.readZipFile(iconPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	m, _, err := image.Decode(bytes.NewReader(imgData))
 | 
			
		||||
	return m, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Label returns the label of the APK.
 | 
			
		||||
func (k *Apk) Label(resConfig *android_binary.ResTableConfig) (s string, err error) {
 | 
			
		||||
	s, err = k.manifest.App.Label.WithResTableConfig(resConfig).String()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if android_binary.IsResID(s) {
 | 
			
		||||
		err = newError("unable to convert label-id to string")
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Manifest returns the manifest of the APK.
 | 
			
		||||
func (k *Apk) Manifest() Manifest {
 | 
			
		||||
	return k.manifest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PackageName returns the package name of the APK.
 | 
			
		||||
func (k *Apk) PackageName() string {
 | 
			
		||||
	return k.manifest.Package.MustString()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isMainIntentFilter(intent ActivityIntentFilter) bool {
 | 
			
		||||
	ok := false
 | 
			
		||||
	for _, action := range intent.Actions {
 | 
			
		||||
		s, err := action.Name.String()
 | 
			
		||||
		if err == nil && s == "android.intent.action.MAIN" {
 | 
			
		||||
			ok = true
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	ok = false
 | 
			
		||||
	for _, category := range intent.Categories {
 | 
			
		||||
		s, err := category.Name.String()
 | 
			
		||||
		if err == nil && s == "android.intent.category.LAUNCHER" {
 | 
			
		||||
			ok = true
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MainActivity returns the name of the main activity.
 | 
			
		||||
func (k *Apk) MainActivity() (activity string, err error) {
 | 
			
		||||
	for _, act := range k.manifest.App.Activities {
 | 
			
		||||
		for _, intent := range act.IntentFilters {
 | 
			
		||||
			if isMainIntentFilter(intent) {
 | 
			
		||||
				return act.Name.String()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for _, act := range k.manifest.App.ActivityAliases {
 | 
			
		||||
		for _, intent := range act.IntentFilters {
 | 
			
		||||
			if isMainIntentFilter(intent) {
 | 
			
		||||
				return act.TargetActivity.String()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return "", newError("No main activity found")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (k *Apk) parseManifest() error {
 | 
			
		||||
	xmlData, err := k.readZipFile("AndroidManifest.xml")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorf("failed to read AndroidManifest.xml: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	xmlfile, err := android_binary.NewXMLFile(bytes.NewReader(xmlData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorf("failed to parse AndroidManifest.xml: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return xmlfile.Decode(&k.manifest, k.table, nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (k *Apk) parseResources() (err error) {
 | 
			
		||||
	resData, err := k.readZipFile("resources.arsc")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	k.table, err = android_binary.NewTableFile(bytes.NewReader(resData))
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (k *Apk) readZipFile(name string) (data []byte, err error) {
 | 
			
		||||
	buf := bytes.NewBuffer(nil)
 | 
			
		||||
	for _, file := range k.zipreader.File {
 | 
			
		||||
		if file.Name != name {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		rc, er := file.Open()
 | 
			
		||||
		if er != nil {
 | 
			
		||||
			err = er
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defer rc.Close()
 | 
			
		||||
		_, err = io.Copy(buf, rc)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		return buf.Bytes(), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, fmt.Errorf("File %s not found", strconv.Quote(name))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										108
									
								
								pkg/android_binary/apk/apkxml.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								pkg/android_binary/apk/apkxml.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,108 @@
 | 
			
		||||
package apk
 | 
			
		||||
 | 
			
		||||
import "git.bvbej.com/bvbej/base-golang/pkg/android_binary"
 | 
			
		||||
 | 
			
		||||
// Instrumentation is an application instrumentation code.
 | 
			
		||||
type Instrumentation struct {
 | 
			
		||||
	Name            android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
	Target          android_binary.String `xml:"http://schemas.android.com/apk/res/android targetPackage,attr"`
 | 
			
		||||
	HandleProfiling android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android handleProfiling,attr"`
 | 
			
		||||
	FunctionalTest  android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android functionalTest,attr"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ActivityAction is an action of an activity.
 | 
			
		||||
type ActivityAction struct {
 | 
			
		||||
	Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ActivityCategory is a category of an activity.
 | 
			
		||||
type ActivityCategory struct {
 | 
			
		||||
	Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ActivityIntentFilter is an androidbinary.Int32ent filter of an activity.
 | 
			
		||||
type ActivityIntentFilter struct {
 | 
			
		||||
	Actions    []ActivityAction   `xml:"action"`
 | 
			
		||||
	Categories []ActivityCategory `xml:"category"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AppActivity is an activity in an application.
 | 
			
		||||
type AppActivity struct {
 | 
			
		||||
	Theme             android_binary.String  `xml:"http://schemas.android.com/apk/res/android theme,attr"`
 | 
			
		||||
	Name              android_binary.String  `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
	Label             android_binary.String  `xml:"http://schemas.android.com/apk/res/android label,attr"`
 | 
			
		||||
	ScreenOrientation android_binary.String  `xml:"http://schemas.android.com/apk/res/android screenOrientation,attr"`
 | 
			
		||||
	IntentFilters     []ActivityIntentFilter `xml:"intent-filter"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AppActivityAlias https://developer.android.com/guide/topics/manifest/activity-alias-element
 | 
			
		||||
type AppActivityAlias struct {
 | 
			
		||||
	Name           android_binary.String  `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
	Label          android_binary.String  `xml:"http://schemas.android.com/apk/res/android label,attr"`
 | 
			
		||||
	TargetActivity android_binary.String  `xml:"http://schemas.android.com/apk/res/android targetActivity,attr"`
 | 
			
		||||
	IntentFilters  []ActivityIntentFilter `xml:"intent-filter"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MetaData is a metadata in an application.
 | 
			
		||||
type MetaData struct {
 | 
			
		||||
	Name  android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
	Value android_binary.String `xml:"http://schemas.android.com/apk/res/android value,attr"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Application is an application in an APK.
 | 
			
		||||
type Application struct {
 | 
			
		||||
	AllowTaskReparenting  android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android allowTaskReparenting,attr"`
 | 
			
		||||
	AllowBackup           android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android allowBackup,attr"`
 | 
			
		||||
	BackupAgent           android_binary.String `xml:"http://schemas.android.com/apk/res/android backupAgent,attr"`
 | 
			
		||||
	Debuggable            android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android debuggable,attr"`
 | 
			
		||||
	Description           android_binary.String `xml:"http://schemas.android.com/apk/res/android description,attr"`
 | 
			
		||||
	Enabled               android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android enabled,attr"`
 | 
			
		||||
	HasCode               android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android hasCode,attr"`
 | 
			
		||||
	HardwareAccelerated   android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android hardwareAccelerated,attr"`
 | 
			
		||||
	Icon                  android_binary.String `xml:"http://schemas.android.com/apk/res/android icon,attr"`
 | 
			
		||||
	KillAfterRestore      android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android killAfterRestore,attr"`
 | 
			
		||||
	LargeHeap             android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android largeHeap,attr"`
 | 
			
		||||
	Label                 android_binary.String `xml:"http://schemas.android.com/apk/res/android label,attr"`
 | 
			
		||||
	Logo                  android_binary.String `xml:"http://schemas.android.com/apk/res/android logo,attr"`
 | 
			
		||||
	ManageSpaceActivity   android_binary.String `xml:"http://schemas.android.com/apk/res/android manageSpaceActivity,attr"`
 | 
			
		||||
	Name                  android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
	Permission            android_binary.String `xml:"http://schemas.android.com/apk/res/android permission,attr"`
 | 
			
		||||
	Persistent            android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android persistent,attr"`
 | 
			
		||||
	Process               android_binary.String `xml:"http://schemas.android.com/apk/res/android process,attr"`
 | 
			
		||||
	RestoreAnyVersion     android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android restoreAnyVersion,attr"`
 | 
			
		||||
	RequiredAccountType   android_binary.String `xml:"http://schemas.android.com/apk/res/android requiredAccountType,attr"`
 | 
			
		||||
	RestrictedAccountType android_binary.String `xml:"http://schemas.android.com/apk/res/android restrictedAccountType,attr"`
 | 
			
		||||
	SupportsRtl           android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android supportsRtl,attr"`
 | 
			
		||||
	TaskAffinity          android_binary.String `xml:"http://schemas.android.com/apk/res/android taskAffinity,attr"`
 | 
			
		||||
	TestOnly              android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android testOnly,attr"`
 | 
			
		||||
	Theme                 android_binary.String `xml:"http://schemas.android.com/apk/res/android theme,attr"`
 | 
			
		||||
	UIOptions             android_binary.String `xml:"http://schemas.android.com/apk/res/android uiOptions,attr"`
 | 
			
		||||
	VMSafeMode            android_binary.Bool   `xml:"http://schemas.android.com/apk/res/android vmSafeMode,attr"`
 | 
			
		||||
	Activities            []AppActivity         `xml:"activity"`
 | 
			
		||||
	ActivityAliases       []AppActivityAlias    `xml:"activity-alias"`
 | 
			
		||||
	MetaData              []MetaData            `xml:"meta-data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UsesSDK is target SDK version.
 | 
			
		||||
type UsesSDK struct {
 | 
			
		||||
	Min    android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android minSdkVersion,attr"`
 | 
			
		||||
	Target android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android targetSdkVersion,attr"`
 | 
			
		||||
	Max    android_binary.Int32 `xml:"http://schemas.android.com/apk/res/android maxSdkVersion,attr"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UsesPermission is user grant the system permission.
 | 
			
		||||
type UsesPermission struct {
 | 
			
		||||
	Name android_binary.String `xml:"http://schemas.android.com/apk/res/android name,attr"`
 | 
			
		||||
	Max  android_binary.Int32  `xml:"http://schemas.android.com/apk/res/android maxSdkVersion,attr"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Manifest is a manifest of an APK.
 | 
			
		||||
type Manifest struct {
 | 
			
		||||
	Package         android_binary.String `xml:"package,attr"`
 | 
			
		||||
	VersionCode     android_binary.Int32  `xml:"http://schemas.android.com/apk/res/android versionCode,attr"`
 | 
			
		||||
	VersionName     android_binary.String `xml:"http://schemas.android.com/apk/res/android versionName,attr"`
 | 
			
		||||
	App             Application           `xml:"application"`
 | 
			
		||||
	Instrument      Instrumentation       `xml:"instrumentation"`
 | 
			
		||||
	SDK             UsesSDK               `xml:"uses-sdk"`
 | 
			
		||||
	UsesPermissions []UsesPermission      `xml:"uses-permission"`
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								pkg/android_binary/apk/errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								pkg/android_binary/apk/errors.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
package apk
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var newError = errors.New
 | 
			
		||||
var errorf = fmt.Errorf
 | 
			
		||||
							
								
								
									
										258
									
								
								pkg/android_binary/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										258
									
								
								pkg/android_binary/common.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,258 @@
 | 
			
		||||
package android_binary
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"io"
 | 
			
		||||
	"unicode/utf16"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ChunkType is a type of a resource chunk.
 | 
			
		||||
type ChunkType uint16
 | 
			
		||||
 | 
			
		||||
// Chunk types.
 | 
			
		||||
const (
 | 
			
		||||
	ResNullChunkType       ChunkType = 0x0000
 | 
			
		||||
	ResStringPoolChunkType ChunkType = 0x0001
 | 
			
		||||
	ResTableChunkType      ChunkType = 0x0002
 | 
			
		||||
	ResXMLChunkType        ChunkType = 0x0003
 | 
			
		||||
 | 
			
		||||
	// Chunk types in RES_XML_TYPE
 | 
			
		||||
	ResXMLFirstChunkType     ChunkType = 0x0100
 | 
			
		||||
	ResXMLStartNamespaceType ChunkType = 0x0100
 | 
			
		||||
	ResXMLEndNamespaceType   ChunkType = 0x0101
 | 
			
		||||
	ResXMLStartElementType   ChunkType = 0x0102
 | 
			
		||||
	ResXMLEndElementType     ChunkType = 0x0103
 | 
			
		||||
	ResXMLCDataType          ChunkType = 0x0104
 | 
			
		||||
	ResXMLLastChunkType      ChunkType = 0x017f
 | 
			
		||||
 | 
			
		||||
	// This contains a uint32_t array mapping strings in the string
 | 
			
		||||
	// pool back to resource identifiers.  It is optional.
 | 
			
		||||
	ResXMLResourceMapType ChunkType = 0x0180
 | 
			
		||||
 | 
			
		||||
	// Chunk types in RES_TABLE_TYPE
 | 
			
		||||
	ResTablePackageType  ChunkType = 0x0200
 | 
			
		||||
	ResTableTypeType     ChunkType = 0x0201
 | 
			
		||||
	ResTableTypeSpecType ChunkType = 0x0202
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ResChunkHeader is a header of a resource chunk.
 | 
			
		||||
type ResChunkHeader struct {
 | 
			
		||||
	Type       ChunkType
 | 
			
		||||
	HeaderSize uint16
 | 
			
		||||
	Size       uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Flags are flags for string pool header.
 | 
			
		||||
type Flags uint32
 | 
			
		||||
 | 
			
		||||
// the values of Flags.
 | 
			
		||||
const (
 | 
			
		||||
	SortedFlag Flags = 1 << 0
 | 
			
		||||
	UTF8Flag   Flags = 1 << 8
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ResStringPoolHeader is a chunk header of string pool.
 | 
			
		||||
type ResStringPoolHeader struct {
 | 
			
		||||
	Header      ResChunkHeader
 | 
			
		||||
	StringCount uint32
 | 
			
		||||
	StyleCount  uint32
 | 
			
		||||
	Flags       Flags
 | 
			
		||||
	StringStart uint32
 | 
			
		||||
	StylesStart uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResStringPoolSpan is a span of style information associated with
 | 
			
		||||
// a string in the pool.
 | 
			
		||||
type ResStringPoolSpan struct {
 | 
			
		||||
	FirstChar, LastChar uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResStringPool is a string pool resource.
 | 
			
		||||
type ResStringPool struct {
 | 
			
		||||
	Header  ResStringPoolHeader
 | 
			
		||||
	Strings []string
 | 
			
		||||
	Styles  []ResStringPoolSpan
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NilResStringPoolRef is nil reference for string pool.
 | 
			
		||||
const NilResStringPoolRef = ResStringPoolRef(0xFFFFFFFF)
 | 
			
		||||
 | 
			
		||||
// ResStringPoolRef is a type representing a reference to a string.
 | 
			
		||||
type ResStringPoolRef uint32
 | 
			
		||||
 | 
			
		||||
// DataType is a type of the data value.
 | 
			
		||||
type DataType uint8
 | 
			
		||||
 | 
			
		||||
// The constants for DataType
 | 
			
		||||
const (
 | 
			
		||||
	TypeNull          DataType = 0x00
 | 
			
		||||
	TypeReference     DataType = 0x01
 | 
			
		||||
	TypeAttribute     DataType = 0x02
 | 
			
		||||
	TypeString        DataType = 0x03
 | 
			
		||||
	TypeFloat         DataType = 0x04
 | 
			
		||||
	TypeDemention     DataType = 0x05
 | 
			
		||||
	TypeFraction      DataType = 0x06
 | 
			
		||||
	TypeFirstInt      DataType = 0x10
 | 
			
		||||
	TypeIntDec        DataType = 0x10
 | 
			
		||||
	TypeIntHex        DataType = 0x11
 | 
			
		||||
	TypeIntBoolean    DataType = 0x12
 | 
			
		||||
	TypeFirstColorInt DataType = 0x1c
 | 
			
		||||
	TypeIntColorARGB8 DataType = 0x1c
 | 
			
		||||
	TypeIntColorRGB8  DataType = 0x1d
 | 
			
		||||
	TypeIntColorARGB4 DataType = 0x1e
 | 
			
		||||
	TypeIntColorRGB4  DataType = 0x1f
 | 
			
		||||
	TypeLastColorInt  DataType = 0x1f
 | 
			
		||||
	TypeLastInt       DataType = 0x1f
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ResValue is a representation of a value in a resource
 | 
			
		||||
type ResValue struct {
 | 
			
		||||
	Size     uint16
 | 
			
		||||
	Res0     uint8
 | 
			
		||||
	DataType DataType
 | 
			
		||||
	Data     uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetString returns a string referenced by ref.
 | 
			
		||||
func (pool *ResStringPool) GetString(ref ResStringPoolRef) string {
 | 
			
		||||
	return pool.Strings[int(ref)]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func readStringPool(sr *io.SectionReader) (*ResStringPool, error) {
 | 
			
		||||
	sp := new(ResStringPool)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, &sp.Header); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	stringStarts := make([]uint32, sp.Header.StringCount)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, stringStarts); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	styleStarts := make([]uint32, sp.Header.StyleCount)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, styleStarts); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sp.Strings = make([]string, sp.Header.StringCount)
 | 
			
		||||
	for i, start := range stringStarts {
 | 
			
		||||
		var str string
 | 
			
		||||
		var err error
 | 
			
		||||
		if _, err := sr.Seek(int64(sp.Header.StringStart+start), io.SeekStart); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if (sp.Header.Flags & UTF8Flag) == 0 {
 | 
			
		||||
			str, err = readUTF16(sr)
 | 
			
		||||
		} else {
 | 
			
		||||
			str, err = readUTF8(sr)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		sp.Strings[i] = str
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sp.Styles = make([]ResStringPoolSpan, sp.Header.StyleCount)
 | 
			
		||||
	for i, start := range styleStarts {
 | 
			
		||||
		if _, err := sr.Seek(int64(sp.Header.StylesStart+start), io.SeekStart); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if err := binary.Read(sr, binary.LittleEndian, &sp.Styles[i]); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func readUTF16(sr *io.SectionReader) (string, error) {
 | 
			
		||||
	// read length of string
 | 
			
		||||
	size, err := readUTF16length(sr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// read string value
 | 
			
		||||
	buf := make([]uint16, size)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, buf); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return string(utf16.Decode(buf)), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func readUTF16length(sr *io.SectionReader) (int, error) {
 | 
			
		||||
	var size int
 | 
			
		||||
	var first, second uint16
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, &first); err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	if (first & 0x8000) != 0 {
 | 
			
		||||
		if err := binary.Read(sr, binary.LittleEndian, &second); err != nil {
 | 
			
		||||
			return 0, err
 | 
			
		||||
		}
 | 
			
		||||
		size = (int(first&0x7FFF) << 16) + int(second)
 | 
			
		||||
	} else {
 | 
			
		||||
		size = int(first)
 | 
			
		||||
	}
 | 
			
		||||
	return size, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func readUTF8(sr *io.SectionReader) (string, error) {
 | 
			
		||||
	// skip utf16 length
 | 
			
		||||
	_, err := readUTF8length(sr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// read lenth of string
 | 
			
		||||
	size, err := readUTF8length(sr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buf := make([]uint8, size)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, buf); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return string(buf), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func readUTF8length(sr *io.SectionReader) (int, error) {
 | 
			
		||||
	var size int
 | 
			
		||||
	var first, second uint8
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, &first); err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	if (first & 0x80) != 0 {
 | 
			
		||||
		if err := binary.Read(sr, binary.LittleEndian, &second); err != nil {
 | 
			
		||||
			return 0, err
 | 
			
		||||
		}
 | 
			
		||||
		size = (int(first&0x7F) << 8) + int(second)
 | 
			
		||||
	} else {
 | 
			
		||||
		size = int(first)
 | 
			
		||||
	}
 | 
			
		||||
	return size, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newZeroFilledReader(r io.Reader, actual int64, expected int64) (io.Reader, error) {
 | 
			
		||||
	if actual >= expected {
 | 
			
		||||
		// no need to fill
 | 
			
		||||
		return r, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// read `actual' bytes from r, and
 | 
			
		||||
	buf := new(bytes.Buffer)
 | 
			
		||||
	if _, err := io.CopyN(buf, r, actual); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// fill zero until `expected' bytes
 | 
			
		||||
	for i := actual; i < expected; i++ {
 | 
			
		||||
		if err := buf.WriteByte(0x00); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return buf, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										1093
									
								
								pkg/android_binary/table.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1093
									
								
								pkg/android_binary/table.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										333
									
								
								pkg/android_binary/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										333
									
								
								pkg/android_binary/type.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,333 @@
 | 
			
		||||
package android_binary
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/xml"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type injector interface {
 | 
			
		||||
	inject(table *TableFile, config *ResTableConfig)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var injectorType = reflect.TypeOf((*injector)(nil)).Elem()
 | 
			
		||||
 | 
			
		||||
func inject(val reflect.Value, table *TableFile, config *ResTableConfig) {
 | 
			
		||||
	if val.Kind() == reflect.Ptr {
 | 
			
		||||
		if val.IsNil() {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		val = val.Elem()
 | 
			
		||||
	}
 | 
			
		||||
	if val.CanInterface() && val.Type().Implements(injectorType) {
 | 
			
		||||
		val.Interface().(injector).inject(table, config)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if val.CanAddr() {
 | 
			
		||||
		pv := val.Addr()
 | 
			
		||||
		if pv.CanInterface() && pv.Type().Implements(injectorType) {
 | 
			
		||||
			pv.Interface().(injector).inject(table, config)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch val.Kind() {
 | 
			
		||||
	default:
 | 
			
		||||
		// ignore other types
 | 
			
		||||
		return
 | 
			
		||||
	case reflect.Slice, reflect.Array:
 | 
			
		||||
		l := val.Len()
 | 
			
		||||
		for i := 0; i < l; i++ {
 | 
			
		||||
			inject(val.Index(i), table, config)
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	case reflect.Struct:
 | 
			
		||||
		l := val.NumField()
 | 
			
		||||
		for i := 0; i < l; i++ {
 | 
			
		||||
			inject(val.Field(i), table, config)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Bool is a boolean value in XML file.
 | 
			
		||||
// It may be an immediate value or a reference.
 | 
			
		||||
type Bool struct {
 | 
			
		||||
	value  string
 | 
			
		||||
	table  *TableFile
 | 
			
		||||
	config *ResTableConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTableFile ties TableFile to the Bool.
 | 
			
		||||
func (v Bool) WithTableFile(table *TableFile) Bool {
 | 
			
		||||
	return Bool{
 | 
			
		||||
		value:  v.value,
 | 
			
		||||
		table:  table,
 | 
			
		||||
		config: v.config,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithResTableConfig ties ResTableConfig to the Bool.
 | 
			
		||||
func (v Bool) WithResTableConfig(config *ResTableConfig) Bool {
 | 
			
		||||
	return Bool{
 | 
			
		||||
		value:  v.value,
 | 
			
		||||
		table:  v.table,
 | 
			
		||||
		config: config,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (v *Bool) inject(table *TableFile, config *ResTableConfig) {
 | 
			
		||||
	v.table = table
 | 
			
		||||
	v.config = config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetBool sets a boolean value.
 | 
			
		||||
func (v *Bool) SetBool(value bool) {
 | 
			
		||||
	v.value = strconv.FormatBool(value)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetResID sets a boolean value with the resource id.
 | 
			
		||||
func (v *Bool) SetResID(resID ResID) {
 | 
			
		||||
	v.value = resID.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnmarshalXMLAttr implements xml.UnmarshalerAttr.
 | 
			
		||||
func (v *Bool) UnmarshalXMLAttr(attr xml.Attr) error {
 | 
			
		||||
	v.value = attr.Value
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MarshalXMLAttr implements xml.MarshalerAttr.
 | 
			
		||||
func (v Bool) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
 | 
			
		||||
	if v.value == "" {
 | 
			
		||||
		// return the zero value of bool
 | 
			
		||||
		return xml.Attr{
 | 
			
		||||
			Name:  name,
 | 
			
		||||
			Value: "false",
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	return xml.Attr{
 | 
			
		||||
		Name:  name,
 | 
			
		||||
		Value: v.value,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Bool returns the boolean value.
 | 
			
		||||
// It resolves the reference if needed.
 | 
			
		||||
func (v Bool) Bool() (bool, error) {
 | 
			
		||||
	if v.value == "" {
 | 
			
		||||
		return false, nil
 | 
			
		||||
	}
 | 
			
		||||
	if !IsResID(v.value) {
 | 
			
		||||
		return strconv.ParseBool(v.value)
 | 
			
		||||
	}
 | 
			
		||||
	id, err := ParseResID(v.value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	value, err := v.table.GetResource(id, v.config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	ret, ok := value.(bool)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return false, fmt.Errorf("invalid type: %T", value)
 | 
			
		||||
	}
 | 
			
		||||
	return ret, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MustBool is same as Bool, but it panics if it fails to parse the value.
 | 
			
		||||
func (v Bool) MustBool() bool {
 | 
			
		||||
	ret, err := v.Bool()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return ret
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Int32 is an integer value in XML file.
 | 
			
		||||
// It may be an immediate value or a reference.
 | 
			
		||||
type Int32 struct {
 | 
			
		||||
	value  string
 | 
			
		||||
	table  *TableFile
 | 
			
		||||
	config *ResTableConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTableFile ties TableFile to the Bool.
 | 
			
		||||
func (v Int32) WithTableFile(table *TableFile) Int32 {
 | 
			
		||||
	return Int32{
 | 
			
		||||
		value:  v.value,
 | 
			
		||||
		table:  table,
 | 
			
		||||
		config: v.config,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithResTableConfig ties ResTableConfig to the Bool.
 | 
			
		||||
func (v Int32) WithResTableConfig(config *ResTableConfig) Bool {
 | 
			
		||||
	return Bool{
 | 
			
		||||
		value:  v.value,
 | 
			
		||||
		table:  v.table,
 | 
			
		||||
		config: config,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (v *Int32) inject(table *TableFile, config *ResTableConfig) {
 | 
			
		||||
	v.table = table
 | 
			
		||||
	v.config = config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetInt32 sets an integer value.
 | 
			
		||||
func (v *Int32) SetInt32(value int32) {
 | 
			
		||||
	v.value = strconv.FormatInt(int64(value), 10)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetResID sets a boolean value with the resource id.
 | 
			
		||||
func (v *Int32) SetResID(resID ResID) {
 | 
			
		||||
	v.value = resID.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnmarshalXMLAttr implements xml.UnmarshalerAttr.
 | 
			
		||||
func (v *Int32) UnmarshalXMLAttr(attr xml.Attr) error {
 | 
			
		||||
	v.value = attr.Value
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MarshalXMLAttr implements xml.MarshalerAttr.
 | 
			
		||||
func (v Int32) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
 | 
			
		||||
	if v.value == "" {
 | 
			
		||||
		// return the zero value of int32
 | 
			
		||||
		return xml.Attr{
 | 
			
		||||
			Name:  name,
 | 
			
		||||
			Value: "0",
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	return xml.Attr{
 | 
			
		||||
		Name:  name,
 | 
			
		||||
		Value: v.value,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Int32 returns the integer value.
 | 
			
		||||
// It resolves the reference if needed.
 | 
			
		||||
func (v Int32) Int32() (int32, error) {
 | 
			
		||||
	if v.value == "" {
 | 
			
		||||
		return 0, nil
 | 
			
		||||
	}
 | 
			
		||||
	if !IsResID(v.value) {
 | 
			
		||||
		v, err := strconv.ParseInt(v.value, 10, 32)
 | 
			
		||||
		return int32(v), err
 | 
			
		||||
	}
 | 
			
		||||
	id, err := ParseResID(v.value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	value, err := v.table.GetResource(id, v.config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	ret, ok := value.(uint32)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return 0, fmt.Errorf("invalid type: %T", value)
 | 
			
		||||
	}
 | 
			
		||||
	return int32(ret), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MustInt32 is same as Int32, but it panics if it fails to parse the value.
 | 
			
		||||
func (v Int32) MustInt32() int32 {
 | 
			
		||||
	ret, err := v.Int32()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return ret
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// String is a boolean value in XML file.
 | 
			
		||||
// It may be an immediate value or a reference.
 | 
			
		||||
type String struct {
 | 
			
		||||
	value  string
 | 
			
		||||
	table  *TableFile
 | 
			
		||||
	config *ResTableConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTableFile ties TableFile to the Bool.
 | 
			
		||||
func (v String) WithTableFile(table *TableFile) String {
 | 
			
		||||
	return String{
 | 
			
		||||
		value:  v.value,
 | 
			
		||||
		table:  table,
 | 
			
		||||
		config: v.config,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithResTableConfig ties ResTableConfig to the Bool.
 | 
			
		||||
func (v String) WithResTableConfig(config *ResTableConfig) String {
 | 
			
		||||
	return String{
 | 
			
		||||
		value:  v.value,
 | 
			
		||||
		table:  v.table,
 | 
			
		||||
		config: config,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (v *String) inject(table *TableFile, config *ResTableConfig) {
 | 
			
		||||
	v.table = table
 | 
			
		||||
	v.config = config
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetString sets a string value.
 | 
			
		||||
func (v *String) SetString(value string) {
 | 
			
		||||
	v.value = value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetResID sets a boolean value with the resource id.
 | 
			
		||||
func (v *String) SetResID(resID ResID) {
 | 
			
		||||
	v.value = resID.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnmarshalXMLAttr implements xml.UnmarshalerAttr.
 | 
			
		||||
func (v *String) UnmarshalXMLAttr(attr xml.Attr) error {
 | 
			
		||||
	v.value = attr.Value
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MarshalXMLAttr implements xml.MarshalerAttr.
 | 
			
		||||
func (v String) MarshalXMLAttr(name xml.Name) (xml.Attr, error) {
 | 
			
		||||
	return xml.Attr{
 | 
			
		||||
		Name:  name,
 | 
			
		||||
		Value: v.value,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// String returns the string value.
 | 
			
		||||
// It resolves the reference if needed.
 | 
			
		||||
func (v String) String() (string, error) {
 | 
			
		||||
	if !IsResID(v.value) {
 | 
			
		||||
		return v.value, nil
 | 
			
		||||
	}
 | 
			
		||||
	id, err := ParseResID(v.value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := v.table.GetResource(id, v.config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//todo 读取套娃
 | 
			
		||||
	switch value.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		return value.(string), nil
 | 
			
		||||
	case uint32:
 | 
			
		||||
		return fmt.Sprintf("%d", value.(uint32)), nil
 | 
			
		||||
	default:
 | 
			
		||||
		return "", nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MustString is same as String, but it panics if it fails to parse the value.
 | 
			
		||||
func (v String) MustString() string {
 | 
			
		||||
	ret, err := v.String()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return ret
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										271
									
								
								pkg/android_binary/xml.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								pkg/android_binary/xml.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,271 @@
 | 
			
		||||
package android_binary
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"encoding/xml"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"reflect"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// XMLFile is an XML file expressed in binary format.
 | 
			
		||||
type XMLFile struct {
 | 
			
		||||
	stringPool     *ResStringPool
 | 
			
		||||
	resourceMap    []uint32
 | 
			
		||||
	notPrecessedNS map[ResStringPoolRef]ResStringPoolRef
 | 
			
		||||
	namespaces     map[ResStringPoolRef]ResStringPoolRef
 | 
			
		||||
	xmlBuffer      bytes.Buffer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResXMLTreeNode is basic XML tree node.
 | 
			
		||||
type ResXMLTreeNode struct {
 | 
			
		||||
	Header     ResChunkHeader
 | 
			
		||||
	LineNumber uint32
 | 
			
		||||
	Comment    ResStringPoolRef
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResXMLTreeNamespaceExt is extended XML tree node for namespace start/end nodes.
 | 
			
		||||
type ResXMLTreeNamespaceExt struct {
 | 
			
		||||
	Prefix ResStringPoolRef
 | 
			
		||||
	URI    ResStringPoolRef
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResXMLTreeAttrExt is extended XML tree node for start tags -- includes attribute.
 | 
			
		||||
type ResXMLTreeAttrExt struct {
 | 
			
		||||
	NS             ResStringPoolRef
 | 
			
		||||
	Name           ResStringPoolRef
 | 
			
		||||
	AttributeStart uint16
 | 
			
		||||
	AttributeSize  uint16
 | 
			
		||||
	AttributeCount uint16
 | 
			
		||||
	IDIndex        uint16
 | 
			
		||||
	ClassIndex     uint16
 | 
			
		||||
	StyleIndex     uint16
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResXMLTreeAttribute is an attribute of start tags.
 | 
			
		||||
type ResXMLTreeAttribute struct {
 | 
			
		||||
	NS         ResStringPoolRef
 | 
			
		||||
	Name       ResStringPoolRef
 | 
			
		||||
	RawValue   ResStringPoolRef
 | 
			
		||||
	TypedValue ResValue
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResXMLTreeEndElementExt is extended XML tree node for element start/end nodes.
 | 
			
		||||
type ResXMLTreeEndElementExt struct {
 | 
			
		||||
	NS   ResStringPoolRef
 | 
			
		||||
	Name ResStringPoolRef
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewXMLFile returns a new XMLFile.
 | 
			
		||||
func NewXMLFile(r io.ReaderAt) (*XMLFile, error) {
 | 
			
		||||
	f := new(XMLFile)
 | 
			
		||||
	sr := io.NewSectionReader(r, 0, 1<<63-1)
 | 
			
		||||
 | 
			
		||||
	fmt.Fprintf(&f.xmlBuffer, xml.Header)
 | 
			
		||||
 | 
			
		||||
	header := new(ResChunkHeader)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	offset := int64(header.HeaderSize)
 | 
			
		||||
	for offset < int64(header.Size) {
 | 
			
		||||
		chunkHeader, err := f.readChunk(r, offset)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		offset += int64(chunkHeader.Size)
 | 
			
		||||
	}
 | 
			
		||||
	return f, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reader returns a reader of XML file expressed in text format.
 | 
			
		||||
func (f *XMLFile) Reader() *bytes.Reader {
 | 
			
		||||
	return bytes.NewReader(f.xmlBuffer.Bytes())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Decode decodes XML file and stores the result in the value pointed to by v.
 | 
			
		||||
// To resolve the resource references, Decode also stores default TableFile and ResTableConfig in the value pointed to by v.
 | 
			
		||||
func (f *XMLFile) Decode(v any, table *TableFile, config *ResTableConfig) error {
 | 
			
		||||
	decoder := xml.NewDecoder(f.Reader())
 | 
			
		||||
	if err := decoder.Decode(v); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	inject(reflect.ValueOf(v), table, config)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *XMLFile) readChunk(r io.ReaderAt, offset int64) (*ResChunkHeader, error) {
 | 
			
		||||
	sr := io.NewSectionReader(r, offset, 1<<63-1-offset)
 | 
			
		||||
	chunkHeader := &ResChunkHeader{}
 | 
			
		||||
	if _, err := sr.Seek(0, io.SeekStart); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, chunkHeader); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	if _, err := sr.Seek(0, io.SeekStart); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	switch chunkHeader.Type {
 | 
			
		||||
	case ResStringPoolChunkType:
 | 
			
		||||
		f.stringPool, err = readStringPool(sr)
 | 
			
		||||
	case ResXMLStartNamespaceType:
 | 
			
		||||
		err = f.readStartNamespace(sr)
 | 
			
		||||
	case ResXMLEndNamespaceType:
 | 
			
		||||
		err = f.readEndNamespace(sr)
 | 
			
		||||
	case ResXMLStartElementType:
 | 
			
		||||
		err = f.readStartElement(sr)
 | 
			
		||||
	case ResXMLEndElementType:
 | 
			
		||||
		err = f.readEndElement(sr)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return chunkHeader, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetString returns a string referenced by ref.
 | 
			
		||||
func (f *XMLFile) GetString(ref ResStringPoolRef) string {
 | 
			
		||||
	return f.stringPool.GetString(ref)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *XMLFile) readStartNamespace(sr *io.SectionReader) error {
 | 
			
		||||
	header := new(ResXMLTreeNode)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	namespace := new(ResXMLTreeNamespaceExt)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, namespace); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if f.notPrecessedNS == nil {
 | 
			
		||||
		f.notPrecessedNS = make(map[ResStringPoolRef]ResStringPoolRef)
 | 
			
		||||
	}
 | 
			
		||||
	f.notPrecessedNS[namespace.URI] = namespace.Prefix
 | 
			
		||||
 | 
			
		||||
	if f.namespaces == nil {
 | 
			
		||||
		f.namespaces = make(map[ResStringPoolRef]ResStringPoolRef)
 | 
			
		||||
	}
 | 
			
		||||
	f.namespaces[namespace.URI] = namespace.Prefix
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *XMLFile) readEndNamespace(sr *io.SectionReader) error {
 | 
			
		||||
	header := new(ResXMLTreeNode)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	namespace := new(ResXMLTreeNamespaceExt)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, namespace); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	delete(f.namespaces, namespace.URI)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *XMLFile) addNamespacePrefix(ns, name ResStringPoolRef) string {
 | 
			
		||||
	if ns != NilResStringPoolRef {
 | 
			
		||||
		prefix := f.GetString(f.namespaces[ns])
 | 
			
		||||
		return fmt.Sprintf("%s:%s", prefix, f.GetString(name))
 | 
			
		||||
	}
 | 
			
		||||
	return f.GetString(name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *XMLFile) readStartElement(sr *io.SectionReader) error {
 | 
			
		||||
	header := new(ResXMLTreeNode)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	ext := new(ResXMLTreeAttrExt)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, ext); err != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Fprintf(&f.xmlBuffer, "<%s", f.addNamespacePrefix(ext.NS, ext.Name))
 | 
			
		||||
 | 
			
		||||
	// output XML namespaces
 | 
			
		||||
	if f.notPrecessedNS != nil {
 | 
			
		||||
		for uri, prefix := range f.notPrecessedNS {
 | 
			
		||||
			fmt.Fprintf(&f.xmlBuffer, " xmlns:%s=\"", f.GetString(prefix))
 | 
			
		||||
			xml.Escape(&f.xmlBuffer, []byte(f.GetString(uri)))
 | 
			
		||||
			fmt.Fprint(&f.xmlBuffer, "\"")
 | 
			
		||||
		}
 | 
			
		||||
		f.notPrecessedNS = nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// process attributes
 | 
			
		||||
	offset := int64(ext.AttributeStart + header.Header.HeaderSize)
 | 
			
		||||
	for i := 0; i < int(ext.AttributeCount); i++ {
 | 
			
		||||
		if _, err := sr.Seek(offset, io.SeekStart); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		attr := new(ResXMLTreeAttribute)
 | 
			
		||||
		binary.Read(sr, binary.LittleEndian, attr)
 | 
			
		||||
 | 
			
		||||
		var value string
 | 
			
		||||
		if attr.RawValue != NilResStringPoolRef {
 | 
			
		||||
			value = f.GetString(attr.RawValue)
 | 
			
		||||
		} else {
 | 
			
		||||
			data := attr.TypedValue.Data
 | 
			
		||||
			switch attr.TypedValue.DataType {
 | 
			
		||||
			case TypeNull:
 | 
			
		||||
				value = ""
 | 
			
		||||
			case TypeReference:
 | 
			
		||||
				value = fmt.Sprintf("@0x%08X", data)
 | 
			
		||||
			case TypeIntDec:
 | 
			
		||||
				value = fmt.Sprintf("%d", data)
 | 
			
		||||
			case TypeIntHex:
 | 
			
		||||
				value = fmt.Sprintf("0x%08X", data)
 | 
			
		||||
			case TypeIntBoolean:
 | 
			
		||||
				if data != 0 {
 | 
			
		||||
					value = "true"
 | 
			
		||||
				} else {
 | 
			
		||||
					value = "false"
 | 
			
		||||
				}
 | 
			
		||||
			default:
 | 
			
		||||
				value = fmt.Sprintf("@0x%08X", data)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fmt.Fprintf(&f.xmlBuffer, " %s=\"", f.addNamespacePrefix(attr.NS, attr.Name))
 | 
			
		||||
		xml.Escape(&f.xmlBuffer, []byte(value))
 | 
			
		||||
		fmt.Fprint(&f.xmlBuffer, "\"")
 | 
			
		||||
		offset += int64(ext.AttributeSize)
 | 
			
		||||
	}
 | 
			
		||||
	fmt.Fprint(&f.xmlBuffer, ">")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *XMLFile) readEndElement(sr *io.SectionReader) error {
 | 
			
		||||
	header := new(ResXMLTreeNode)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, header); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := sr.Seek(int64(header.Header.HeaderSize), io.SeekStart); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	ext := new(ResXMLTreeEndElementExt)
 | 
			
		||||
	if err := binary.Read(sr, binary.LittleEndian, ext); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	fmt.Fprintf(&f.xmlBuffer, "</%s>", f.addNamespacePrefix(ext.NS, ext.Name))
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										108
									
								
								pkg/ants/pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								pkg/ants/pool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,108 @@
 | 
			
		||||
package ants
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/ticker"
 | 
			
		||||
	"github.com/panjf2000/ants/v2"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ GoroutinePool = (*goroutinePool)(nil)
 | 
			
		||||
 | 
			
		||||
type GoroutinePool interface {
 | 
			
		||||
	run()
 | 
			
		||||
 | 
			
		||||
	Submit(task func())
 | 
			
		||||
	Stop()
 | 
			
		||||
 | 
			
		||||
	Size() int
 | 
			
		||||
	Running() int
 | 
			
		||||
	Free() int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type goroutinePool struct {
 | 
			
		||||
	pool   *ants.Pool
 | 
			
		||||
	logger *zap.Logger
 | 
			
		||||
	ticker ticker.Ticker
 | 
			
		||||
	step   int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type poolLogger struct {
 | 
			
		||||
	zap *zap.Logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *poolLogger) Printf(format string, args ...any) {
 | 
			
		||||
	l.zap.Sugar().Infof(format, args)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPool(zapLogger *zap.Logger, step int) (GoroutinePool, error) {
 | 
			
		||||
	ttl := time.Minute * 5
 | 
			
		||||
 | 
			
		||||
	options := ants.Options{
 | 
			
		||||
		Nonblocking:    true,
 | 
			
		||||
		ExpiryDuration: ttl,
 | 
			
		||||
		PanicHandler: func(err any) {
 | 
			
		||||
			zapLogger.Sugar().Error(
 | 
			
		||||
				"GoroutinePool panic",
 | 
			
		||||
				zap.String("error", fmt.Sprintf("%+v", err)),
 | 
			
		||||
				zap.String("stack", string(debug.Stack())),
 | 
			
		||||
			)
 | 
			
		||||
		},
 | 
			
		||||
		Logger: &poolLogger{zap: zapLogger},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	antsPool, err := ants.NewPool(step, ants.WithOptions(options))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pool := &goroutinePool{
 | 
			
		||||
		pool:   antsPool,
 | 
			
		||||
		logger: zapLogger,
 | 
			
		||||
		ticker: ticker.New(ttl),
 | 
			
		||||
		step:   step,
 | 
			
		||||
	}
 | 
			
		||||
	pool.run()
 | 
			
		||||
 | 
			
		||||
	return pool, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *goroutinePool) run() {
 | 
			
		||||
	p.ticker.Process(func() {
 | 
			
		||||
		if p.Free() > p.step {
 | 
			
		||||
			mul := p.Free() / p.step
 | 
			
		||||
			p.pool.Tune(p.Size() - p.step*mul)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *goroutinePool) Submit(task func()) {
 | 
			
		||||
	if p.pool.IsClosed() {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	err := p.pool.Submit(task)
 | 
			
		||||
	if errors.Is(err, ants.ErrPoolOverload) {
 | 
			
		||||
		p.pool.Tune(p.Size() + p.step)
 | 
			
		||||
		p.Submit(task)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *goroutinePool) Size() int {
 | 
			
		||||
	return p.pool.Cap()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *goroutinePool) Running() int {
 | 
			
		||||
	return p.pool.Running()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *goroutinePool) Free() int {
 | 
			
		||||
	return p.pool.Free()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *goroutinePool) Stop() {
 | 
			
		||||
	p.ticker.Stop()
 | 
			
		||||
	p.pool.Release()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										97
									
								
								pkg/apollo/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								pkg/apollo/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
package apollo
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/env"
 | 
			
		||||
	"github.com/apolloconfig/agollo/v4"
 | 
			
		||||
	"github.com/apolloconfig/agollo/v4/component/log"
 | 
			
		||||
	apolloConfig "github.com/apolloconfig/agollo/v4/env/config"
 | 
			
		||||
	"github.com/apolloconfig/agollo/v4/storage"
 | 
			
		||||
	"github.com/spf13/viper"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type clientConfig struct {
 | 
			
		||||
	client agollo.Client
 | 
			
		||||
	ac     *apolloConfig.AppConfig
 | 
			
		||||
	conf   any
 | 
			
		||||
 | 
			
		||||
	onChange       func(event *storage.ChangeEvent)
 | 
			
		||||
	onNewestChange func(*storage.FullChangeEvent)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Option func(*clientConfig)
 | 
			
		||||
 | 
			
		||||
func WithOnChangeEvent(event func(event *storage.ChangeEvent)) Option {
 | 
			
		||||
	return func(conf *clientConfig) {
 | 
			
		||||
		conf.onChange = event
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WithOnNewestChangeEvent(event func(event *storage.FullChangeEvent)) Option {
 | 
			
		||||
	return func(conf *clientConfig) {
 | 
			
		||||
		conf.onNewestChange = event
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetApolloConfig(appId, secret string, config any, opt ...Option) error {
 | 
			
		||||
	var err error
 | 
			
		||||
	namespace := env.Active().Value() + ".yaml"
 | 
			
		||||
 | 
			
		||||
	c := new(clientConfig)
 | 
			
		||||
	c.conf = config
 | 
			
		||||
	c.ac = &apolloConfig.AppConfig{
 | 
			
		||||
		AppID:          appId,
 | 
			
		||||
		Cluster:        "dev",
 | 
			
		||||
		IP:             "https://config.bvbej.com",
 | 
			
		||||
		NamespaceName:  namespace,
 | 
			
		||||
		IsBackupConfig: false,
 | 
			
		||||
		Secret:         secret,
 | 
			
		||||
		MustStart:      true,
 | 
			
		||||
	}
 | 
			
		||||
	for _, option := range opt {
 | 
			
		||||
		option(c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	agollo.SetLogger(&log.DefaultLogger{})
 | 
			
		||||
 | 
			
		||||
	c.client, err = agollo.StartWithConfig(func() (*apolloConfig.AppConfig, error) {
 | 
			
		||||
		return c.ac, nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("get config error:[%s]", err)
 | 
			
		||||
	}
 | 
			
		||||
	c.client.AddChangeListener(c)
 | 
			
		||||
 | 
			
		||||
	err = c.serialization()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("unmarshal config error:[%s]", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *clientConfig) serialization() error {
 | 
			
		||||
	parser := viper.New()
 | 
			
		||||
 | 
			
		||||
	parser.SetConfigType("yaml")
 | 
			
		||||
	c.client.GetConfigCache(c.ac.NamespaceName).Range(func(key, value any) bool {
 | 
			
		||||
		parser.Set(key.(string), value)
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return parser.Unmarshal(c.conf)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *clientConfig) OnChange(event *storage.ChangeEvent) {
 | 
			
		||||
	_ = c.serialization()
 | 
			
		||||
	if c.onChange != nil {
 | 
			
		||||
		c.onChange(event)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *clientConfig) OnNewestChange(event *storage.FullChangeEvent) {
 | 
			
		||||
	if c.onNewestChange != nil {
 | 
			
		||||
		c.onNewestChange(event)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										29
									
								
								pkg/auth/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								pkg/auth/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
// Config authorization configuration parameters
 | 
			
		||||
type Config struct {
 | 
			
		||||
	// access token expiration time, 0 means it doesn't expire
 | 
			
		||||
	AccessTokenExp time.Duration
 | 
			
		||||
	// refresh token expiration time, 0 means it doesn't expire
 | 
			
		||||
	RefreshTokenExp time.Duration
 | 
			
		||||
	// whether to generate the refreshing token
 | 
			
		||||
	IsGenerateRefresh bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RefreshConfig refreshing token config
 | 
			
		||||
type RefreshConfig struct {
 | 
			
		||||
	// whether to reset the refreshing creation time
 | 
			
		||||
	IsResetRefreshTime bool
 | 
			
		||||
	// whether to remove access token
 | 
			
		||||
	IsRemoveAccess bool
 | 
			
		||||
	// whether to remove refreshing token
 | 
			
		||||
	IsRemoveRefreshing bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// default configs
 | 
			
		||||
var (
 | 
			
		||||
	DefaultAccessTokenCfg  = &Config{AccessTokenExp: time.Hour * 24, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true}
 | 
			
		||||
	DefaultRefreshTokenCfg = &RefreshConfig{IsResetRefreshTime: true, IsRemoveAccess: true, IsRemoveRefreshing: true}
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										12
									
								
								pkg/auth/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								pkg/auth/error.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import "errors"
 | 
			
		||||
 | 
			
		||||
var New = errors.New
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrInvalidAccessToken  = errors.New("invalid access token")
 | 
			
		||||
	ErrInvalidRefreshToken = errors.New("invalid refresh token")
 | 
			
		||||
	ErrExpiredAccessToken  = errors.New("expired access token")
 | 
			
		||||
	ErrExpiredRefreshToken = errors.New("expired refresh token")
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										17
									
								
								pkg/auth/generate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								pkg/auth/generate.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type (
 | 
			
		||||
	GenerateBasic struct {
 | 
			
		||||
		UserID    string
 | 
			
		||||
		CreateAt  time.Time
 | 
			
		||||
		TokenInfo TokenInfo
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	AccessGenerate interface {
 | 
			
		||||
		Token(data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										97
									
								
								pkg/auth/jwt_access.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								pkg/auth/jwt_access.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/golang-jwt/jwt/v4"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// JWTAccessClaims jwt claims
 | 
			
		||||
type JWTAccessClaims struct {
 | 
			
		||||
	jwt.RegisteredClaims
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Valid claims verification
 | 
			
		||||
func (a *JWTAccessClaims) Valid() error {
 | 
			
		||||
	if a.ExpiresAt.Before(time.Now()) {
 | 
			
		||||
		return ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewJWTAccessGenerate create to generate the jwt access token instance
 | 
			
		||||
func NewJWTAccessGenerate(key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
 | 
			
		||||
	return &JWTAccessGenerate{
 | 
			
		||||
		SignedKey:    key,
 | 
			
		||||
		SignedMethod: method,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JWTAccessGenerate generate the jwt access token
 | 
			
		||||
type JWTAccessGenerate struct {
 | 
			
		||||
	SignedKey    []byte
 | 
			
		||||
	SignedMethod jwt.SigningMethod
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Token based on the UUID generated token
 | 
			
		||||
func (a *JWTAccessGenerate) Token(data *GenerateBasic, isGenRefresh bool) (string, string, error) {
 | 
			
		||||
	claims := &JWTAccessClaims{
 | 
			
		||||
		RegisteredClaims: jwt.RegisteredClaims{
 | 
			
		||||
			Issuer:    "BvBeJ",
 | 
			
		||||
			Subject:   data.UserID,
 | 
			
		||||
			ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := jwt.NewWithClaims(a.SignedMethod, claims)
 | 
			
		||||
	var key any
 | 
			
		||||
	if a.isEs() {
 | 
			
		||||
		v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", "", err
 | 
			
		||||
		}
 | 
			
		||||
		key = v
 | 
			
		||||
	} else if a.isRsOrPS() {
 | 
			
		||||
		v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", "", err
 | 
			
		||||
		}
 | 
			
		||||
		key = v
 | 
			
		||||
	} else if a.isHs() {
 | 
			
		||||
		key = a.SignedKey
 | 
			
		||||
	} else {
 | 
			
		||||
		return "", "", errors.New("unsupported sign method")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	access, err := token.SignedString(key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", "", err
 | 
			
		||||
	}
 | 
			
		||||
	refresh := ""
 | 
			
		||||
 | 
			
		||||
	if isGenRefresh {
 | 
			
		||||
		t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
 | 
			
		||||
		refresh = base64.URLEncoding.EncodeToString([]byte(t))
 | 
			
		||||
		refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return access, refresh, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *JWTAccessGenerate) isEs() bool {
 | 
			
		||||
	return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *JWTAccessGenerate) isRsOrPS() bool {
 | 
			
		||||
	isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
 | 
			
		||||
	isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
 | 
			
		||||
	return isRs || isPs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *JWTAccessGenerate) isHs() bool {
 | 
			
		||||
	return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										194
									
								
								pkg/auth/manager.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								pkg/auth/manager.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,194 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewManager create to authorization management instance
 | 
			
		||||
func NewManager(ag AccessGenerate, ts TokenStore) *Manager {
 | 
			
		||||
	return &Manager{
 | 
			
		||||
		cfg:            DefaultAccessTokenCfg,
 | 
			
		||||
		rCfg:           DefaultRefreshTokenCfg,
 | 
			
		||||
		accessGenerate: ag,
 | 
			
		||||
		tokenStore:     ts,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetConfig mapping the access token generate config
 | 
			
		||||
func (m *Manager) SetConfig(cfg *Config) {
 | 
			
		||||
	m.cfg = cfg
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefreshTokenConfig  mapping the token refresh config
 | 
			
		||||
func (m *Manager) SetRefreshTokenConfig(store *RefreshConfig) {
 | 
			
		||||
	m.rCfg = store
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Manager provide authorization management
 | 
			
		||||
type Manager struct {
 | 
			
		||||
	cfg            *Config
 | 
			
		||||
	rCfg           *RefreshConfig
 | 
			
		||||
	accessGenerate AccessGenerate
 | 
			
		||||
	tokenStore     TokenStore
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateAccessToken generate the access token
 | 
			
		||||
func (m *Manager) GenerateAccessToken(userID string) (TokenInfo, error) {
 | 
			
		||||
	ti := NewToken()
 | 
			
		||||
	ti.SetUserID(userID)
 | 
			
		||||
 | 
			
		||||
	createAt := time.Now()
 | 
			
		||||
	ti.SetAccessCreateAt(createAt)
 | 
			
		||||
 | 
			
		||||
	// set access token expires
 | 
			
		||||
	ti.SetAccessExpiresIn(m.cfg.AccessTokenExp)
 | 
			
		||||
	if m.cfg.IsGenerateRefresh {
 | 
			
		||||
		ti.SetRefreshCreateAt(createAt)
 | 
			
		||||
		ti.SetRefreshExpiresIn(m.cfg.RefreshTokenExp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	td := &GenerateBasic{
 | 
			
		||||
		UserID:    userID,
 | 
			
		||||
		CreateAt:  createAt,
 | 
			
		||||
		TokenInfo: ti,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	av, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ti.SetAccess(av)
 | 
			
		||||
 | 
			
		||||
	if rv != "" {
 | 
			
		||||
		ti.SetRefresh(rv)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = m.tokenStore.Create(ti)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RefreshAccessToken refreshing an access token
 | 
			
		||||
func (m *Manager) RefreshAccessToken(refresh string) (TokenInfo, error) {
 | 
			
		||||
	ti, err := m.LoadRefreshToken(refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
 | 
			
		||||
 | 
			
		||||
	td := &GenerateBasic{
 | 
			
		||||
		UserID:    ti.GetUserID(),
 | 
			
		||||
		CreateAt:  time.Now(),
 | 
			
		||||
		TokenInfo: ti,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ti.SetAccessCreateAt(td.CreateAt)
 | 
			
		||||
	if v := m.cfg.AccessTokenExp; v > 0 {
 | 
			
		||||
		ti.SetAccessExpiresIn(v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if v := m.cfg.RefreshTokenExp; v > 0 {
 | 
			
		||||
		ti.SetRefreshExpiresIn(v)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.rCfg.IsResetRefreshTime {
 | 
			
		||||
		ti.SetRefreshCreateAt(td.CreateAt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tv, rv, err := m.accessGenerate.Token(td, m.cfg.IsGenerateRefresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ti.SetAccess(tv)
 | 
			
		||||
	if rv != "" {
 | 
			
		||||
		ti.SetRefresh(rv)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = m.tokenStore.Create(ti); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.rCfg.IsRemoveAccess {
 | 
			
		||||
		// remove the old access token
 | 
			
		||||
		if err = m.tokenStore.RemoveByAccess(oldAccess); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.rCfg.IsRemoveRefreshing && rv != "" {
 | 
			
		||||
		// remove the old refresh token
 | 
			
		||||
		if err = m.tokenStore.RemoveByRefresh(oldRefresh); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if rv == "" {
 | 
			
		||||
		ti.SetRefresh("")
 | 
			
		||||
		ti.SetRefreshCreateAt(time.Now())
 | 
			
		||||
		ti.SetRefreshExpiresIn(0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveAccessToken use the access token to delete the token information
 | 
			
		||||
func (m *Manager) RemoveAccessToken(access string) error {
 | 
			
		||||
	if access == "" {
 | 
			
		||||
		return ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return m.tokenStore.RemoveByAccess(access)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveRefreshToken use the refresh token to delete the token information
 | 
			
		||||
func (m *Manager) RemoveRefreshToken(refresh string) error {
 | 
			
		||||
	if refresh == "" {
 | 
			
		||||
		return ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return m.tokenStore.RemoveByRefresh(refresh)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LoadAccessToken according to the access token for corresponding token information
 | 
			
		||||
func (m *Manager) LoadAccessToken(access string) (TokenInfo, error) {
 | 
			
		||||
	if access == "" {
 | 
			
		||||
		return nil, ErrInvalidAccessToken
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ct := time.Now()
 | 
			
		||||
	ti, err := m.tokenStore.GetByAccess(access)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if ti == nil || ti.GetAccess() != access {
 | 
			
		||||
		return nil, ErrInvalidAccessToken
 | 
			
		||||
	} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
 | 
			
		||||
		ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
 | 
			
		||||
		return nil, ErrExpiredRefreshToken
 | 
			
		||||
	} else if ti.GetAccessExpiresIn() != 0 &&
 | 
			
		||||
		ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
 | 
			
		||||
		return nil, ErrExpiredAccessToken
 | 
			
		||||
	}
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LoadRefreshToken according to the refresh token for corresponding token information
 | 
			
		||||
func (m *Manager) LoadRefreshToken(refresh string) (TokenInfo, error) {
 | 
			
		||||
	if refresh == "" {
 | 
			
		||||
		return nil, ErrInvalidRefreshToken
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ti, err := m.tokenStore.GetByRefresh(refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if ti == nil || ti.GetRefresh() != refresh {
 | 
			
		||||
		return nil, ErrInvalidRefreshToken
 | 
			
		||||
	} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
 | 
			
		||||
		ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
 | 
			
		||||
		return nil, ErrExpiredRefreshToken
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										334
									
								
								pkg/auth/store.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										334
									
								
								pkg/auth/store.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,334 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	jsonIterator "github.com/json-iterator/go"
 | 
			
		||||
	"github.com/redis/go-redis/v9"
 | 
			
		||||
	"github.com/tidwall/buntdb"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	jsonMarshal   = jsonIterator.Marshal
 | 
			
		||||
	jsonUnmarshal = jsonIterator.Unmarshal
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TokenStore interface {
 | 
			
		||||
	Create(info TokenInfo) error
 | 
			
		||||
	RemoveByAccess(access string) error
 | 
			
		||||
	RemoveByRefresh(refresh string) error
 | 
			
		||||
	GetByAccess(access string) (TokenInfo, error)
 | 
			
		||||
	GetByRefresh(refresh string) (TokenInfo, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewMemoryTokenStore create a token buntStore instance based on memory
 | 
			
		||||
func NewMemoryTokenStore() (TokenStore, error) {
 | 
			
		||||
	return NewFileTokenStore(":memory:")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewFileTokenStore create a token buntStore instance based on file
 | 
			
		||||
func NewFileTokenStore(filename string) (TokenStore, error) {
 | 
			
		||||
	db, err := buntdb.Open(filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &buntStore{db: db}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
 | 
			
		||||
type buntStore struct {
 | 
			
		||||
	db *buntdb.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *buntStore) remove(key string) error {
 | 
			
		||||
	err := ts.db.Update(func(tx *buntdb.Tx) error {
 | 
			
		||||
		_, err := tx.Delete(key)
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
	if errors.Is(err, buntdb.ErrNotFound) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *buntStore) getData(key string) (TokenInfo, error) {
 | 
			
		||||
	var ti TokenInfo
 | 
			
		||||
	err := ts.db.View(func(tx *buntdb.Tx) error {
 | 
			
		||||
		jv, err := tx.Get(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tm Token
 | 
			
		||||
		err = jsonUnmarshal([]byte(jv), &tm)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		ti = &tm
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == buntdb.ErrNotFound {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return ti, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ts *buntStore) getBasicID(key string) (string, error) {
 | 
			
		||||
	var basicID string
 | 
			
		||||
	err := ts.db.View(func(tx *buntdb.Tx) error {
 | 
			
		||||
		v, err := tx.Get(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		basicID = v
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == buntdb.ErrNotFound {
 | 
			
		||||
			return "", nil
 | 
			
		||||
		}
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return basicID, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create and buntStore the new token information
 | 
			
		||||
func (ts *buntStore) Create(info TokenInfo) error {
 | 
			
		||||
	ct := time.Now()
 | 
			
		||||
	jv, err := jsonMarshal(info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ts.db.Update(func(tx *buntdb.Tx) error {
 | 
			
		||||
		basicID := uuid.Must(uuid.NewRandom()).String()
 | 
			
		||||
		aexp := info.GetAccessExpiresIn()
 | 
			
		||||
		rexp := aexp
 | 
			
		||||
		expires := true
 | 
			
		||||
		if refresh := info.GetRefresh(); refresh != "" {
 | 
			
		||||
			rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
 | 
			
		||||
			if aexp.Seconds() > rexp.Seconds() {
 | 
			
		||||
				aexp = rexp
 | 
			
		||||
			}
 | 
			
		||||
			expires = info.GetRefreshExpiresIn() != 0
 | 
			
		||||
			_, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		_, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByAccess use the access token to delete the token information
 | 
			
		||||
func (ts *buntStore) RemoveByAccess(access string) error {
 | 
			
		||||
	return ts.remove(access)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByRefresh use the refresh token to delete the token information
 | 
			
		||||
func (ts *buntStore) RemoveByRefresh(refresh string) error {
 | 
			
		||||
	return ts.remove(refresh)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByAccess use the access token for token information data
 | 
			
		||||
func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := ts.getBasicID(access)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return ts.getData(basicID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByRefresh use the refresh token for token information data
 | 
			
		||||
func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := ts.getBasicID(refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return ts.getData(basicID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*------------------------------------------------------------------------------------*/
 | 
			
		||||
 | 
			
		||||
// NewRedisStoreWithCli create an instance of a redis store
 | 
			
		||||
func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
 | 
			
		||||
	store := &redisStore{
 | 
			
		||||
		cli: cli,
 | 
			
		||||
		ctx: context.TODO(),
 | 
			
		||||
		ns:  keyNamespace,
 | 
			
		||||
	}
 | 
			
		||||
	return store
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TokenStore redis token store
 | 
			
		||||
type redisStore struct {
 | 
			
		||||
	cli *redis.Client
 | 
			
		||||
	ctx context.Context
 | 
			
		||||
	ns  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) wrapperKey(key string) string {
 | 
			
		||||
	return fmt.Sprintf("%s%s", s.ns, key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
 | 
			
		||||
	if err := result.Err(); err != nil {
 | 
			
		||||
		if err == redis.Nil {
 | 
			
		||||
			return true, nil
 | 
			
		||||
		}
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	return false, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) remove(key string) error {
 | 
			
		||||
	result := s.cli.Del(s.ctx, s.wrapperKey(key))
 | 
			
		||||
	_, err := s.checkError(result)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
 | 
			
		||||
	basicID, err := s.getBasicID(tokenString)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if basicID == "" {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = s.remove(tokenString)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token, err := s.getToken(basicID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if token == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	checkToken := token.GetRefresh()
 | 
			
		||||
	if isRefresh {
 | 
			
		||||
		checkToken = token.GetAccess()
 | 
			
		||||
	}
 | 
			
		||||
	result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
 | 
			
		||||
	if err = result.Err(); err != nil && err != redis.Nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if result.Val() == 0 {
 | 
			
		||||
		return s.remove(basicID)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
 | 
			
		||||
	if ok, err := s.checkError(result); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if ok {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buf, err := result.Bytes()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == redis.Nil {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var token Token
 | 
			
		||||
	if err = jsonUnmarshal(buf, &token); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &token, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) getToken(key string) (TokenInfo, error) {
 | 
			
		||||
	result := s.cli.Get(s.ctx, s.wrapperKey(key))
 | 
			
		||||
	return s.parseToken(result)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
 | 
			
		||||
	if ok, err := s.checkError(result); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	} else if ok {
 | 
			
		||||
		return "", nil
 | 
			
		||||
	}
 | 
			
		||||
	return result.Val(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *redisStore) getBasicID(token string) (string, error) {
 | 
			
		||||
	result := s.cli.Get(s.ctx, s.wrapperKey(token))
 | 
			
		||||
	return s.parseBasicID(result)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create and store the new token information
 | 
			
		||||
func (s *redisStore) Create(info TokenInfo) error {
 | 
			
		||||
	ct := time.Now()
 | 
			
		||||
	jv, err := jsonMarshal(info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pipe := s.cli.TxPipeline()
 | 
			
		||||
	basicID := uuid.Must(uuid.NewRandom()).String()
 | 
			
		||||
	aexp := info.GetAccessExpiresIn()
 | 
			
		||||
	rexp := aexp
 | 
			
		||||
 | 
			
		||||
	if refresh := info.GetRefresh(); refresh != "" {
 | 
			
		||||
		rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
 | 
			
		||||
		if aexp.Seconds() > rexp.Seconds() {
 | 
			
		||||
			aexp = rexp
 | 
			
		||||
		}
 | 
			
		||||
		pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
 | 
			
		||||
	pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
 | 
			
		||||
 | 
			
		||||
	if _, err = pipe.Exec(s.ctx); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByAccess Use the access token to delete the token information
 | 
			
		||||
func (s *redisStore) RemoveByAccess(access string) error {
 | 
			
		||||
	return s.removeToken(access, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveByRefresh Use the refresh token to delete the token information
 | 
			
		||||
func (s *redisStore) RemoveByRefresh(refresh string) error {
 | 
			
		||||
	return s.removeToken(refresh, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByAccess Use the access token for token information data
 | 
			
		||||
func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := s.getBasicID(access)
 | 
			
		||||
	if err != nil || basicID == "" {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return s.getToken(basicID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetByRefresh Use the refresh token for token information data
 | 
			
		||||
func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
 | 
			
		||||
	basicID, err := s.getBasicID(refresh)
 | 
			
		||||
	if err != nil || basicID == "" {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return s.getToken(basicID)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										118
									
								
								pkg/auth/token.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								pkg/auth/token.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,118 @@
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// TokenInfo the token information model interface
 | 
			
		||||
type TokenInfo interface {
 | 
			
		||||
	New() TokenInfo
 | 
			
		||||
 | 
			
		||||
	GetUserID() string
 | 
			
		||||
	SetUserID(string)
 | 
			
		||||
 | 
			
		||||
	GetAccess() string
 | 
			
		||||
	SetAccess(string)
 | 
			
		||||
	GetAccessCreateAt() time.Time
 | 
			
		||||
	SetAccessCreateAt(time.Time)
 | 
			
		||||
	GetAccessExpiresIn() time.Duration
 | 
			
		||||
	SetAccessExpiresIn(time.Duration)
 | 
			
		||||
 | 
			
		||||
	GetRefresh() string
 | 
			
		||||
	SetRefresh(string)
 | 
			
		||||
	GetRefreshCreateAt() time.Time
 | 
			
		||||
	SetRefreshCreateAt(time.Time)
 | 
			
		||||
	GetRefreshExpiresIn() time.Duration
 | 
			
		||||
	SetRefreshExpiresIn(time.Duration)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewToken create to token model instance
 | 
			
		||||
func NewToken() *Token {
 | 
			
		||||
	return &Token{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Token token model
 | 
			
		||||
type Token struct {
 | 
			
		||||
	UserID           string        `bson:"UserID"`
 | 
			
		||||
	Access           string        `bson:"Access"`
 | 
			
		||||
	AccessCreateAt   time.Time     `bson:"AccessCreateAt"`
 | 
			
		||||
	AccessExpiresIn  time.Duration `bson:"AccessExpiresIn"`
 | 
			
		||||
	Refresh          string        `bson:"Refresh"`
 | 
			
		||||
	RefreshCreateAt  time.Time     `bson:"RefreshCreateAt"`
 | 
			
		||||
	RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New create to token model instance
 | 
			
		||||
func (t *Token) New() TokenInfo {
 | 
			
		||||
	return NewToken()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetUserID the user id
 | 
			
		||||
func (t *Token) GetUserID() string {
 | 
			
		||||
	return t.UserID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetUserID the user id
 | 
			
		||||
func (t *Token) SetUserID(userID string) {
 | 
			
		||||
	t.UserID = userID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAccess access Token
 | 
			
		||||
func (t *Token) GetAccess() string {
 | 
			
		||||
	return t.Access
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetAccess access Token
 | 
			
		||||
func (t *Token) SetAccess(access string) {
 | 
			
		||||
	t.Access = access
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAccessCreateAt create Time
 | 
			
		||||
func (t *Token) GetAccessCreateAt() time.Time {
 | 
			
		||||
	return t.AccessCreateAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetAccessCreateAt create Time
 | 
			
		||||
func (t *Token) SetAccessCreateAt(createAt time.Time) {
 | 
			
		||||
	t.AccessCreateAt = createAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAccessExpiresIn the lifetime in seconds of the access token
 | 
			
		||||
func (t *Token) GetAccessExpiresIn() time.Duration {
 | 
			
		||||
	return t.AccessExpiresIn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetAccessExpiresIn the lifetime in seconds of the access token
 | 
			
		||||
func (t *Token) SetAccessExpiresIn(exp time.Duration) {
 | 
			
		||||
	t.AccessExpiresIn = exp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefresh refresh Token
 | 
			
		||||
func (t *Token) GetRefresh() string {
 | 
			
		||||
	return t.Refresh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefresh refresh Token
 | 
			
		||||
func (t *Token) SetRefresh(refresh string) {
 | 
			
		||||
	t.Refresh = refresh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefreshCreateAt create Time
 | 
			
		||||
func (t *Token) GetRefreshCreateAt() time.Time {
 | 
			
		||||
	return t.RefreshCreateAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefreshCreateAt create Time
 | 
			
		||||
func (t *Token) SetRefreshCreateAt(createAt time.Time) {
 | 
			
		||||
	t.RefreshCreateAt = createAt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefreshExpiresIn the lifetime in seconds of the refresh token
 | 
			
		||||
func (t *Token) GetRefreshExpiresIn() time.Duration {
 | 
			
		||||
	return t.RefreshExpiresIn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetRefreshExpiresIn the lifetime in seconds of the refresh token
 | 
			
		||||
func (t *Token) SetRefreshExpiresIn(exp time.Duration) {
 | 
			
		||||
	t.RefreshExpiresIn = exp
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										39
									
								
								pkg/bcrypt/password.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								pkg/bcrypt/password.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,39 @@
 | 
			
		||||
package bcrypt
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"golang.org/x/crypto/bcrypt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Password = (*password)(nil)
 | 
			
		||||
 | 
			
		||||
type Password interface {
 | 
			
		||||
	i()
 | 
			
		||||
	Generate(pwd string) string
 | 
			
		||||
	Validate(pwd, hash string) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type password struct {
 | 
			
		||||
	cost int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *password) i() {}
 | 
			
		||||
 | 
			
		||||
func NewPassword(cost int) (Password, error) {
 | 
			
		||||
	if cost < bcrypt.MinCost || cost > bcrypt.MaxCost {
 | 
			
		||||
		return nil, fmt.Errorf("cost out of range")
 | 
			
		||||
	}
 | 
			
		||||
	return &password{
 | 
			
		||||
		cost: cost,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *password) Generate(pwd string) string {
 | 
			
		||||
	hash, _ := bcrypt.GenerateFromPassword([]byte(pwd), p.cost)
 | 
			
		||||
	return string(hash)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *password) Validate(pwd, hash string) bool {
 | 
			
		||||
	err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pwd))
 | 
			
		||||
	return err == nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										23
									
								
								pkg/browser/browser.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								pkg/browser/browser.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
package browser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"runtime"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var commands = map[string]string{
 | 
			
		||||
	"windows": "start",
 | 
			
		||||
	"darwin":  "open",
 | 
			
		||||
	"linux":   "xdg-open",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Open(uri string) error {
 | 
			
		||||
	run, ok := commands[runtime.GOOS]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return fmt.Errorf("don't know how to open things on %s platform", runtime.GOOS)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cmd := exec.Command(run, uri)
 | 
			
		||||
	return cmd.Start()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										482
									
								
								pkg/cache/redis.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										482
									
								
								pkg/cache/redis.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,482 @@
 | 
			
		||||
package cache
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
	"github.com/redis/go-redis/v9"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
type Trace = trace.T
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	Trace *trace.Trace
 | 
			
		||||
	Redis *trace.Redis
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RedisConfig struct {
 | 
			
		||||
	Addr        string `yaml:"addr"`
 | 
			
		||||
	Pass        string `yaml:"pass"`
 | 
			
		||||
	DB          int    `yaml:"db"`
 | 
			
		||||
	MaxRetries  int    `yaml:"maxRetries"`  // 最大重试次数
 | 
			
		||||
	PoolSize    int    `yaml:"poolSize"`    // Redis连接池大小
 | 
			
		||||
	MinIdleConn int    `yaml:"minIdleConn"` // 最小空闲连接数
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newOption() *option {
 | 
			
		||||
	return &option{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Repo = (*cacheRepo)(nil)
 | 
			
		||||
 | 
			
		||||
type Repo interface {
 | 
			
		||||
	i()
 | 
			
		||||
	Client() *redis.Client
 | 
			
		||||
	Set(key, value string, ttl time.Duration, options ...Option) error
 | 
			
		||||
	Get(key string, options ...Option) (string, error)
 | 
			
		||||
	TTL(key string) (time.Duration, error)
 | 
			
		||||
	Expire(key string, ttl time.Duration) bool
 | 
			
		||||
	ExpireAt(key string, ttl time.Time) bool
 | 
			
		||||
	Del(key string, options ...Option) bool
 | 
			
		||||
	Exists(keys ...string) bool
 | 
			
		||||
	Incr(key string, options ...Option) (int64, error)
 | 
			
		||||
	Decr(key string, options ...Option) (int64, error)
 | 
			
		||||
	HGet(key, field string, options ...Option) (string, error)
 | 
			
		||||
	HSet(key, field, value string, options ...Option) error
 | 
			
		||||
	HDel(key, field string, options ...Option) error
 | 
			
		||||
	HGetAll(key string, options ...Option) (map[string]string, error)
 | 
			
		||||
	HIncrBy(key, field string, incr int64, options ...Option) (int64, error)
 | 
			
		||||
	HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error)
 | 
			
		||||
	LPush(key, value string, options ...Option) error
 | 
			
		||||
	LLen(key string, options ...Option) (int64, error)
 | 
			
		||||
	BRPop(key string, timeout time.Duration, options ...Option) (string, error)
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type cacheRepo struct {
 | 
			
		||||
	client *redis.Client
 | 
			
		||||
	ctx    context.Context
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(cfg RedisConfig) (Repo, error) {
 | 
			
		||||
	client := redis.NewClient(&redis.Options{
 | 
			
		||||
		Addr:         cfg.Addr,
 | 
			
		||||
		Password:     cfg.Pass,
 | 
			
		||||
		DB:           cfg.DB,
 | 
			
		||||
		MaxRetries:   cfg.MaxRetries,
 | 
			
		||||
		PoolSize:     cfg.PoolSize,
 | 
			
		||||
		MinIdleConns: cfg.MinIdleConn,
 | 
			
		||||
	})
 | 
			
		||||
	ctx := context.TODO()
 | 
			
		||||
	if err := client.Ping(ctx).Err(); err != nil {
 | 
			
		||||
		return nil, errors.Join(err, errors.New("ping redis err"))
 | 
			
		||||
	}
 | 
			
		||||
	return &cacheRepo{
 | 
			
		||||
		client: client,
 | 
			
		||||
		ctx:    ctx,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WithTrace(t Trace) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		if t != nil {
 | 
			
		||||
			opt.Trace = t.(*trace.Trace)
 | 
			
		||||
			opt.Redis = new(trace.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) i() {}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Client() *redis.Client {
 | 
			
		||||
	return c.client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Set(key, value string, ttl time.Duration, options ...Option) error {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "set"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.Value = value
 | 
			
		||||
			opt.Redis.TTL = ttl.Minutes()
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.client.Set(c.ctx, key, value, ttl).Err(); err != nil {
 | 
			
		||||
		return errors.Join(err, fmt.Errorf("redis set key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Get(key string, options ...Option) (string, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "get"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := c.client.Get(c.ctx, key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Join(err, fmt.Errorf("redis get key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) TTL(key string) (time.Duration, error) {
 | 
			
		||||
	ttl, err := c.client.TTL(c.ctx, key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return -1, errors.Join(err, fmt.Errorf("redis get key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ttl, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Expire(key string, ttl time.Duration) bool {
 | 
			
		||||
	ok, _ := c.client.Expire(c.ctx, key, ttl).Result()
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) ExpireAt(key string, ttl time.Time) bool {
 | 
			
		||||
	ok, _ := c.client.ExpireAt(c.ctx, key, ttl).Result()
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Exists(keys ...string) bool {
 | 
			
		||||
	if len(keys) == 0 {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	value, _ := c.client.Exists(c.ctx, keys...).Result()
 | 
			
		||||
	return value > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Del(key string, options ...Option) bool {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "del"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if key == "" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, _ := c.client.Del(c.ctx, key).Result()
 | 
			
		||||
	return value > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Incr(key string, options ...Option) (int64, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "incr"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
	value, err := c.client.Incr(c.ctx, key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Join(err, fmt.Errorf("redis incr key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Decr(key string, options ...Option) (int64, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "decr"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
	value, err := c.client.Decr(c.ctx, key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Join(err, fmt.Errorf("redis decr key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) HGet(key, field string, options ...Option) (string, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "hash get"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.Value = field
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := c.client.HGet(c.ctx, key, field).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Join(err, fmt.Errorf("redis hget key: %s field: %s err", key, field))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) HSet(key, field, value string, options ...Option) error {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "hash set"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.Value = field + "/" + value
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.client.HSet(c.ctx, key, field, value).Err(); err != nil {
 | 
			
		||||
		return errors.Join(err, fmt.Errorf("redis hset key: %s field: %s err", key, field))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) HDel(key, field string, options ...Option) error {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "hash del"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.Value = field
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.client.HDel(c.ctx, key, field).Err(); err != nil {
 | 
			
		||||
		return errors.Join(err, fmt.Errorf("redis hdel key: %s field: %s err", key, field))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) HGetAll(key string, options ...Option) (map[string]string, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "hash get all"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := c.client.HGetAll(c.ctx, key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Join(err, fmt.Errorf("redis hget all key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) HIncrBy(key, field string, incr int64, options ...Option) (int64, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "hash incr int64"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr)
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := c.client.HIncrBy(c.ctx, key, field, incr).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Join(err, fmt.Errorf("redis hash incr int64 key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "hash incr float64"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr)
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := c.client.HIncrByFloat(c.ctx, key, field, incr).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Join(err, fmt.Errorf("redis hash incr float64 key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) LPush(key, value string, options ...Option) error {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "list push"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.Value = value
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := c.client.LPush(c.ctx, key, value).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.Join(err, fmt.Errorf("redis list push key: %s value: %s err", key, value))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) LLen(key string, options ...Option) (int64, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "list len"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := c.client.LLen(c.ctx, key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, errors.Join(err, fmt.Errorf("redis list len key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return value, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) BRPop(key string, timeout time.Duration, options ...Option) (string, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Redis.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
			opt.Redis.Handle = "list brpop"
 | 
			
		||||
			opt.Redis.Key = key
 | 
			
		||||
			opt.Redis.TTL = timeout.Seconds()
 | 
			
		||||
			opt.Redis.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendRedis(opt.Redis)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	value, err := c.client.BRPop(c.ctx, timeout, key).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Join(err, fmt.Errorf("redis list len key: %s err", key))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return value[1], nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *cacheRepo) Close() error {
 | 
			
		||||
	return c.client.Close()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										97
									
								
								pkg/captcha/base64.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								pkg/captcha/base64.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
package captcha
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/cache"
 | 
			
		||||
	"github.com/mojocn/base64Captcha"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ base64Captcha.Store = (*store)(nil)
 | 
			
		||||
 | 
			
		||||
type store struct {
 | 
			
		||||
	cache  cache.Repo
 | 
			
		||||
	ttl    time.Duration
 | 
			
		||||
	logger *zap.Logger
 | 
			
		||||
	ns     string
 | 
			
		||||
	prefix string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *store) Set(id string, value string) error {
 | 
			
		||||
	err := s.cache.Set(fmt.Sprintf("%s%s%s", s.ns, s.prefix, id), value, s.ttl)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *store) Get(id string, clear bool) string {
 | 
			
		||||
	value, err := s.cache.Get(fmt.Sprintf("%s%s%s", s.ns, s.prefix, id))
 | 
			
		||||
	if err == nil && clear {
 | 
			
		||||
		s.cache.Del(fmt.Sprintf("%s%s%s", s.ns, s.prefix, id))
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *store) Verify(id, answer string, clear bool) bool {
 | 
			
		||||
	value := s.Get(id, clear)
 | 
			
		||||
	if value == "" || answer == "" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return strings.ToLower(value) == strings.ToLower(answer)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewStore(cache cache.Repo, ttl time.Duration, namespace string) base64Captcha.Store {
 | 
			
		||||
	return &store{
 | 
			
		||||
		cache:  cache,
 | 
			
		||||
		ttl:    ttl,
 | 
			
		||||
		ns:     namespace,
 | 
			
		||||
		prefix: "captcha:base64:",
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Captcha = (*captcha)(nil)
 | 
			
		||||
 | 
			
		||||
type Captcha interface {
 | 
			
		||||
	Generate() (id, b64s, answer string, err error)
 | 
			
		||||
	Verify(id, value string) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type captcha struct {
 | 
			
		||||
	driver base64Captcha.Driver
 | 
			
		||||
	store  base64Captcha.Store
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewStringCaptcha(store base64Captcha.Store, height, width, length int) Captcha {
 | 
			
		||||
	conf := &base64Captcha.DriverString{
 | 
			
		||||
		Height:          height,
 | 
			
		||||
		Width:           width,
 | 
			
		||||
		NoiseCount:      length,
 | 
			
		||||
		ShowLineOptions: base64Captcha.OptionShowHollowLine,
 | 
			
		||||
		Length:          length,
 | 
			
		||||
		Source:          "ABCDEFGHIJKMNPQRSTUVWXYZabcdefghijkmnpqrstuvwxyz0123456789",
 | 
			
		||||
	}
 | 
			
		||||
	return &captcha{
 | 
			
		||||
		driver: conf.ConvertFonts(),
 | 
			
		||||
		store:  store,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDigitCaptcha(store base64Captcha.Store, height, width, length int) Captcha {
 | 
			
		||||
	conf := base64Captcha.NewDriverDigit(height, width, length, 0.7, height)
 | 
			
		||||
	return &captcha{
 | 
			
		||||
		driver: conf,
 | 
			
		||||
		store:  store,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *captcha) Generate() (id, b64s, answer string, err error) {
 | 
			
		||||
	newCaptcha := base64Captcha.NewCaptcha(c.driver, c.store)
 | 
			
		||||
	return newCaptcha.Generate()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *captcha) Verify(id, value string) bool {
 | 
			
		||||
	return c.store.Verify(id, value, true)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										119
									
								
								pkg/cidr/calc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								pkg/cidr/calc.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
			
		||||
package cidr
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sort"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SuperNetting 合并网段
 | 
			
		||||
func SuperNetting(ns []string) (*cidr, error) {
 | 
			
		||||
	num := len(ns)
 | 
			
		||||
	if num < 1 || (num&(num-1)) != 0 {
 | 
			
		||||
		return nil, fmt.Errorf("子网数量必须是2的次方")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mask := ""
 | 
			
		||||
	var cidrs []*cidr
 | 
			
		||||
	for _, n := range ns {
 | 
			
		||||
		// 检查子网CIDR有效性
 | 
			
		||||
		c, err := ParseCIDR(n)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("网段%v格式错误", n)
 | 
			
		||||
		}
 | 
			
		||||
		cidrs = append(cidrs, c)
 | 
			
		||||
 | 
			
		||||
		// TODO 暂只考虑相同子网掩码的网段合并
 | 
			
		||||
		if len(mask) == 0 {
 | 
			
		||||
			mask = c.Mask()
 | 
			
		||||
		} else if c.Mask() != mask {
 | 
			
		||||
			return nil, fmt.Errorf("子网掩码不一致")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	AscSortCIDRs(cidrs)
 | 
			
		||||
 | 
			
		||||
	// 检查网段是否连续
 | 
			
		||||
	var network net.IP
 | 
			
		||||
	for _, c := range cidrs {
 | 
			
		||||
		if len(network) > 0 {
 | 
			
		||||
			if !network.Equal(c.ipNet.IP) {
 | 
			
		||||
				return nil, fmt.Errorf("必须是连续的网段")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		network = net.ParseIP(c.Broadcast())
 | 
			
		||||
		IncrIP(network)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 子网掩码左移,得到共同的父网段
 | 
			
		||||
	c := cidrs[0]
 | 
			
		||||
	ones, bits := c.MaskSize()
 | 
			
		||||
	ones = ones - int(math.Log2(float64(num)))
 | 
			
		||||
	c.ipNet.Mask = net.CIDRMask(ones, bits)
 | 
			
		||||
	c.ipNet.IP.Mask(c.ipNet.Mask)
 | 
			
		||||
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IncrIP IP地址自增
 | 
			
		||||
func IncrIP(ip net.IP) {
 | 
			
		||||
	for i := len(ip) - 1; i >= 0; i-- {
 | 
			
		||||
		ip[i]++
 | 
			
		||||
		if ip[i] > 0 {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DecrIP IP地址自减
 | 
			
		||||
func DecrIP(ip net.IP) {
 | 
			
		||||
	length := len(ip)
 | 
			
		||||
	for i := length - 1; i >= 0; i-- {
 | 
			
		||||
		ip[length-1]--
 | 
			
		||||
		if ip[length-1] < 0xFF {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		for j := 1; j < length; j++ {
 | 
			
		||||
			ip[length-j-1]--
 | 
			
		||||
			if ip[length-j-1] < 0xFF {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Compare 比较IP大小 a等于b,返回0; a大于b,返回+1; a小于b,返回-1
 | 
			
		||||
func Compare(a, b net.IP) int {
 | 
			
		||||
	return bytes.Compare(a, b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AscSortCIDRs 升序
 | 
			
		||||
func AscSortCIDRs(cs []*cidr) {
 | 
			
		||||
	sort.Slice(cs, func(i, j int) bool {
 | 
			
		||||
		if n := bytes.Compare(cs[i].ipNet.IP, cs[j].ipNet.IP); n != 0 {
 | 
			
		||||
			return n < 0
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if n := bytes.Compare(cs[i].ipNet.Mask, cs[j].ipNet.Mask); n != 0 {
 | 
			
		||||
			return n < 0
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return false
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DescSortCIDRs 降序
 | 
			
		||||
func DescSortCIDRs(cs []*cidr) {
 | 
			
		||||
	sort.Slice(cs, func(i, j int) bool {
 | 
			
		||||
		if n := bytes.Compare(cs[i].ipNet.IP, cs[j].ipNet.IP); n != 0 {
 | 
			
		||||
			return n >= 0
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if n := bytes.Compare(cs[i].ipNet.Mask, cs[j].ipNet.Mask); n != 0 {
 | 
			
		||||
			return n >= 0
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return false
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										194
									
								
								pkg/cidr/ip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								pkg/cidr/ip.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,194 @@
 | 
			
		||||
package cidr
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math"
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 裂解子网的方式
 | 
			
		||||
const (
 | 
			
		||||
	MethodSubnetNum = 0 // 基于子网数量
 | 
			
		||||
	MethodHostNum   = 1 // 基于主机数量
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ CIDR = (*cidr)(nil)
 | 
			
		||||
 | 
			
		||||
type CIDR interface {
 | 
			
		||||
	CIDR() string
 | 
			
		||||
	IP() string
 | 
			
		||||
	Network() string
 | 
			
		||||
	Broadcast() string
 | 
			
		||||
	Mask() string
 | 
			
		||||
	MaskSize() (int, int)
 | 
			
		||||
	IPRange() (string, string)
 | 
			
		||||
	IPCount() *big.Int
 | 
			
		||||
 | 
			
		||||
	IsIPv4() bool
 | 
			
		||||
	IsIPv6() bool
 | 
			
		||||
 | 
			
		||||
	Equal(string) bool
 | 
			
		||||
	Contains(string) bool
 | 
			
		||||
	ForEachIP(func(string) error) error
 | 
			
		||||
	ForEachIPBeginWith(string, func(string) error) error
 | 
			
		||||
	SubNetting(method, num int) ([]*cidr, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type cidr struct {
 | 
			
		||||
	ip    net.IP
 | 
			
		||||
	ipNet *net.IPNet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseCIDR 解析CIDR网段
 | 
			
		||||
func ParseCIDR(s string) (*cidr, error) {
 | 
			
		||||
	i, n, err := net.ParseCIDR(s)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &cidr{ip: i, ipNet: n}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Equal 判断网段是否相等
 | 
			
		||||
func (c *cidr) Equal(ns string) bool {
 | 
			
		||||
	c2, err := ParseCIDR(ns)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return c.ipNet.IP.Equal(c2.ipNet.IP) /* && c.ipNet.IP.Equal(c2.ip) */
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsIPv4 判断是否IPv4
 | 
			
		||||
func (c *cidr) IsIPv4() bool {
 | 
			
		||||
	_, bits := c.ipNet.Mask.Size()
 | 
			
		||||
	return bits/8 == net.IPv4len
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsIPv6 判断是否IPv6
 | 
			
		||||
func (c *cidr) IsIPv6() bool {
 | 
			
		||||
	_, bits := c.ipNet.Mask.Size()
 | 
			
		||||
	return bits/8 == net.IPv6len
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Contains 判断IP是否包含在网段中
 | 
			
		||||
func (c *cidr) Contains(ip string) bool {
 | 
			
		||||
	return c.ipNet.Contains(net.ParseIP(ip))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CIDR 根据子网掩码长度校准后的CIDR
 | 
			
		||||
func (c *cidr) CIDR() string {
 | 
			
		||||
	return c.ipNet.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IP CIDR字符串中的IP部分
 | 
			
		||||
func (c *cidr) IP() string {
 | 
			
		||||
	return c.ip.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Network 网络号
 | 
			
		||||
func (c *cidr) Network() string {
 | 
			
		||||
	return c.ipNet.IP.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MaskSize 子网掩码位数
 | 
			
		||||
func (c *cidr) MaskSize() (ones, bits int) {
 | 
			
		||||
	ones, bits = c.ipNet.Mask.Size()
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mask 子网掩码
 | 
			
		||||
func (c *cidr) Mask() string {
 | 
			
		||||
	mask, _ := hex.DecodeString(c.ipNet.Mask.String())
 | 
			
		||||
	return net.IP([]byte(mask)).String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Broadcast 广播地址(网段最后一个IP)
 | 
			
		||||
func (c *cidr) Broadcast() string {
 | 
			
		||||
	mask := c.ipNet.Mask
 | 
			
		||||
	bcst := make(net.IP, len(c.ipNet.IP))
 | 
			
		||||
	copy(bcst, c.ipNet.IP)
 | 
			
		||||
	for i := 0; i < len(mask); i++ {
 | 
			
		||||
		ipIdx := len(bcst) - i - 1
 | 
			
		||||
		bcst[ipIdx] = c.ipNet.IP[ipIdx] | ^mask[len(mask)-i-1]
 | 
			
		||||
	}
 | 
			
		||||
	return bcst.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IPRange 起始IP、结束IP
 | 
			
		||||
func (c *cidr) IPRange() (start, end string) {
 | 
			
		||||
	return c.Network(), c.Broadcast()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IPCount IP数量
 | 
			
		||||
func (c *cidr) IPCount() *big.Int {
 | 
			
		||||
	ones, bits := c.ipNet.Mask.Size()
 | 
			
		||||
	return big.NewInt(0).Lsh(big.NewInt(1), uint(bits-ones))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ForEachIP 遍历网段下所有IP
 | 
			
		||||
func (c *cidr) ForEachIP(iterator func(ip string) error) error {
 | 
			
		||||
	next := make(net.IP, len(c.ipNet.IP))
 | 
			
		||||
	copy(next, c.ipNet.IP)
 | 
			
		||||
	for c.ipNet.Contains(next) {
 | 
			
		||||
		if err := iterator(next.String()); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		IncrIP(next)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ForEachIPBeginWith 从指定IP开始遍历网段下后续的IP
 | 
			
		||||
func (c *cidr) ForEachIPBeginWith(beginIP string, iterator func(ip string) error) error {
 | 
			
		||||
	next := net.ParseIP(beginIP)
 | 
			
		||||
	for c.ipNet.Contains(next) {
 | 
			
		||||
		if err := iterator(next.String()); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		IncrIP(next)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SubNetting 裂解网段
 | 
			
		||||
func (c *cidr) SubNetting(method, num int) ([]*cidr, error) {
 | 
			
		||||
	if num < 1 || (num&(num-1)) != 0 {
 | 
			
		||||
		return nil, fmt.Errorf("裂解数量必须是2的次方")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newOnes := int(math.Log2(float64(num)))
 | 
			
		||||
	ones, bits := c.MaskSize()
 | 
			
		||||
	switch method {
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("不支持的裂解方式")
 | 
			
		||||
	case MethodSubnetNum:
 | 
			
		||||
		newOnes = ones + newOnes
 | 
			
		||||
		// 如果子网的掩码长度大于父网段的长度,则无法裂解
 | 
			
		||||
		if newOnes > bits {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
	case MethodHostNum:
 | 
			
		||||
		newOnes = bits - newOnes
 | 
			
		||||
		// 如果子网的掩码长度小于等于父网段的掩码长度,则无法裂解
 | 
			
		||||
		if newOnes <= ones {
 | 
			
		||||
			return nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		// 主机数量转换为子网数量
 | 
			
		||||
		num = int(math.Pow(float64(2), float64(newOnes-ones)))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var cidrs []*cidr
 | 
			
		||||
	network := make(net.IP, len(c.ipNet.IP))
 | 
			
		||||
	copy(network, c.ipNet.IP)
 | 
			
		||||
	for i := 0; i < num; i++ {
 | 
			
		||||
		cidr, _ := ParseCIDR(fmt.Sprintf("%v/%v", network.String(), newOnes))
 | 
			
		||||
		cidrs = append(cidrs, cidr)
 | 
			
		||||
 | 
			
		||||
		// 广播地址的下一个IP即为下一段的网络号
 | 
			
		||||
		network = net.ParseIP(cidr.Broadcast())
 | 
			
		||||
		IncrIP(network)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return cidrs, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										360
									
								
								pkg/cmap/cmap.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										360
									
								
								pkg/cmap/cmap.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,360 @@
 | 
			
		||||
package cmap
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ShardCount = 32
 | 
			
		||||
 | 
			
		||||
type Stringer interface {
 | 
			
		||||
	fmt.Stringer
 | 
			
		||||
	comparable
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConcurrentMap A "thread" safe map of type string:Anything.
 | 
			
		||||
// To avoid lock bottlenecks this map is dived to several (ShardCount) map shards.
 | 
			
		||||
type ConcurrentMap[K comparable, V any] struct {
 | 
			
		||||
	shards   []*ConcurrentMapShared[K, V]
 | 
			
		||||
	sharding func(key K) uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConcurrentMapShared A "thread" safe string to anything map.
 | 
			
		||||
type ConcurrentMapShared[K comparable, V any] struct {
 | 
			
		||||
	items        map[K]V
 | 
			
		||||
	sync.RWMutex // Read Write mutex, guards access to internal map.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
 | 
			
		||||
	m := ConcurrentMap[K, V]{
 | 
			
		||||
		sharding: sharding,
 | 
			
		||||
		shards:   make([]*ConcurrentMapShared[K, V], ShardCount),
 | 
			
		||||
	}
 | 
			
		||||
	for i := 0; i < ShardCount; i++ {
 | 
			
		||||
		m.shards[i] = &ConcurrentMapShared[K, V]{items: make(map[K]V)}
 | 
			
		||||
	}
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New Creates a new concurrent map.
 | 
			
		||||
func New[V any]() ConcurrentMap[string, V] {
 | 
			
		||||
	return create[string, V](fnv32)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewStringer Creates a new concurrent map.
 | 
			
		||||
func NewStringer[K Stringer, V any]() ConcurrentMap[K, V] {
 | 
			
		||||
	return create[K, V](strfnv32[K])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewWithCustomShardingFunction Creates a new concurrent map.
 | 
			
		||||
func NewWithCustomShardingFunction[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
 | 
			
		||||
	return create[K, V](sharding)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetShard returns shard under given key
 | 
			
		||||
func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] {
 | 
			
		||||
	return m.shards[uint(m.sharding(key))%uint(ShardCount)]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m ConcurrentMap[K, V]) MSet(data map[K]V) {
 | 
			
		||||
	for key, value := range data {
 | 
			
		||||
		shard := m.GetShard(key)
 | 
			
		||||
		shard.Lock()
 | 
			
		||||
		shard.items[key] = value
 | 
			
		||||
		shard.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Set Sets the given value under the specified key.
 | 
			
		||||
func (m ConcurrentMap[K, V]) Set(key K, value V) {
 | 
			
		||||
	// Get map shard.
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.Lock()
 | 
			
		||||
	shard.items[key] = value
 | 
			
		||||
	shard.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpsertCb Callback to return new element to be inserted into the map
 | 
			
		||||
// It is called while lock is held, therefore it MUST NOT
 | 
			
		||||
// try to access other keys in same map, as it can lead to deadlock since
 | 
			
		||||
// Go sync.RWLock is not reentrant
 | 
			
		||||
type UpsertCb[V any] func(exist bool, valueInMap V, newValue V) V
 | 
			
		||||
 | 
			
		||||
// Upsert Insert or Update - updates existing element or inserts a new one using UpsertCb
 | 
			
		||||
func (m ConcurrentMap[K, V]) Upsert(key K, value V, cb UpsertCb[V]) (res V) {
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.Lock()
 | 
			
		||||
	v, ok := shard.items[key]
 | 
			
		||||
	res = cb(ok, v, value)
 | 
			
		||||
	shard.items[key] = res
 | 
			
		||||
	shard.Unlock()
 | 
			
		||||
	return res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetIfAbsent Sets the given value under the specified key if no value was associated with it.
 | 
			
		||||
func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool {
 | 
			
		||||
	// Get map shard.
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.Lock()
 | 
			
		||||
	_, ok := shard.items[key]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		shard.items[key] = value
 | 
			
		||||
	}
 | 
			
		||||
	shard.Unlock()
 | 
			
		||||
	return !ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get retrieves an element from map under given key.
 | 
			
		||||
func (m ConcurrentMap[K, V]) Get(key K) (V, bool) {
 | 
			
		||||
	// Get shard
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.RLock()
 | 
			
		||||
	// Get item from shard.
 | 
			
		||||
	val, ok := shard.items[key]
 | 
			
		||||
	shard.RUnlock()
 | 
			
		||||
	return val, ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Count returns the number of elements within the map.
 | 
			
		||||
func (m ConcurrentMap[K, V]) Count() int {
 | 
			
		||||
	count := 0
 | 
			
		||||
	for i := 0; i < ShardCount; i++ {
 | 
			
		||||
		shard := m.shards[i]
 | 
			
		||||
		shard.RLock()
 | 
			
		||||
		count += len(shard.items)
 | 
			
		||||
		shard.RUnlock()
 | 
			
		||||
	}
 | 
			
		||||
	return count
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Has Looks up an item under specified key
 | 
			
		||||
func (m ConcurrentMap[K, V]) Has(key K) bool {
 | 
			
		||||
	// Get shard
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.RLock()
 | 
			
		||||
	// See if element is within shard.
 | 
			
		||||
	_, ok := shard.items[key]
 | 
			
		||||
	shard.RUnlock()
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove removes an element from the map.
 | 
			
		||||
func (m ConcurrentMap[K, V]) Remove(key K) {
 | 
			
		||||
	// Try to get shard.
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.Lock()
 | 
			
		||||
	delete(shard.items, key)
 | 
			
		||||
	shard.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held
 | 
			
		||||
// If returns true, the element will be removed from the map
 | 
			
		||||
type RemoveCb[K any, V any] func(key K, v V, exists bool) bool
 | 
			
		||||
 | 
			
		||||
// RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params
 | 
			
		||||
// If callback returns true and element exists, it will remove it from the map
 | 
			
		||||
// Returns the value returned by the callback (even if element was not present in the map)
 | 
			
		||||
func (m ConcurrentMap[K, V]) RemoveCb(key K, cb RemoveCb[K, V]) bool {
 | 
			
		||||
	// Try to get shard.
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.Lock()
 | 
			
		||||
	v, ok := shard.items[key]
 | 
			
		||||
	remove := cb(key, v, ok)
 | 
			
		||||
	if remove && ok {
 | 
			
		||||
		delete(shard.items, key)
 | 
			
		||||
	}
 | 
			
		||||
	shard.Unlock()
 | 
			
		||||
	return remove
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Pop removes an element from the map and returns it
 | 
			
		||||
func (m ConcurrentMap[K, V]) Pop(key K) (v V, exists bool) {
 | 
			
		||||
	// Try to get shard.
 | 
			
		||||
	shard := m.GetShard(key)
 | 
			
		||||
	shard.Lock()
 | 
			
		||||
	v, exists = shard.items[key]
 | 
			
		||||
	delete(shard.items, key)
 | 
			
		||||
	shard.Unlock()
 | 
			
		||||
	return v, exists
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsEmpty checks if map is empty.
 | 
			
		||||
func (m ConcurrentMap[K, V]) IsEmpty() bool {
 | 
			
		||||
	return m.Count() == 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tuple Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
 | 
			
		||||
type Tuple[K comparable, V any] struct {
 | 
			
		||||
	Key K
 | 
			
		||||
	Val V
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IterBuffered returns a buffered iterator which could be used in a for range loop.
 | 
			
		||||
func (m ConcurrentMap[K, V]) IterBuffered() <-chan Tuple[K, V] {
 | 
			
		||||
	chans := snapshot(m)
 | 
			
		||||
	total := 0
 | 
			
		||||
	for _, c := range chans {
 | 
			
		||||
		total += cap(c)
 | 
			
		||||
	}
 | 
			
		||||
	ch := make(chan Tuple[K, V], total)
 | 
			
		||||
	go fanIn(chans, ch)
 | 
			
		||||
	return ch
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Clear removes all items from map.
 | 
			
		||||
func (m ConcurrentMap[K, V]) Clear() {
 | 
			
		||||
	for item := range m.IterBuffered() {
 | 
			
		||||
		m.Remove(item.Key)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns an array of channels that contains elements in each shard,
 | 
			
		||||
// which likely takes a snapshot of `m`.
 | 
			
		||||
// It returns once the size of each buffered channel is determined,
 | 
			
		||||
// before all the channels are populated using goroutines.
 | 
			
		||||
func snapshot[K comparable, V any](m ConcurrentMap[K, V]) (chans []chan Tuple[K, V]) {
 | 
			
		||||
	//When you access map items before initializing.
 | 
			
		||||
	if len(m.shards) == 0 {
 | 
			
		||||
		panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`)
 | 
			
		||||
	}
 | 
			
		||||
	chans = make([]chan Tuple[K, V], ShardCount)
 | 
			
		||||
	wg := sync.WaitGroup{}
 | 
			
		||||
	wg.Add(ShardCount)
 | 
			
		||||
	// Foreach shard.
 | 
			
		||||
	for index, shard := range m.shards {
 | 
			
		||||
		go func(index int, shard *ConcurrentMapShared[K, V]) {
 | 
			
		||||
			// Foreach key, value pair.
 | 
			
		||||
			shard.RLock()
 | 
			
		||||
			chans[index] = make(chan Tuple[K, V], len(shard.items))
 | 
			
		||||
			wg.Done()
 | 
			
		||||
			for key, val := range shard.items {
 | 
			
		||||
				chans[index] <- Tuple[K, V]{key, val}
 | 
			
		||||
			}
 | 
			
		||||
			shard.RUnlock()
 | 
			
		||||
			close(chans[index])
 | 
			
		||||
		}(index, shard)
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	return chans
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fanIn reads elements from channels `chans` into channel `out`
 | 
			
		||||
func fanIn[K comparable, V any](chans []chan Tuple[K, V], out chan Tuple[K, V]) {
 | 
			
		||||
	wg := sync.WaitGroup{}
 | 
			
		||||
	wg.Add(len(chans))
 | 
			
		||||
	for _, ch := range chans {
 | 
			
		||||
		go func(ch chan Tuple[K, V]) {
 | 
			
		||||
			for t := range ch {
 | 
			
		||||
				out <- t
 | 
			
		||||
			}
 | 
			
		||||
			wg.Done()
 | 
			
		||||
		}(ch)
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	close(out)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Items returns all items as map[string]V
 | 
			
		||||
func (m ConcurrentMap[K, V]) Items() map[K]V {
 | 
			
		||||
	tmp := make(map[K]V)
 | 
			
		||||
 | 
			
		||||
	// Insert items to temporary map.
 | 
			
		||||
	for item := range m.IterBuffered() {
 | 
			
		||||
		tmp[item.Key] = item.Val
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return tmp
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IterCb Iterator callbacalled for every key,value found in
 | 
			
		||||
// maps. RLock is held for all calls for a given shard
 | 
			
		||||
// therefore callback sess consistent view of a shard,
 | 
			
		||||
// but not across the shards
 | 
			
		||||
type IterCb[K comparable, V any] func(key K, v V)
 | 
			
		||||
 | 
			
		||||
// IterCb Callback based iterator, cheapest way to read
 | 
			
		||||
// all elements in a map.
 | 
			
		||||
func (m ConcurrentMap[K, V]) IterCb(fn IterCb[K, V]) {
 | 
			
		||||
	for idx := range m.shards {
 | 
			
		||||
		shard := (m.shards)[idx]
 | 
			
		||||
		shard.RLock()
 | 
			
		||||
		for key, value := range shard.items {
 | 
			
		||||
			fn(key, value)
 | 
			
		||||
		}
 | 
			
		||||
		shard.RUnlock()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Keys returns all keys as []string
 | 
			
		||||
func (m ConcurrentMap[K, V]) Keys() []K {
 | 
			
		||||
	count := m.Count()
 | 
			
		||||
	ch := make(chan K, count)
 | 
			
		||||
	go func() {
 | 
			
		||||
		// Foreach shard.
 | 
			
		||||
		wg := sync.WaitGroup{}
 | 
			
		||||
		wg.Add(ShardCount)
 | 
			
		||||
		for _, shard := range m.shards {
 | 
			
		||||
			go func(shard *ConcurrentMapShared[K, V]) {
 | 
			
		||||
				// Foreach key, value pair.
 | 
			
		||||
				shard.RLock()
 | 
			
		||||
				for key := range shard.items {
 | 
			
		||||
					ch <- key
 | 
			
		||||
				}
 | 
			
		||||
				shard.RUnlock()
 | 
			
		||||
				wg.Done()
 | 
			
		||||
			}(shard)
 | 
			
		||||
		}
 | 
			
		||||
		wg.Wait()
 | 
			
		||||
		close(ch)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// Generate keys
 | 
			
		||||
	keys := make([]K, 0, count)
 | 
			
		||||
	for k := range ch {
 | 
			
		||||
		keys = append(keys, k)
 | 
			
		||||
	}
 | 
			
		||||
	return keys
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MarshalJSON Reviles ConcurrentMap "private" variables to json marshal.
 | 
			
		||||
func (m ConcurrentMap[K, V]) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	// Create a temporary map, which will hold all item spread across shards.
 | 
			
		||||
	tmp := make(map[K]V)
 | 
			
		||||
 | 
			
		||||
	// Insert items to temporary map.
 | 
			
		||||
	for item := range m.IterBuffered() {
 | 
			
		||||
		tmp[item.Key] = item.Val
 | 
			
		||||
	}
 | 
			
		||||
	return json.Marshal(tmp)
 | 
			
		||||
}
 | 
			
		||||
func strfnv32[K fmt.Stringer](key K) uint32 {
 | 
			
		||||
	return fnv32(key.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fnv32(key string) uint32 {
 | 
			
		||||
	hash := uint32(2166136261)
 | 
			
		||||
	const prime32 = uint32(16777619)
 | 
			
		||||
	keyLength := len(key)
 | 
			
		||||
	for i := 0; i < keyLength; i++ {
 | 
			
		||||
		hash *= prime32
 | 
			
		||||
		hash ^= uint32(key[i])
 | 
			
		||||
	}
 | 
			
		||||
	return hash
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnmarshalJSON Reverse process of Marshal.
 | 
			
		||||
func (m *ConcurrentMap[K, V]) UnmarshalJSON(b []byte) (err error) {
 | 
			
		||||
	tmp := make(map[K]V)
 | 
			
		||||
 | 
			
		||||
	// Unmarshal into a single map.
 | 
			
		||||
	if err := json.Unmarshal(b, &tmp); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// foreach key,value pair in temporary map insert into our concurrent map.
 | 
			
		||||
	for key, val := range tmp {
 | 
			
		||||
		m.Set(key, val)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										47
									
								
								pkg/color/string_darwin.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								pkg/color/string_darwin.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
			
		||||
//go:build darwin
 | 
			
		||||
// +build darwin
 | 
			
		||||
 | 
			
		||||
package color
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ = RandomColor()
 | 
			
		||||
 | 
			
		||||
// RandomColor generates a random color.
 | 
			
		||||
func RandomColor() string {
 | 
			
		||||
	return fmt.Sprintf("#%s", strconv.FormatInt(int64(rand.Intn(16777216)), 16))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Yellow ...
 | 
			
		||||
func Yellow(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[33m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Red ...
 | 
			
		||||
func Red(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[31m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Redf ...
 | 
			
		||||
func Redf(msg string, arg any) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[31m%s\x1b[0m %+v\n", msg, arg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blue ...
 | 
			
		||||
func Blue(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[34m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Green ...
 | 
			
		||||
func Green(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[32m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Greenf ...
 | 
			
		||||
func Greenf(msg string, arg any) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[32m%s\x1b[0m %+v\n", msg, arg)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										47
									
								
								pkg/color/string_linux.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								pkg/color/string_linux.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package color
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ = RandomColor()
 | 
			
		||||
 | 
			
		||||
// RandomColor generates a random color.
 | 
			
		||||
func RandomColor() string {
 | 
			
		||||
	return fmt.Sprintf("#%s", strconv.FormatInt(int64(rand.Intn(16777216)), 16))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Yellow ...
 | 
			
		||||
func Yellow(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[33m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Red ...
 | 
			
		||||
func Red(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[31m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Redf ...
 | 
			
		||||
func Redf(msg string, arg any) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[31m%s\x1b[0m %+v\n", msg, arg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blue ...
 | 
			
		||||
func Blue(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[34m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Green ...
 | 
			
		||||
func Green(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[32m%s\x1b[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Greenf ...
 | 
			
		||||
func Greenf(msg string, arg any) string {
 | 
			
		||||
	return fmt.Sprintf("\x1b[32m%s\x1b[0m %+v\n", msg, arg)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										47
									
								
								pkg/color/string_windows.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								pkg/color/string_windows.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
			
		||||
//go:build windows
 | 
			
		||||
// +build windows
 | 
			
		||||
 | 
			
		||||
package color
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ = RandomColor()
 | 
			
		||||
 | 
			
		||||
// RandomColor generates a random color.
 | 
			
		||||
func RandomColor() string {
 | 
			
		||||
	return fmt.Sprintf("#%s", strconv.FormatInt(int64(rand.Intn(16777216)), 16))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Yellow ...
 | 
			
		||||
func Yellow(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\033[33m%s\033[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Red ...
 | 
			
		||||
func Red(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\033[31m%s\033[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Redf ...
 | 
			
		||||
func Redf(msg string, arg any) string {
 | 
			
		||||
	return fmt.Sprintf("\033[31m%s\033[0m %+v\n", msg, arg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blue ...
 | 
			
		||||
func Blue(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\033[34m%s\033[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Green ...
 | 
			
		||||
func Green(msg string) string {
 | 
			
		||||
	return fmt.Sprintf("\033[32m%s\033[0m", msg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Greenf ...
 | 
			
		||||
func Greenf(msg string, arg any) string {
 | 
			
		||||
	return fmt.Sprintf("\033[32m%s\033[0m %+v\n", msg, arg)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										38
									
								
								pkg/compress/compress.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								pkg/compress/compress.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
package compress
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"compress/zlib"
 | 
			
		||||
	"io"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Compress = (*compress)(nil)
 | 
			
		||||
 | 
			
		||||
type Compress interface {
 | 
			
		||||
	DoZlibCompress(src []byte) []byte
 | 
			
		||||
	DoZlibUnCompress(compressSrc []byte) []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type compress struct{}
 | 
			
		||||
 | 
			
		||||
func New() Compress {
 | 
			
		||||
	return &compress{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DoZlibCompress 进行zlib压缩
 | 
			
		||||
func (c *compress) DoZlibCompress(src []byte) []byte {
 | 
			
		||||
	var in bytes.Buffer
 | 
			
		||||
	w := zlib.NewWriter(&in)
 | 
			
		||||
	_, _ = w.Write(src)
 | 
			
		||||
	_ = w.Close()
 | 
			
		||||
	return in.Bytes()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DoZlibUnCompress 进行zlib解压缩
 | 
			
		||||
func (c *compress) DoZlibUnCompress(compressSrc []byte) []byte {
 | 
			
		||||
	b := bytes.NewReader(compressSrc)
 | 
			
		||||
	var out bytes.Buffer
 | 
			
		||||
	r, _ := zlib.NewReader(b)
 | 
			
		||||
	_, _ = io.Copy(&out, r)
 | 
			
		||||
	return out.Bytes()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										40
									
								
								pkg/crontab/crontab.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								pkg/crontab/crontab.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,40 @@
 | 
			
		||||
package crontab
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/robfig/cron/v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Crontab = (*crontab)(nil)
 | 
			
		||||
 | 
			
		||||
type Crontab interface {
 | 
			
		||||
	i()
 | 
			
		||||
	AddFunc(spec string, cmd func()) (entryID cron.EntryID, err error)
 | 
			
		||||
	Entries() []cron.Entry
 | 
			
		||||
	Stop()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type crontab struct {
 | 
			
		||||
	cron *cron.Cron
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New() Crontab {
 | 
			
		||||
	return &crontab{
 | 
			
		||||
		cron: cron.New(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *crontab) i() {}
 | 
			
		||||
 | 
			
		||||
func (c *crontab) AddFunc(spec string, cmd func()) (entryID cron.EntryID, err error) {
 | 
			
		||||
	entryID, err = c.cron.AddFunc(spec, cmd)
 | 
			
		||||
	c.cron.Start()
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *crontab) Stop() {
 | 
			
		||||
	c.cron.Stop()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *crontab) Entries() []cron.Entry {
 | 
			
		||||
	return c.cron.Entries()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										77
									
								
								pkg/database/mongo.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								pkg/database/mongo.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,77 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"go.mongodb.org/mongo-driver/mongo"
 | 
			
		||||
	"go.mongodb.org/mongo-driver/mongo/options"
 | 
			
		||||
	"go.mongodb.org/mongo-driver/mongo/readpref"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ MongoDB = (*mongoDB)(nil)
 | 
			
		||||
 | 
			
		||||
type MongoDB interface {
 | 
			
		||||
	i()
 | 
			
		||||
	GetDB() *mongo.Database
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MongoDBConfig struct {
 | 
			
		||||
	Addr    string        `yaml:"addr"`
 | 
			
		||||
	User    string        `yaml:"user"`
 | 
			
		||||
	Pass    string        `yaml:"pass"`
 | 
			
		||||
	Name    string        `yaml:"name"`
 | 
			
		||||
	Timeout time.Duration `yaml:"timeout"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mongoDB struct {
 | 
			
		||||
	client  *mongo.Client
 | 
			
		||||
	db      *mongo.Database
 | 
			
		||||
	timeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mongoDB) i() {}
 | 
			
		||||
 | 
			
		||||
func NewMongoDB(cfg MongoDBConfig) (MongoDB, error) {
 | 
			
		||||
	timeout := cfg.Timeout * time.Second
 | 
			
		||||
	connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
 | 
			
		||||
	defer connectCancelFunc()
 | 
			
		||||
	var auth string
 | 
			
		||||
	if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
 | 
			
		||||
		auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
 | 
			
		||||
	}
 | 
			
		||||
	client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
 | 
			
		||||
		fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
 | 
			
		||||
	))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
 | 
			
		||||
	defer pingCancelFunc()
 | 
			
		||||
	err = client.Ping(pingCtx, readpref.Primary())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &mongoDB{
 | 
			
		||||
		client:  client,
 | 
			
		||||
		db:      client.Database(cfg.Name),
 | 
			
		||||
		timeout: timeout,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mongoDB) GetDB() *mongo.Database {
 | 
			
		||||
	return m.db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mongoDB) Close() error {
 | 
			
		||||
	disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
 | 
			
		||||
	defer disconnectCancelFunc()
 | 
			
		||||
	err := m.client.Disconnect(disconnectCtx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										247
									
								
								pkg/database/mysql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										247
									
								
								pkg/database/mysql.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,247 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
	"gorm.io/driver/mysql"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	"gorm.io/gorm/schema"
 | 
			
		||||
	"gorm.io/gorm/utils"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	callBackBeforeName = "core:before"
 | 
			
		||||
	callBackAfterName  = "core:after"
 | 
			
		||||
	startTime          = "_start_time"
 | 
			
		||||
	traceCtxName       = "_trace_ctx_name"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ MysqlRepo = (*mysqlRepo)(nil)
 | 
			
		||||
 | 
			
		||||
type MysqlRepo interface {
 | 
			
		||||
	i()
 | 
			
		||||
	GetRead(options ...Option) *gorm.DB
 | 
			
		||||
	GetWrite(options ...Option) *gorm.DB
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MySQLConfig struct {
 | 
			
		||||
	Read struct {
 | 
			
		||||
		Addr string `yaml:"addr"`
 | 
			
		||||
		User string `yaml:"user"`
 | 
			
		||||
		Pass string `yaml:"pass"`
 | 
			
		||||
		Name string `yaml:"name"`
 | 
			
		||||
	} `yaml:"read"`
 | 
			
		||||
	Write struct {
 | 
			
		||||
		Addr string `yaml:"addr"`
 | 
			
		||||
		User string `yaml:"user"`
 | 
			
		||||
		Pass string `yaml:"pass"`
 | 
			
		||||
		Name string `yaml:"name"`
 | 
			
		||||
	} `yaml:"write"`
 | 
			
		||||
	Base struct {
 | 
			
		||||
		MaxOpenConn     int           `yaml:"maxOpenConn"`     //最大连接数
 | 
			
		||||
		MaxIdleConn     int           `yaml:"maxIdleConn"`     //最大空闲连接数
 | 
			
		||||
		ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
 | 
			
		||||
	} `yaml:"base"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mysqlRepo struct {
 | 
			
		||||
	read  *gorm.DB
 | 
			
		||||
	write *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
 | 
			
		||||
	dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
 | 
			
		||||
		cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
 | 
			
		||||
		cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &mysqlRepo{
 | 
			
		||||
		read:  dbr,
 | 
			
		||||
		write: dbw,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) i() {}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db := d.read
 | 
			
		||||
	if opt.Trace != nil {
 | 
			
		||||
		db.InstanceSet(traceCtxName, opt.Trace)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db := d.write
 | 
			
		||||
	if opt.Trace != nil {
 | 
			
		||||
		db.InstanceSet(traceCtxName, opt.Trace)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) Close() (err error) {
 | 
			
		||||
	rdb, err1 := d.read.DB()
 | 
			
		||||
	if err1 != nil {
 | 
			
		||||
		err = errors.Join(err1)
 | 
			
		||||
	}
 | 
			
		||||
	err2 := rdb.Close()
 | 
			
		||||
	if err2 != nil {
 | 
			
		||||
		err = errors.Join(err2)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wdb, err3 := d.write.DB()
 | 
			
		||||
	if err3 != nil {
 | 
			
		||||
		err = errors.Join(err3)
 | 
			
		||||
	}
 | 
			
		||||
	err4 := wdb.Close()
 | 
			
		||||
	if err4 != nil {
 | 
			
		||||
		err = errors.Join(err4)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
 | 
			
		||||
	dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
 | 
			
		||||
		user,
 | 
			
		||||
		pass,
 | 
			
		||||
		addr,
 | 
			
		||||
		dbName,
 | 
			
		||||
		true,
 | 
			
		||||
		"Local")
 | 
			
		||||
 | 
			
		||||
	// 日志配置
 | 
			
		||||
	newLogger := logger.New(
 | 
			
		||||
		log.New(os.Stdout, "\r\n", log.LstdFlags),
 | 
			
		||||
		logger.Config{
 | 
			
		||||
			SlowThreshold:             time.Second,  // 慢SQL阈值
 | 
			
		||||
			Colorful:                  true,         // 彩色打印
 | 
			
		||||
			IgnoreRecordNotFoundError: true,         // 忽略记录未找到错误
 | 
			
		||||
			LogLevel:                  logger.Error, // 日志级别
 | 
			
		||||
		},
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
 | 
			
		||||
		NamingStrategy: schema.NamingStrategy{
 | 
			
		||||
			SingularTable: true,
 | 
			
		||||
		},
 | 
			
		||||
		Logger: newLogger,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db.Set("gorm:table_options", "CHARSET=utf8mb4")
 | 
			
		||||
 | 
			
		||||
	sqlDB, err := db.DB()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 设置连接池 用于设置最大打开的连接数,默认值为0表示不限制.设置最大的连接数,可以避免并发太高导致连接mysql出现too many connections的错误。
 | 
			
		||||
	sqlDB.SetMaxOpenConns(maxOpenConn)
 | 
			
		||||
 | 
			
		||||
	// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
 | 
			
		||||
	sqlDB.SetMaxIdleConns(maxIdleConn)
 | 
			
		||||
 | 
			
		||||
	// 设置最大连接超时
 | 
			
		||||
	sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
 | 
			
		||||
 | 
			
		||||
	// 使用插件
 | 
			
		||||
	err = db.Use(&TracePlugin{})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/***************************************************************/
 | 
			
		||||
 | 
			
		||||
type TracePlugin struct{}
 | 
			
		||||
 | 
			
		||||
func (op *TracePlugin) Name() string {
 | 
			
		||||
	return "TracePlugin"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
 | 
			
		||||
	// 开始前
 | 
			
		||||
	_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
 | 
			
		||||
 | 
			
		||||
	// 结束后
 | 
			
		||||
	_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func before(db *gorm.DB) {
 | 
			
		||||
	db.InstanceSet(startTime, time.Now())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func after(db *gorm.DB) {
 | 
			
		||||
	_traceCtx, isExist := db.InstanceGet(traceCtxName)
 | 
			
		||||
	if !isExist {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	_trace, ok := _traceCtx.(trace.T)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ts, isExist := db.InstanceGet(startTime)
 | 
			
		||||
	if !isExist {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ts, ok := _ts.(time.Time)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
 | 
			
		||||
 | 
			
		||||
	sqlInfo := new(trace.SQL)
 | 
			
		||||
	sqlInfo.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
	sqlInfo.SQL = sql
 | 
			
		||||
	sqlInfo.Stack = utils.FileWithLineNum()
 | 
			
		||||
	sqlInfo.Rows = db.Statement.RowsAffected
 | 
			
		||||
	sqlInfo.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
	_trace.AppendSQL(sqlInfo)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										88
									
								
								pkg/database/tool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								pkg/database/tool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"reflect"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type NullTime sql.NullTime
 | 
			
		||||
 | 
			
		||||
func (n *NullTime) Scan(value any) error {
 | 
			
		||||
	return (*sql.NullTime)(n).Scan(value)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n NullTime) Value() (driver.Value, error) {
 | 
			
		||||
	if !n.Valid {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return n.Time, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n NullTime) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	if n.Valid {
 | 
			
		||||
		return json.Marshal(n.Time)
 | 
			
		||||
	}
 | 
			
		||||
	return json.Marshal(nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NullTime) UnmarshalJSON(b []byte) error {
 | 
			
		||||
	if string(b) == "null" {
 | 
			
		||||
		n.Valid = false
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	err := json.Unmarshal(b, &n.Time)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		n.Valid = true
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*-----------------------------------------------------------*/
 | 
			
		||||
 | 
			
		||||
type PaginateList struct {
 | 
			
		||||
	Page  int64 `json:"page"`
 | 
			
		||||
	Size  int64 `json:"size"`
 | 
			
		||||
	Total int64 `json:"total"`
 | 
			
		||||
	List  any   `json:"list"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
 | 
			
		||||
	ptr := reflect.ValueOf(model)
 | 
			
		||||
	if ptr.Kind() != reflect.Ptr {
 | 
			
		||||
		return nil, fmt.Errorf("model must be pointer")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var total int64
 | 
			
		||||
	err := db.Model(model).Count(&total).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &PaginateList{
 | 
			
		||||
			Page:  page,
 | 
			
		||||
			Size:  size,
 | 
			
		||||
			Total: total,
 | 
			
		||||
			List:  make([]any, 0),
 | 
			
		||||
		}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	offset := size * (page - 1)
 | 
			
		||||
	err = db.Limit(int(size)).Offset(int(offset)).Find(model).Error
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &PaginateList{
 | 
			
		||||
			Page:  page,
 | 
			
		||||
			Size:  size,
 | 
			
		||||
			Total: total,
 | 
			
		||||
			List:  make([]any, 0),
 | 
			
		||||
		}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &PaginateList{
 | 
			
		||||
		Page:  page,
 | 
			
		||||
		Size:  size,
 | 
			
		||||
		Total: total,
 | 
			
		||||
		List:  model,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										23
									
								
								pkg/database/trace.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								pkg/database/trace.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import "git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
 | 
			
		||||
type Trace = trace.T
 | 
			
		||||
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
func WithTrace(t Trace) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		if t != nil {
 | 
			
		||||
			opt.Trace = t.(*trace.Trace)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newOption() *option {
 | 
			
		||||
	return &option{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	Trace *trace.Trace
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										26
									
								
								pkg/ddm/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								pkg/ddm/README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
			
		||||
## DDM
 | 
			
		||||
 | 
			
		||||
动态数据掩码(Dynamic Data Masking,简称为DDM)能够防止把敏感数据暴露给未经授权的用户。
 | 
			
		||||
 | 
			
		||||
| 类型 | 要求 | 示例 | 说明
 | 
			
		||||
| ---- | ---- | ---- | ---- 
 | 
			
		||||
| 手机号 | 前 3 后 4 | 132****7986 | 定长 11 位数字
 | 
			
		||||
| 邮箱地址 | 前 1 后 1 | l**w@gmail.com | 仅对 @ 之前的邮箱名称进行掩码
 | 
			
		||||
| 姓名 | 隐姓 | *鸿章 | 将姓氏隐藏
 | 
			
		||||
| 密码 | 不输出 | ****** |
 | 
			
		||||
| 银行卡卡号 | 前 6 后 4 | 622888******5676 | 银行卡卡号最多 19 位数字
 | 
			
		||||
| 身份证号 | 前 1 后 1 | 1******7 | 定长 18 位
 | 
			
		||||
 | 
			
		||||
#### 代码示例
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
// 返回值
 | 
			
		||||
type message struct {
 | 
			
		||||
	Email     ddm.Email    `json:"email"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
msg := new(message)
 | 
			
		||||
msg.Email = ddm.Email("xinliangnote@163.com")
 | 
			
		||||
...
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										35
									
								
								pkg/ddm/benchmark.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								pkg/ddm/benchmark.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
			
		||||
package ddm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/mritd/chinaid"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type BType uint8
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	BMobile BType = iota
 | 
			
		||||
	BIDNo
 | 
			
		||||
	BName
 | 
			
		||||
	BBankNo
 | 
			
		||||
	BEmail
 | 
			
		||||
	BAddress
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Benchmark(bType BType) string {
 | 
			
		||||
	switch bType {
 | 
			
		||||
	case BMobile:
 | 
			
		||||
		return chinaid.Mobile()
 | 
			
		||||
	case BIDNo:
 | 
			
		||||
		return chinaid.IDNo()
 | 
			
		||||
	case BEmail:
 | 
			
		||||
		return chinaid.Email()
 | 
			
		||||
	case BAddress:
 | 
			
		||||
		return chinaid.Address()
 | 
			
		||||
	case BName:
 | 
			
		||||
		return chinaid.Name()
 | 
			
		||||
	case BBankNo:
 | 
			
		||||
		return chinaid.BankNo()
 | 
			
		||||
	default:
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										62
									
								
								pkg/ddm/mark.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								pkg/ddm/mark.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,62 @@
 | 
			
		||||
package ddm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (m Mobile) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	if len(m) != 11 {
 | 
			
		||||
		return []byte(`"` + m + `"`), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	v := fmt.Sprintf("%s****%s", m[:3], m[len(m)-4:])
 | 
			
		||||
	return []byte(`"` + v + `"`), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (bc BankCard) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	if len(bc) > 19 || len(bc) < 16 {
 | 
			
		||||
		return []byte(`"` + bc + `"`), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	v := fmt.Sprintf("%s******%s", bc[:6], bc[len(bc)-4:])
 | 
			
		||||
	return []byte(`"` + v + `"`), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (card IDCard) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	if len(card) != 18 {
 | 
			
		||||
		return []byte(`"` + card + `"`), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	v := fmt.Sprintf("%s******%s", card[:1], card[len(card)-1:])
 | 
			
		||||
	return []byte(`"` + v + `"`), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (name IDName) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	if len(name) < 1 {
 | 
			
		||||
		return []byte(`""`), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nameRune := []rune(name)
 | 
			
		||||
	v := fmt.Sprintf("*%s", string(nameRune[1:]))
 | 
			
		||||
	return []byte(`"` + v + `"`), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pw PassWord) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	v := "******"
 | 
			
		||||
	return []byte(`"` + v + `"`), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e Email) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	if !strings.Contains(string(e), "@") {
 | 
			
		||||
		return []byte(`"` + e + `"`), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	split := strings.Split(string(e), "@")
 | 
			
		||||
	if len(split[0]) < 1 || len(split[1]) < 1 {
 | 
			
		||||
		return []byte(`"` + e + `"`), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	v := fmt.Sprintf("%s***%s", split[0][:1], split[0][len(split[0])-1:])
 | 
			
		||||
	return []byte(`"` + v + "@" + split[1] + `"`), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								pkg/ddm/type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								pkg/ddm/type.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package ddm
 | 
			
		||||
 | 
			
		||||
// 手机号 132****7986
 | 
			
		||||
type Mobile string
 | 
			
		||||
 | 
			
		||||
// 银行卡号 622888******5676
 | 
			
		||||
type BankCard string
 | 
			
		||||
 | 
			
		||||
// 身份证号 1******7
 | 
			
		||||
type IDCard string
 | 
			
		||||
 | 
			
		||||
// 姓名 *鸿章
 | 
			
		||||
type IDName string
 | 
			
		||||
 | 
			
		||||
// 密码 ******
 | 
			
		||||
type PassWord string
 | 
			
		||||
 | 
			
		||||
// 邮箱 l***w@gmail.com
 | 
			
		||||
type Email string
 | 
			
		||||
							
								
								
									
										264
									
								
								pkg/downloader/downloader.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										264
									
								
								pkg/downloader/downloader.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,264 @@
 | 
			
		||||
package downloader
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/downloader/base"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/downloader/controller"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/downloader/fetcher"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/downloader/protocol/http"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/downloader/util"
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Listener func(event *Event)
 | 
			
		||||
 | 
			
		||||
type TaskInfo struct {
 | 
			
		||||
	ID       string
 | 
			
		||||
	Res      *base.Resource
 | 
			
		||||
	Opts     *base.Options
 | 
			
		||||
	Status   base.Status
 | 
			
		||||
	Progress *Progress
 | 
			
		||||
 | 
			
		||||
	fetcher fetcher.Fetcher
 | 
			
		||||
	timer   *util.Timer
 | 
			
		||||
	locker  *sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Progress struct {
 | 
			
		||||
	// 下载耗时(纳秒)
 | 
			
		||||
	Used int64
 | 
			
		||||
	// 每秒下载字节数
 | 
			
		||||
	Speed int64
 | 
			
		||||
	// 已下载的字节数
 | 
			
		||||
	Downloaded int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type downloader struct {
 | 
			
		||||
	*controller.DefaultController
 | 
			
		||||
	fetchBuilders map[string]func() fetcher.Fetcher
 | 
			
		||||
	task          *TaskInfo
 | 
			
		||||
	listener      Listener
 | 
			
		||||
	finished      bool
 | 
			
		||||
	finishedCh    chan error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newDownloader(f func() (protocols []string, builder func() fetcher.Fetcher), options ...controller.Option) *downloader {
 | 
			
		||||
	d := &downloader{
 | 
			
		||||
		DefaultController: controller.NewController(options...),
 | 
			
		||||
		finishedCh:        make(chan error, 1),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	d.fetchBuilders = make(map[string]func() fetcher.Fetcher)
 | 
			
		||||
	protocols, builder := f()
 | 
			
		||||
	for _, p := range protocols {
 | 
			
		||||
		d.fetchBuilders[strings.ToUpper(p)] = builder
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return d
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *downloader) buildFetcher(URL string) (fetcher.Fetcher, error) {
 | 
			
		||||
	parseURL, err := url.Parse(URL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if fetchBuilder, ok := d.fetchBuilders[strings.ToUpper(parseURL.Scheme)]; ok {
 | 
			
		||||
		fetched := fetchBuilder()
 | 
			
		||||
		fetched.Setup(d.DefaultController)
 | 
			
		||||
		return fetched, nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, errors.New("unsupported protocol")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *downloader) Resolve(req *base.Request) (*base.Resource, error) {
 | 
			
		||||
	fetched, err := d.buildFetcher(req.URL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return fetched.Resolve(req)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *downloader) Create(res *base.Resource, opts *base.Options) (err error) {
 | 
			
		||||
	fetched, err := d.buildFetcher(res.Req.URL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !res.Range || opts.Connections < 1 {
 | 
			
		||||
		opts.Connections = 1
 | 
			
		||||
	}
 | 
			
		||||
	err = fetched.Create(res, opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	task := &TaskInfo{
 | 
			
		||||
		ID:       uuid.New().String(),
 | 
			
		||||
		Res:      res,
 | 
			
		||||
		Opts:     opts,
 | 
			
		||||
		Status:   base.DownloadStatusStart,
 | 
			
		||||
		Progress: &Progress{},
 | 
			
		||||
		fetcher:  fetched,
 | 
			
		||||
		timer:    &util.Timer{},
 | 
			
		||||
		locker:   new(sync.Mutex),
 | 
			
		||||
	}
 | 
			
		||||
	d.task = task
 | 
			
		||||
	task.timer.Start()
 | 
			
		||||
	d.emit(EventKeyStart)
 | 
			
		||||
	err = fetched.Start()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		err = fetched.Wait()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			d.emit(EventKeyError, err)
 | 
			
		||||
			task.Status = base.DownloadStatusError
 | 
			
		||||
		} else {
 | 
			
		||||
			task.Progress.Used = task.timer.Used()
 | 
			
		||||
			if task.Res.TotalSize == 0 {
 | 
			
		||||
				task.Res.TotalSize = task.fetcher.Progress().TotalDownloaded()
 | 
			
		||||
			}
 | 
			
		||||
			used := task.Progress.Used / int64(time.Second)
 | 
			
		||||
			if used == 0 {
 | 
			
		||||
				used = 1
 | 
			
		||||
			}
 | 
			
		||||
			task.Progress.Speed = task.Res.TotalSize / used
 | 
			
		||||
			task.Progress.Downloaded = task.Res.TotalSize
 | 
			
		||||
			d.emit(EventKeyDone)
 | 
			
		||||
			task.Status = base.DownloadStatusDone
 | 
			
		||||
		}
 | 
			
		||||
		d.finished = true
 | 
			
		||||
		d.emit(EventKeyFinally, err)
 | 
			
		||||
		d.finishedCh <- err
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// 每秒统计一次下载速度
 | 
			
		||||
	go func() {
 | 
			
		||||
		for !d.finished {
 | 
			
		||||
			if d.task.Status == base.DownloadStatusPause {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			current := d.task.fetcher.Progress().TotalDownloaded()
 | 
			
		||||
			d.task.Progress.Used = d.task.timer.Used()
 | 
			
		||||
			d.task.Progress.Speed = current - d.task.Progress.Downloaded
 | 
			
		||||
			d.task.Progress.Downloaded = current
 | 
			
		||||
			d.emit(EventKeyProgress)
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *downloader) Pause() error {
 | 
			
		||||
	d.task.locker.Lock()
 | 
			
		||||
	defer d.task.locker.Unlock()
 | 
			
		||||
	d.task.timer.Pause()
 | 
			
		||||
	err := d.task.fetcher.Pause()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	d.emit(EventKeyPause)
 | 
			
		||||
	d.task.Status = base.DownloadStatusPause
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *downloader) Continue() error {
 | 
			
		||||
	d.task.locker.Lock()
 | 
			
		||||
	defer d.task.locker.Unlock()
 | 
			
		||||
	d.task.timer.Continue()
 | 
			
		||||
	err := d.task.fetcher.Continue()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	d.emit(EventKeyContinue)
 | 
			
		||||
	d.task.Status = base.DownloadStatusStart
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *downloader) Listener(fn Listener) {
 | 
			
		||||
	d.listener = fn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *downloader) emit(eventKey EventKey, errs ...error) {
 | 
			
		||||
	if d.listener != nil {
 | 
			
		||||
		var err error
 | 
			
		||||
		if len(errs) > 0 {
 | 
			
		||||
			err = errs[0]
 | 
			
		||||
		}
 | 
			
		||||
		d.listener(&Event{
 | 
			
		||||
			Key:  eventKey,
 | 
			
		||||
			Task: d.task,
 | 
			
		||||
			Err:  err,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Boot = (*boot)(nil)
 | 
			
		||||
 | 
			
		||||
type Boot interface {
 | 
			
		||||
	URL(url string) Boot
 | 
			
		||||
	Extra(extra any) Boot
 | 
			
		||||
	Listener(listener Listener) Boot
 | 
			
		||||
	Create(opts *base.Options) <-chan error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type boot struct {
 | 
			
		||||
	url        string
 | 
			
		||||
	extra      any
 | 
			
		||||
	listener   Listener
 | 
			
		||||
	downloader *downloader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *boot) resolve() (*base.Resource, error) {
 | 
			
		||||
	return b.downloader.Resolve(&base.Request{
 | 
			
		||||
		URL:   b.url,
 | 
			
		||||
		Extra: b.extra,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *boot) URL(url string) Boot {
 | 
			
		||||
	b.url = url
 | 
			
		||||
	return b
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *boot) Extra(extra any) Boot {
 | 
			
		||||
	b.extra = extra
 | 
			
		||||
	return b
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *boot) Listener(listener Listener) Boot {
 | 
			
		||||
	b.listener = listener
 | 
			
		||||
	return b
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *boot) Create(opts *base.Options) <-chan error {
 | 
			
		||||
	res, err := b.resolve()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		b.downloader.finishedCh <- err
 | 
			
		||||
		return b.downloader.finishedCh
 | 
			
		||||
	}
 | 
			
		||||
	b.downloader.Listener(b.listener)
 | 
			
		||||
 | 
			
		||||
	err = b.downloader.Create(res, opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		b.downloader.finishedCh <- err
 | 
			
		||||
		return b.downloader.finishedCh
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return b.downloader.finishedCh
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New 一个文件对应一个实例
 | 
			
		||||
func New(options ...controller.Option) Boot {
 | 
			
		||||
	return &boot{
 | 
			
		||||
		downloader: newDownloader(http.FetcherBuilder, options...),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										45
									
								
								pkg/downloader/downloader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								pkg/downloader/downloader_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
			
		||||
package downloader
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/downloader/base"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/downloader/controller"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/tool"
 | 
			
		||||
	"golang.org/x/net/proxy"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNewDownloader(t *testing.T) {
 | 
			
		||||
	dialer, err := proxy.SOCKS5("tcp", "127.0.0.1:1080", nil, proxy.Direct)
 | 
			
		||||
	parse, _ := url.Parse(`socks5://127.0.0.1:1080`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err, dialer)
 | 
			
		||||
	}
 | 
			
		||||
	err = <-New(
 | 
			
		||||
		controller.WithDialer(dialer), // todo: use dialer proxy
 | 
			
		||||
		controller.WithCookie(nil),
 | 
			
		||||
		controller.WithProxy(http.ProxyURL(parse)), // todo: use http client proxy
 | 
			
		||||
		controller.WithTimeout(time.Second*3),
 | 
			
		||||
	).URL("http://10.0.1.34/com.tencent.tmgp.jxqy.apk").
 | 
			
		||||
		Listener(func(event *Event) {
 | 
			
		||||
			if event.Key == EventKeyFinally {
 | 
			
		||||
				fmt.Println("下载完成!")
 | 
			
		||||
			}
 | 
			
		||||
			if event.Key == EventKeyProgress {
 | 
			
		||||
				fmt.Printf("下载速度:%s/s 已下载:%s 已用时:%s \n",
 | 
			
		||||
					tool.ByteFmt(event.Task.Progress.Speed),
 | 
			
		||||
					tool.ByteFmt(event.Task.Progress.Downloaded),
 | 
			
		||||
					time.Duration(event.Task.Progress.Used),
 | 
			
		||||
				)
 | 
			
		||||
			}
 | 
			
		||||
		}).
 | 
			
		||||
		Create(&base.Options{
 | 
			
		||||
			Connections: runtime.NumCPU(),
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
	t.Log(err)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								pkg/downloader/event.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								pkg/downloader/event.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package downloader
 | 
			
		||||
 | 
			
		||||
type EventKey string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	EventKeyStart    EventKey = "start"
 | 
			
		||||
	EventKeyPause    EventKey = "pause"
 | 
			
		||||
	EventKeyContinue EventKey = "continue"
 | 
			
		||||
	EventKeyProgress EventKey = "progress"
 | 
			
		||||
	EventKeyError    EventKey = "error"
 | 
			
		||||
	EventKeyDone     EventKey = "done"
 | 
			
		||||
	EventKeyFinally  EventKey = "finally"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Event struct {
 | 
			
		||||
	Key  EventKey
 | 
			
		||||
	Task *TaskInfo
 | 
			
		||||
	Err  error
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										340
									
								
								pkg/duration_fmt/fmt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										340
									
								
								pkg/duration_fmt/fmt.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,340 @@
 | 
			
		||||
package duration_fmt
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	//units, _   = DefaultUnitsCoder.Decode("year,week,day,hour,minute,second,millisecond,microsecond")
 | 
			
		||||
	units, _   = DefaultUnitsCoder.Decode("年,星期,天,小时,分钟,秒,毫秒,微秒")
 | 
			
		||||
	unitsShort = []string{"y", "w", "d", "h", "m", "s", "ms", "µs"}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Durafmt holds the parsed duration and the original input duration.
 | 
			
		||||
type Durafmt struct {
 | 
			
		||||
	duration  time.Duration
 | 
			
		||||
	input     string // Used as reference.
 | 
			
		||||
	limitN    int    // Non-zero to limit only first N elements to output.
 | 
			
		||||
	limitUnit string // Non-empty to limit max unit
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LimitToUnit sets the output format, you will not have unit bigger than the UNIT specified. UNIT = "" means no restriction.
 | 
			
		||||
func (d *Durafmt) LimitToUnit(unit string) *Durafmt {
 | 
			
		||||
	d.limitUnit = unit
 | 
			
		||||
	return d
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LimitFirstN sets the output format, outputing only first N elements. n == 0 means no limit.
 | 
			
		||||
func (d *Durafmt) LimitFirstN(n int) *Durafmt {
 | 
			
		||||
	d.limitN = n
 | 
			
		||||
	return d
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Durafmt) Duration() time.Duration {
 | 
			
		||||
	return d.duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Truncate sets precision
 | 
			
		||||
func (d *Durafmt) Truncate(unit time.Duration) *Durafmt {
 | 
			
		||||
	d.duration = d.duration.Truncate(unit)
 | 
			
		||||
	return d
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Parse creates a new *Durafmt struct, returns error if input is invalid.
 | 
			
		||||
func Parse(dinput time.Duration) *Durafmt {
 | 
			
		||||
	input := dinput.String()
 | 
			
		||||
	return &Durafmt{dinput, input, 0, ""}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseShort creates a new *Durafmt struct, short form, returns error if input is invalid.
 | 
			
		||||
// It's shortcut for `Parse(dur).LimitFirstN(1)`
 | 
			
		||||
func ParseShort(dinput time.Duration) *Durafmt {
 | 
			
		||||
	input := dinput.String()
 | 
			
		||||
	return &Durafmt{dinput, input, 1, ""}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseString creates a new *Durafmt struct from a string.
 | 
			
		||||
// returns an error if input is invalid.
 | 
			
		||||
func ParseString(input string) (*Durafmt, error) {
 | 
			
		||||
	if input == "0" || input == "-0" {
 | 
			
		||||
		return nil, errors.New("durafmt: missing unit in duration " + input)
 | 
			
		||||
	}
 | 
			
		||||
	duration, err := time.ParseDuration(input)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &Durafmt{duration, input, 0, ""}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseStringShort creates a new *Durafmt struct from a string, short form
 | 
			
		||||
// returns an error if input is invalid.
 | 
			
		||||
// It's shortcut for `ParseString(durStr)` and then calling `LimitFirstN(1)`
 | 
			
		||||
func ParseStringShort(input string) (*Durafmt, error) {
 | 
			
		||||
	if input == "0" || input == "-0" {
 | 
			
		||||
		return nil, errors.New("durafmt: missing unit in duration " + input)
 | 
			
		||||
	}
 | 
			
		||||
	duration, err := time.ParseDuration(input)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &Durafmt{duration, input, 1, ""}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// String parses d *Durafmt into a human readable duration with default units.
 | 
			
		||||
func (d *Durafmt) String() string {
 | 
			
		||||
	return d.Format(units)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Format parses d *Durafmt into a human readable duration with units.
 | 
			
		||||
func (d *Durafmt) Format(units Units) string {
 | 
			
		||||
	var duration string
 | 
			
		||||
 | 
			
		||||
	// Check for minus durations.
 | 
			
		||||
	if string(d.input[0]) == "-" {
 | 
			
		||||
		duration += "-"
 | 
			
		||||
		d.duration = -d.duration
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var microseconds int64
 | 
			
		||||
	var milliseconds int64
 | 
			
		||||
	var seconds int64
 | 
			
		||||
	var minutes int64
 | 
			
		||||
	var hours int64
 | 
			
		||||
	var days int64
 | 
			
		||||
	var weeks int64
 | 
			
		||||
	var years int64
 | 
			
		||||
	var shouldConvert = false
 | 
			
		||||
 | 
			
		||||
	remainingSecondsToConvert := int64(d.duration / time.Microsecond)
 | 
			
		||||
 | 
			
		||||
	// Convert duration.
 | 
			
		||||
	if d.limitUnit == "" {
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "years" || shouldConvert {
 | 
			
		||||
		years = remainingSecondsToConvert / (365 * 24 * 3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= years * 365 * 24 * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "weeks" || shouldConvert {
 | 
			
		||||
		weeks = remainingSecondsToConvert / (7 * 24 * 3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= weeks * 7 * 24 * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "days" || shouldConvert {
 | 
			
		||||
		days = remainingSecondsToConvert / (24 * 3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= days * 24 * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "hours" || shouldConvert {
 | 
			
		||||
		hours = remainingSecondsToConvert / (3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= hours * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "minutes" || shouldConvert {
 | 
			
		||||
		minutes = remainingSecondsToConvert / (60 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= minutes * 60 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "seconds" || shouldConvert {
 | 
			
		||||
		seconds = remainingSecondsToConvert / 1000000
 | 
			
		||||
		remainingSecondsToConvert -= seconds * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "milliseconds" || shouldConvert {
 | 
			
		||||
		milliseconds = remainingSecondsToConvert / 1000
 | 
			
		||||
		remainingSecondsToConvert -= milliseconds * 1000
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	microseconds = remainingSecondsToConvert
 | 
			
		||||
 | 
			
		||||
	// Create a map of the converted duration time.
 | 
			
		||||
	durationMap := []int64{
 | 
			
		||||
		microseconds,
 | 
			
		||||
		milliseconds,
 | 
			
		||||
		seconds,
 | 
			
		||||
		minutes,
 | 
			
		||||
		hours,
 | 
			
		||||
		days,
 | 
			
		||||
		weeks,
 | 
			
		||||
		years,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Construct duration string.
 | 
			
		||||
	for i, u := range units.Units() {
 | 
			
		||||
		v := durationMap[7-i]
 | 
			
		||||
		strval := strconv.FormatInt(v, 10)
 | 
			
		||||
		switch {
 | 
			
		||||
		// add to the duration string if v > 1.
 | 
			
		||||
		case v > 1:
 | 
			
		||||
			duration += strval + " " + u.Plural + " "
 | 
			
		||||
		// remove the plural 's', if v is 1.
 | 
			
		||||
		case v == 1:
 | 
			
		||||
			duration += strval + " " + u.Singular + " "
 | 
			
		||||
		// omit any value with 0s or 0.
 | 
			
		||||
		case d.duration.String() == "0" || d.duration.String() == "0s":
 | 
			
		||||
			pattern := fmt.Sprintf("^-?0%s$", unitsShort[i])
 | 
			
		||||
			isMatch, err := regexp.MatchString(pattern, d.input)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return ""
 | 
			
		||||
			}
 | 
			
		||||
			if isMatch {
 | 
			
		||||
				duration += strval + " " + u.Plural
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		// omit any value with 0.
 | 
			
		||||
		case v == 0:
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// trim any remaining spaces.
 | 
			
		||||
	duration = strings.TrimSpace(duration)
 | 
			
		||||
 | 
			
		||||
	// if more than 2 spaces present return the first 2 strings
 | 
			
		||||
	// if short version is requested
 | 
			
		||||
	if d.limitN > 0 {
 | 
			
		||||
		parts := strings.Split(duration, " ")
 | 
			
		||||
		if len(parts) > d.limitN*2 {
 | 
			
		||||
			duration = strings.Join(parts[:d.limitN*2], " ")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Durafmt) InternationalString() string {
 | 
			
		||||
	var duration string
 | 
			
		||||
 | 
			
		||||
	// Check for minus durations.
 | 
			
		||||
	if string(d.input[0]) == "-" {
 | 
			
		||||
		duration += "-"
 | 
			
		||||
		d.duration = -d.duration
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var microseconds int64
 | 
			
		||||
	var milliseconds int64
 | 
			
		||||
	var seconds int64
 | 
			
		||||
	var minutes int64
 | 
			
		||||
	var hours int64
 | 
			
		||||
	var days int64
 | 
			
		||||
	var weeks int64
 | 
			
		||||
	var years int64
 | 
			
		||||
	var shouldConvert = false
 | 
			
		||||
 | 
			
		||||
	remainingSecondsToConvert := int64(d.duration / time.Microsecond)
 | 
			
		||||
 | 
			
		||||
	// Convert duration.
 | 
			
		||||
	if d.limitUnit == "" {
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "years" || shouldConvert {
 | 
			
		||||
		years = remainingSecondsToConvert / (365 * 24 * 3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= years * 365 * 24 * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "weeks" || shouldConvert {
 | 
			
		||||
		weeks = remainingSecondsToConvert / (7 * 24 * 3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= weeks * 7 * 24 * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "days" || shouldConvert {
 | 
			
		||||
		days = remainingSecondsToConvert / (24 * 3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= days * 24 * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "hours" || shouldConvert {
 | 
			
		||||
		hours = remainingSecondsToConvert / (3600 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= hours * 3600 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "minutes" || shouldConvert {
 | 
			
		||||
		minutes = remainingSecondsToConvert / (60 * 1000000)
 | 
			
		||||
		remainingSecondsToConvert -= minutes * 60 * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "seconds" || shouldConvert {
 | 
			
		||||
		seconds = remainingSecondsToConvert / 1000000
 | 
			
		||||
		remainingSecondsToConvert -= seconds * 1000000
 | 
			
		||||
		shouldConvert = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if d.limitUnit == "milliseconds" || shouldConvert {
 | 
			
		||||
		milliseconds = remainingSecondsToConvert / 1000
 | 
			
		||||
		remainingSecondsToConvert -= milliseconds * 1000
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	microseconds = remainingSecondsToConvert
 | 
			
		||||
 | 
			
		||||
	// Create a map of the converted duration time.
 | 
			
		||||
	durationMap := map[string]int64{
 | 
			
		||||
		"µs": microseconds,
 | 
			
		||||
		"ms": milliseconds,
 | 
			
		||||
		"s":  seconds,
 | 
			
		||||
		"m":  minutes,
 | 
			
		||||
		"h":  hours,
 | 
			
		||||
		"d":  days,
 | 
			
		||||
		"w":  weeks,
 | 
			
		||||
		"y":  years,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Construct duration string.
 | 
			
		||||
	for i := range units.Units() {
 | 
			
		||||
		u := unitsShort[i]
 | 
			
		||||
		v := durationMap[u]
 | 
			
		||||
		strval := strconv.FormatInt(v, 10)
 | 
			
		||||
		switch {
 | 
			
		||||
		// add to the duration string if v > 0.
 | 
			
		||||
		case v > 0:
 | 
			
		||||
			duration += strval + " " + u + " "
 | 
			
		||||
		// omit any value with 0.
 | 
			
		||||
		case d.duration.String() == "0":
 | 
			
		||||
			pattern := fmt.Sprintf("^-?0%s$", unitsShort[i])
 | 
			
		||||
			isMatch, err := regexp.MatchString(pattern, d.input)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return ""
 | 
			
		||||
			}
 | 
			
		||||
			if isMatch {
 | 
			
		||||
				duration += strval + " " + u
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		// omit any value with 0.
 | 
			
		||||
		case v == 0:
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// trim any remaining spaces.
 | 
			
		||||
	duration = strings.TrimSpace(duration)
 | 
			
		||||
 | 
			
		||||
	// if more than 2 spaces present return the first 2 strings
 | 
			
		||||
	// if short version is requested
 | 
			
		||||
	if d.limitN > 0 {
 | 
			
		||||
		parts := strings.Split(duration, " ")
 | 
			
		||||
		if len(parts) > d.limitN*2 {
 | 
			
		||||
			duration = strings.Join(parts[:d.limitN*2], " ")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Durafmt) TrimSpace() string {
 | 
			
		||||
	return strings.Replace(d.String(), " ", "", -1)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										107
									
								
								pkg/duration_fmt/units.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								pkg/duration_fmt/units.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,107 @@
 | 
			
		||||
package duration_fmt
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DefaultUnitsCoder default units coder using `":"` as PluralSep and `","` as UnitsSep
 | 
			
		||||
var DefaultUnitsCoder = UnitsCoder{":", ","}
 | 
			
		||||
 | 
			
		||||
// Unit the pair of singular and plural units
 | 
			
		||||
type Unit struct {
 | 
			
		||||
	Singular, Plural string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Units duration units
 | 
			
		||||
type Units struct {
 | 
			
		||||
	Year, Week, Day, Hour, Minute,
 | 
			
		||||
	Second, Millisecond, Microsecond Unit
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Units return a slice of units
 | 
			
		||||
func (u Units) Units() []Unit {
 | 
			
		||||
	return []Unit{u.Year, u.Week, u.Day, u.Hour, u.Minute,
 | 
			
		||||
		u.Second, u.Millisecond, u.Microsecond}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UnitsCoder the units encoder and decoder
 | 
			
		||||
type UnitsCoder struct {
 | 
			
		||||
	// PluralSep char to sep singular and plural pair.
 | 
			
		||||
	// Example with char `":"`: `"year:year"` (english) or `"mês:meses"` (portuguese)
 | 
			
		||||
	PluralSep,
 | 
			
		||||
	// UnitsSep char to sep units (singular and plural pairs).
 | 
			
		||||
	// Example with char `","`: `"year:year,week:weeks"` (english) or `"mês:meses,semana:semanas"` (portuguese)
 | 
			
		||||
	UnitsSep string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Encode encodes input Units to string
 | 
			
		||||
// Examples with `UnitsCoder{PluralSep: ":", UnitsSep = ","}`
 | 
			
		||||
// 	- singular and plural pair units: `"year:wers,week:weeks,day:days,hour:hours,minute:minutes,second:seconds,millisecond:millliseconds,microsecond:microsseconds"`
 | 
			
		||||
func (coder UnitsCoder) Encode(units Units) string {
 | 
			
		||||
	var pairs = make([]string, 8)
 | 
			
		||||
	for i, u := range units.Units() {
 | 
			
		||||
		pairs[i] = u.Singular + coder.PluralSep + u.Plural
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Join(pairs, coder.UnitsSep)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Decode decodes input string to Units.
 | 
			
		||||
// The input must follow the following formats:
 | 
			
		||||
// - Unit format (singular and plural pair)
 | 
			
		||||
// 	- must singular (the plural receives 's' character as suffix)
 | 
			
		||||
//	- singular and plural: separated by `PluralSep` char
 | 
			
		||||
//		Example with char `":"`: `"year:year"` (english) or `"mês:meses"` (portuguese)
 | 
			
		||||
// - Units format (pairs of  Year, Week, Day, Hour, Minute,
 | 
			
		||||
//	Second, Millisecond and Microsecond units) separated by `UnitsSep` char
 | 
			
		||||
// 	- Examples with `UnitsCoder{PluralSep: ":", UnitsSep = ","}`
 | 
			
		||||
// 		- must singular units: `"year,week,day,hour,minute,second,millisecond,microsecond"`
 | 
			
		||||
// 		- mixed units: `"year,week:weeks,day,hour,minute:minutes,second,millisecond,microsecond"`
 | 
			
		||||
// 		- singular and plural pair units: `"year:wers,week:weeks,day:days,hour:hours,minute:minutes,second:seconds,millisecond:millliseconds,microsecond:microsseconds"`
 | 
			
		||||
func (coder UnitsCoder) Decode(s string) (units Units, err error) {
 | 
			
		||||
	parts := strings.Split(s, coder.UnitsSep)
 | 
			
		||||
	if len(parts) != 8 {
 | 
			
		||||
		err = fmt.Errorf("bad parts length")
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var parse = func(name, part string, u *Unit) bool {
 | 
			
		||||
		ps := strings.Split(part, coder.PluralSep)
 | 
			
		||||
		switch len(ps) {
 | 
			
		||||
		case 1:
 | 
			
		||||
			u.Singular, u.Plural = ps[0], ps[0]
 | 
			
		||||
		case 2:
 | 
			
		||||
			u.Singular, u.Plural = ps[0], ps[1]
 | 
			
		||||
		default:
 | 
			
		||||
			err = fmt.Errorf("bad unit %q pair length", name)
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !parse("Year", parts[0], &units.Year) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	if !parse("Week", parts[1], &units.Week) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	if !parse("Day", parts[2], &units.Day) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	if !parse("Hour", parts[3], &units.Hour) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	if !parse("Minute", parts[4], &units.Minute) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	if !parse("Second", parts[5], &units.Second) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	if !parse("Millisecond", parts[6], &units.Millisecond) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	if !parse("Microsecond", parts[7], &units.Microsecond) {
 | 
			
		||||
		return units, err
 | 
			
		||||
	}
 | 
			
		||||
	return units, err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										80
									
								
								pkg/env/env.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								pkg/env/env.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,80 @@
 | 
			
		||||
package env
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	active Environment
 | 
			
		||||
	dev    Environment = &environment{value: "dev"} //开发环境
 | 
			
		||||
	fat    Environment = &environment{value: "fat"} //测试环境
 | 
			
		||||
	uat    Environment = &environment{value: "uat"} //预上线环境
 | 
			
		||||
	pro    Environment = &environment{value: "pro"} //正式环境
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Environment = (*environment)(nil)
 | 
			
		||||
 | 
			
		||||
// Environment 环境配置
 | 
			
		||||
type Environment interface {
 | 
			
		||||
	Value() string
 | 
			
		||||
	IsDev() bool
 | 
			
		||||
	IsFat() bool
 | 
			
		||||
	IsUat() bool
 | 
			
		||||
	IsPro() bool
 | 
			
		||||
	t()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type environment struct {
 | 
			
		||||
	value string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *environment) Value() string {
 | 
			
		||||
	return e.value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *environment) IsDev() bool {
 | 
			
		||||
	return e.value == "dev"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *environment) IsFat() bool {
 | 
			
		||||
	return e.value == "fat"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *environment) IsUat() bool {
 | 
			
		||||
	return e.value == "uat"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *environment) IsPro() bool {
 | 
			
		||||
	return e.value == "pro"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *environment) t() {}
 | 
			
		||||
 | 
			
		||||
func Set(env string) error {
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	switch strings.ToLower(strings.TrimSpace(env)) {
 | 
			
		||||
	case "dev":
 | 
			
		||||
		active = dev
 | 
			
		||||
	case "fat":
 | 
			
		||||
		active = fat
 | 
			
		||||
	case "uat":
 | 
			
		||||
		active = uat
 | 
			
		||||
	case "pro":
 | 
			
		||||
		active = pro
 | 
			
		||||
	default:
 | 
			
		||||
		err = fmt.Errorf("'%s' cannot be found, or it is illegal. enum:dev(开发环境),fat(测试环境),uat(预上线环境),pro(正式环境)", env)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Active 当前配置的env
 | 
			
		||||
func Active() Environment {
 | 
			
		||||
	if active == nil {
 | 
			
		||||
		fmt.Println("Warning: environment not set. The default 'dev' will be used.")
 | 
			
		||||
		active = dev
 | 
			
		||||
	}
 | 
			
		||||
	return active
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										74
									
								
								pkg/errno/errno.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								pkg/errno/errno.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,74 @@
 | 
			
		||||
package errno
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Error = (*err)(nil)
 | 
			
		||||
 | 
			
		||||
type Error interface {
 | 
			
		||||
	// WithErr 设置错误信息
 | 
			
		||||
	WithErr(err error) Error
 | 
			
		||||
	// GetBusinessCode 获取 Business Code
 | 
			
		||||
	GetBusinessCode() int
 | 
			
		||||
	// GetHttpCode 获取 HTTP Code
 | 
			
		||||
	GetHttpCode() int
 | 
			
		||||
	// GetMsg 获取 Msg
 | 
			
		||||
	GetMsg() string
 | 
			
		||||
	// GetErr 获取错误信息
 | 
			
		||||
	GetErr() error
 | 
			
		||||
	// ToString 返回 JSON 格式的错误详情
 | 
			
		||||
	ToString() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type err struct {
 | 
			
		||||
	HttpCode     int    // HTTP Code
 | 
			
		||||
	BusinessCode int    // Business Code
 | 
			
		||||
	Message      string // 描述信息
 | 
			
		||||
	Err          error  // 错误信息
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewError(httpCode, businessCode int, msg string) Error {
 | 
			
		||||
	return &err{
 | 
			
		||||
		HttpCode:     httpCode,
 | 
			
		||||
		BusinessCode: businessCode,
 | 
			
		||||
		Message:      msg,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *err) WithErr(err error) Error {
 | 
			
		||||
	e.Err = err
 | 
			
		||||
	return e
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *err) GetHttpCode() int {
 | 
			
		||||
	return e.HttpCode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *err) GetBusinessCode() int {
 | 
			
		||||
	return e.BusinessCode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *err) GetMsg() string {
 | 
			
		||||
	return e.Message
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (e *err) GetErr() error {
 | 
			
		||||
	return e.Err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ToString 返回 JSON 格式的错误详情
 | 
			
		||||
func (e *err) ToString() string {
 | 
			
		||||
	err := &struct {
 | 
			
		||||
		HttpCode     int    `json:"http_code"`
 | 
			
		||||
		BusinessCode int    `json:"business_code"`
 | 
			
		||||
		Message      string `json:"message"`
 | 
			
		||||
	}{
 | 
			
		||||
		HttpCode:     e.HttpCode,
 | 
			
		||||
		BusinessCode: e.BusinessCode,
 | 
			
		||||
		Message:      e.Message,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	raw, _ := json.Marshal(err)
 | 
			
		||||
	return string(raw)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										111
									
								
								pkg/excel/export.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								pkg/excel/export.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,111 @@
 | 
			
		||||
package excel
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/xuri/excelize/v2"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const maxCharCount = 26
 | 
			
		||||
 | 
			
		||||
var DefaultColumnWidth float64 = 20
 | 
			
		||||
 | 
			
		||||
type ColumnOption struct {
 | 
			
		||||
	Field   string
 | 
			
		||||
	Comment string
 | 
			
		||||
	Width   float64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Export(sheetName, filepath string, columns []ColumnOption, rows []map[string]any) error {
 | 
			
		||||
	f := excelize.NewFile()
 | 
			
		||||
	sheetIndex, err := f.NewSheet(sheetName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	_ = f.DeleteSheet("Sheet1")
 | 
			
		||||
	_ = f.SetColWidth(sheetName, "A", string(byte('A'+len(columns)-1)), DefaultColumnWidth)
 | 
			
		||||
	contentStyle, _ := f.NewStyle(&excelize.Style{
 | 
			
		||||
		Alignment: &excelize.Alignment{
 | 
			
		||||
			Horizontal: "center",
 | 
			
		||||
			Vertical:   "center",
 | 
			
		||||
			WrapText:   true,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	titleStyle, _ := f.NewStyle(&excelize.Style{
 | 
			
		||||
		Alignment: &excelize.Alignment{
 | 
			
		||||
			Horizontal: "center",
 | 
			
		||||
			Vertical:   "center",
 | 
			
		||||
			WrapText:   true,
 | 
			
		||||
		},
 | 
			
		||||
		Font: &excelize.Font{
 | 
			
		||||
			Bold: true,
 | 
			
		||||
			Size: 14,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	maxColumnRowNameLen := 1 + len(strconv.Itoa(len(rows)))
 | 
			
		||||
	columnCount := len(columns)
 | 
			
		||||
	if columnCount > maxCharCount {
 | 
			
		||||
		maxColumnRowNameLen++
 | 
			
		||||
	} else if columnCount > maxCharCount*maxCharCount {
 | 
			
		||||
		maxColumnRowNameLen += 2
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//标题
 | 
			
		||||
	type columnItem struct {
 | 
			
		||||
		RowName   []byte
 | 
			
		||||
		FieldName string
 | 
			
		||||
	}
 | 
			
		||||
	columnNames := make([]columnItem, 0)
 | 
			
		||||
	for index, column := range columns {
 | 
			
		||||
		columnName := getColumnName(index, maxColumnRowNameLen)
 | 
			
		||||
		if column.Width > 0 {
 | 
			
		||||
			_ = f.SetColWidth(sheetName, string(columnName), string(columnName), column.Width)
 | 
			
		||||
		}
 | 
			
		||||
		columnNames = append(columnNames, columnItem{FieldName: column.Field, RowName: columnName})
 | 
			
		||||
		rowName := getColumnRowName(columnName, 1)
 | 
			
		||||
		err := f.SetCellValue(sheetName, rowName, column.Comment)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		_ = f.SetCellStyle(sheetName, rowName, rowName, titleStyle)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//正文
 | 
			
		||||
	for rowIndex, row := range rows {
 | 
			
		||||
		for _, item := range columnNames {
 | 
			
		||||
			rowName := getColumnRowName(item.RowName, rowIndex+2)
 | 
			
		||||
			err := f.SetCellValue(sheetName, rowName, row[item.FieldName])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			_ = f.SetCellStyle(sheetName, rowName, rowName, contentStyle)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	f.SetActiveSheet(sheetIndex)
 | 
			
		||||
 | 
			
		||||
	err = f.SaveAs(filepath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getColumnName(column, maxColumnRowNameLen int) []byte {
 | 
			
		||||
	const A = 'A'
 | 
			
		||||
	if column < maxCharCount {
 | 
			
		||||
		slice := make([]byte, 0, maxColumnRowNameLen)
 | 
			
		||||
		return append(slice, byte(A+column))
 | 
			
		||||
	} else {
 | 
			
		||||
		return append(getColumnName(column/maxCharCount-1, maxColumnRowNameLen), byte(A+column%maxCharCount))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getColumnRowName(columnName []byte, rowIndex int) (columnRowName string) {
 | 
			
		||||
	l := len(columnName)
 | 
			
		||||
	columnName = strconv.AppendInt(columnName, int64(rowIndex), 10)
 | 
			
		||||
	columnRowName = string(columnName)
 | 
			
		||||
	columnName = columnName[:l]
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										189
									
								
								pkg/file/file.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								pkg/file/file.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,189 @@
 | 
			
		||||
package file
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	buffSize = 1 << 20
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ReadLineFromEnd --
 | 
			
		||||
type ReadLineFromEnd struct {
 | 
			
		||||
	f *os.File
 | 
			
		||||
 | 
			
		||||
	fileSize int
 | 
			
		||||
	bwr      *bytes.Buffer
 | 
			
		||||
	lineBuff []byte
 | 
			
		||||
	swapBuff []byte
 | 
			
		||||
 | 
			
		||||
	isFirst bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Exists
 | 
			
		||||
func IsExists(path string) (os.FileInfo, bool) {
 | 
			
		||||
	f, err := os.Stat(path)
 | 
			
		||||
	return f, err == nil || os.IsExist(err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewReadLineFromEnd
 | 
			
		||||
func NewReadLineFromEnd(filename string) (rd *ReadLineFromEnd, err error) {
 | 
			
		||||
	f, err := os.Open(filename)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	info, err := f.Stat()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if info.IsDir() {
 | 
			
		||||
		return nil, fmt.Errorf("not file")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fileSize := int(info.Size())
 | 
			
		||||
 | 
			
		||||
	rd = &ReadLineFromEnd{
 | 
			
		||||
		f:        f,
 | 
			
		||||
		fileSize: fileSize,
 | 
			
		||||
		bwr:      bytes.NewBuffer([]byte{}),
 | 
			
		||||
		lineBuff: make([]byte, 0),
 | 
			
		||||
		swapBuff: make([]byte, buffSize),
 | 
			
		||||
		isFirst:  true,
 | 
			
		||||
	}
 | 
			
		||||
	return rd, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadLine 结尾包含'\n'
 | 
			
		||||
func (c *ReadLineFromEnd) ReadLine() (line []byte, err error) {
 | 
			
		||||
	var ok bool
 | 
			
		||||
	for {
 | 
			
		||||
		ok, err = c.buff()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if ok {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	line, err = c.bwr.ReadBytes('\n')
 | 
			
		||||
	if err == io.EOF && c.fileSize > 0 {
 | 
			
		||||
		err = nil
 | 
			
		||||
	}
 | 
			
		||||
	return line, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close --
 | 
			
		||||
func (c *ReadLineFromEnd) Close() (err error) {
 | 
			
		||||
	return c.f.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ReadLineFromEnd) buff() (ok bool, err error) {
 | 
			
		||||
	if c.fileSize == 0 {
 | 
			
		||||
		return true, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if c.bwr.Len() >= buffSize {
 | 
			
		||||
		return true, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	offset := 0
 | 
			
		||||
	if c.fileSize > buffSize {
 | 
			
		||||
		offset = c.fileSize - buffSize
 | 
			
		||||
	}
 | 
			
		||||
	_, err = c.f.Seek(int64(offset), 0)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	n, err := c.f.Read(c.swapBuff)
 | 
			
		||||
	if err != nil && err != io.EOF {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	if c.fileSize < n {
 | 
			
		||||
		n = c.fileSize
 | 
			
		||||
	}
 | 
			
		||||
	if n == 0 {
 | 
			
		||||
		return true, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		m := bytes.LastIndex(c.swapBuff[:n], []byte{'\n'})
 | 
			
		||||
		if m == -1 {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if m < n-1 {
 | 
			
		||||
			err = c.writeLine(c.swapBuff[m+1 : n])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return false, err
 | 
			
		||||
			}
 | 
			
		||||
			ok = true
 | 
			
		||||
		} else if m == n-1 && !c.isFirst {
 | 
			
		||||
			err = c.writeLine(nil)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return false, err
 | 
			
		||||
			}
 | 
			
		||||
			ok = true
 | 
			
		||||
		}
 | 
			
		||||
		n = m
 | 
			
		||||
		if n == 0 {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if n > 0 {
 | 
			
		||||
		reverseBytes(c.swapBuff[:n])
 | 
			
		||||
		c.lineBuff = append(c.lineBuff, c.swapBuff[:n]...)
 | 
			
		||||
	}
 | 
			
		||||
	if offset == 0 {
 | 
			
		||||
		err = c.writeLine(nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return false, err
 | 
			
		||||
		}
 | 
			
		||||
		ok = true
 | 
			
		||||
	}
 | 
			
		||||
	c.fileSize = offset
 | 
			
		||||
	if c.isFirst {
 | 
			
		||||
		c.isFirst = false
 | 
			
		||||
	}
 | 
			
		||||
	return ok, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ReadLineFromEnd) writeLine(b []byte) (err error) {
 | 
			
		||||
	if len(b) > 0 {
 | 
			
		||||
		_, err = c.bwr.Write(b)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if len(c.lineBuff) > 0 {
 | 
			
		||||
		reverseBytes(c.lineBuff)
 | 
			
		||||
		_, err = c.bwr.Write(c.lineBuff)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		c.lineBuff = c.lineBuff[:0]
 | 
			
		||||
	}
 | 
			
		||||
	_, err = c.bwr.Write([]byte{'\n'})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func reverseBytes(b []byte) {
 | 
			
		||||
	n := len(b)
 | 
			
		||||
	if n <= 1 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	for i := 0; i < n; i++ {
 | 
			
		||||
		k := n - 1
 | 
			
		||||
		if k != i {
 | 
			
		||||
			b[i], b[k] = b[k], b[i]
 | 
			
		||||
		}
 | 
			
		||||
		n--
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										98
									
								
								pkg/grpclient/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								pkg/grpclient/client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,98 @@
 | 
			
		||||
package grpclient
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"google.golang.org/grpc"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
	"google.golang.org/grpc/credentials/insecure"
 | 
			
		||||
	"google.golang.org/grpc/keepalive"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	defaultDialTimeout = time.Second * 2
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	credential       credentials.TransportCredentials
 | 
			
		||||
	keepalive        *keepalive.ClientParameters
 | 
			
		||||
	dialTimeout      time.Duration
 | 
			
		||||
	unaryInterceptor grpc.UnaryClientInterceptor
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithCredential setup credential for tls
 | 
			
		||||
func WithCredential(credential credentials.TransportCredentials) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.credential = credential
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithKeepAlive setup keepalive parameters
 | 
			
		||||
func WithKeepAlive(keepalive *keepalive.ClientParameters) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.keepalive = keepalive
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithDialTimeout set up the dial timeout
 | 
			
		||||
func WithDialTimeout(timeout time.Duration) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.dialTimeout = timeout
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WithUnaryInterceptor(unaryInterceptor grpc.UnaryClientInterceptor) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.unaryInterceptor = unaryInterceptor
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(target string, options ...Option) (*grpc.ClientConn, error) {
 | 
			
		||||
	if target == "" {
 | 
			
		||||
		return nil, errors.New("target required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	opt := new(option)
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	kacp := defaultKeepAlive
 | 
			
		||||
	if opt.keepalive != nil {
 | 
			
		||||
		kacp = opt.keepalive
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dialTimeout := defaultDialTimeout
 | 
			
		||||
	if opt.dialTimeout > 0 {
 | 
			
		||||
		dialTimeout = opt.dialTimeout
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dialOptions := []grpc.DialOption{
 | 
			
		||||
		grpc.WithBlock(),
 | 
			
		||||
		grpc.WithKeepaliveParams(*kacp),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.unaryInterceptor != nil {
 | 
			
		||||
		dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(opt.unaryInterceptor))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.credential == nil {
 | 
			
		||||
		dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
 | 
			
		||||
	} else {
 | 
			
		||||
		dialOptions = append(dialOptions, grpc.WithTransportCredentials(opt.credential))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	conn, err := grpc.DialContext(ctx, target, dialOptions...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return conn, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										15
									
								
								pkg/grpclient/keepalive.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								pkg/grpclient/keepalive.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
			
		||||
package grpclient
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"google.golang.org/grpc/keepalive"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	defaultKeepAlive = &keepalive.ClientParameters{
 | 
			
		||||
		Time:                10 * time.Second,
 | 
			
		||||
		Timeout:             time.Second,
 | 
			
		||||
		PermitWithoutStream: true,
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										25
									
								
								pkg/hash/hash.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								pkg/hash/hash.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
package hash
 | 
			
		||||
 | 
			
		||||
var _ Hash = (*hash)(nil)
 | 
			
		||||
 | 
			
		||||
type Hash interface {
 | 
			
		||||
	i()
 | 
			
		||||
 | 
			
		||||
	// hashids
 | 
			
		||||
	HashidsEncode(params []int) (string, error)
 | 
			
		||||
	HashidsDecode(hash string) ([]int, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type hash struct {
 | 
			
		||||
	secret string
 | 
			
		||||
	length int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(secret string, length int) Hash {
 | 
			
		||||
	return &hash{
 | 
			
		||||
		secret: secret,
 | 
			
		||||
		length: length,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *hash) i() {}
 | 
			
		||||
							
								
								
									
										41
									
								
								pkg/hash/hash_hashids.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								pkg/hash/hash_hashids.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
package hash
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/speps/go-hashids"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (h *hash) HashidsEncode(params []int) (string, error) {
 | 
			
		||||
	hd := hashids.NewData()
 | 
			
		||||
	hd.Salt = h.secret
 | 
			
		||||
	hd.MinLength = h.length
 | 
			
		||||
 | 
			
		||||
	hashID, err := hashids.NewWithData(hd)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hashStr, err := hashID.Encode(params)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return hashStr, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *hash) HashidsDecode(hash string) ([]int, error) {
 | 
			
		||||
	hd := hashids.NewData()
 | 
			
		||||
	hd.Salt = h.secret
 | 
			
		||||
	hd.MinLength = h.length
 | 
			
		||||
 | 
			
		||||
	hashID, err := hashids.NewWithData(hd)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ids, err := hashID.DecodeWithError(hash)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ids, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										55
									
								
								pkg/hmac/hmac.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								pkg/hmac/hmac.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
			
		||||
package hmac
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	baseHmac "crypto/hmac"
 | 
			
		||||
	"crypto/sha1"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ HMAC = (*hmac)(nil)
 | 
			
		||||
 | 
			
		||||
type HMAC interface {
 | 
			
		||||
	i()
 | 
			
		||||
 | 
			
		||||
	Sha1ToString(data string) string
 | 
			
		||||
	Sha1ToBase64String(data string) string
 | 
			
		||||
 | 
			
		||||
	Sha256ToString(data string) string
 | 
			
		||||
	Sha256ToBase64String(data string) string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type hmac struct {
 | 
			
		||||
	secret string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(secret string) HMAC {
 | 
			
		||||
	return &hmac{secret: secret}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *hmac) i() {}
 | 
			
		||||
 | 
			
		||||
func (m *hmac) Sha256ToString(data string) string {
 | 
			
		||||
	h := baseHmac.New(sha256.New, []byte(m.secret))
 | 
			
		||||
	h.Write([]byte(data))
 | 
			
		||||
	return hex.EncodeToString(h.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *hmac) Sha256ToBase64String(data string) string {
 | 
			
		||||
	h := baseHmac.New(sha256.New, []byte(m.secret))
 | 
			
		||||
	h.Write([]byte(data))
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *hmac) Sha1ToString(data string) string {
 | 
			
		||||
	h := baseHmac.New(sha1.New, []byte(m.secret))
 | 
			
		||||
	h.Write([]byte(data))
 | 
			
		||||
	return hex.EncodeToString(h.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *hmac) Sha1ToBase64String(data string) string {
 | 
			
		||||
	h := baseHmac.New(sha1.New, []byte(m.secret))
 | 
			
		||||
	h.Write([]byte(data))
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										29
									
								
								pkg/httpclient/alarm.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								pkg/httpclient/alarm.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
			
		||||
package httpclient
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Verify parse the body and verify that it is correct
 | 
			
		||||
type AlarmVerify func(body []byte) (shouldAlarm bool)
 | 
			
		||||
 | 
			
		||||
type AlarmObject interface {
 | 
			
		||||
	Send(subject, body string) error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func onFailedAlarm(title string, raw []byte, logger *zap.Logger, alarmObject AlarmObject) {
 | 
			
		||||
	buf := bytes.NewBuffer(nil)
 | 
			
		||||
 | 
			
		||||
	scanner := bufio.NewScanner(bytes.NewReader(raw))
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		buf.WriteString(scanner.Text())
 | 
			
		||||
		buf.WriteString(" <br/>")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := alarmObject.Send(title, buf.String()); err != nil && logger != nil {
 | 
			
		||||
		logger.Error("calls failed alarm err", zap.Error(err))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										395
									
								
								pkg/httpclient/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										395
									
								
								pkg/httpclient/client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,395 @@
 | 
			
		||||
package httpclient
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/spf13/cast"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	httpURL "net/url"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// DefaultTTL 一次http请求最长执行1分钟
 | 
			
		||||
	DefaultTTL = time.Minute
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Get 请求
 | 
			
		||||
func Get(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withoutBody(http.MethodGet, url, form, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Delete delete 请求
 | 
			
		||||
func Delete(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withoutBody(http.MethodDelete, url, form, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func withoutBody(method, url string, form httpURL.Values, options ...Option) (body []byte, err error) {
 | 
			
		||||
	if url == "" {
 | 
			
		||||
		return nil, errors.New("url required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(form) > 0 {
 | 
			
		||||
		if url, err = addFormValuesIntoURL(url, form); err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
 | 
			
		||||
	opt := getOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.trace != nil {
 | 
			
		||||
			opt.dialog.Success = err == nil
 | 
			
		||||
			opt.dialog.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.trace.AppendDialog(opt.dialog)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		releaseOption(opt)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
	opt.header["Content-Type"] = []string{"application/x-www-form-urlencoded; charset=utf-8"}
 | 
			
		||||
	if opt.trace != nil {
 | 
			
		||||
		opt.header[trace.Header] = []string{opt.trace.ID()}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ttl := opt.ttl
 | 
			
		||||
	if ttl <= 0 {
 | 
			
		||||
		ttl = DefaultTTL
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), ttl)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	if opt.dialog != nil {
 | 
			
		||||
		decodedURL, _ := httpURL.QueryUnescape(url)
 | 
			
		||||
		opt.dialog.Request = &trace.Request{
 | 
			
		||||
			TTL:        ttl.String(),
 | 
			
		||||
			Method:     method,
 | 
			
		||||
			DecodedURL: decodedURL,
 | 
			
		||||
			Header:     opt.header,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	retryTimes := opt.retryTimes
 | 
			
		||||
	if retryTimes <= 0 {
 | 
			
		||||
		retryTimes = DefaultRetryTimes
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	retryDelay := opt.retryDelay
 | 
			
		||||
	if retryDelay <= 0 {
 | 
			
		||||
		retryDelay = DefaultRetryDelay
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var httpCode int
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.alarmObject == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if opt.alarmVerify != nil && !opt.alarmVerify(body) && err == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		info := &struct {
 | 
			
		||||
			TraceID string `json:"trace_id"`
 | 
			
		||||
			Request struct {
 | 
			
		||||
				Method string `json:"method"`
 | 
			
		||||
				URL    string `json:"url"`
 | 
			
		||||
			} `json:"request"`
 | 
			
		||||
			Response struct {
 | 
			
		||||
				HTTPCode int    `json:"http_code"`
 | 
			
		||||
				Body     string `json:"body"`
 | 
			
		||||
			} `json:"response"`
 | 
			
		||||
			Error string `json:"error"`
 | 
			
		||||
		}{}
 | 
			
		||||
 | 
			
		||||
		if opt.trace != nil {
 | 
			
		||||
			info.TraceID = opt.trace.ID()
 | 
			
		||||
		}
 | 
			
		||||
		info.Request.Method = method
 | 
			
		||||
		info.Request.URL = url
 | 
			
		||||
		info.Response.HTTPCode = httpCode
 | 
			
		||||
		info.Response.Body = string(body)
 | 
			
		||||
		info.Error = ""
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			info.Error = fmt.Sprintf("%+v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		raw, _ := json.MarshalIndent(info, "", " ")
 | 
			
		||||
		onFailedAlarm(opt.alarmTitle, raw, opt.logger, opt.alarmObject)
 | 
			
		||||
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for k := 0; k < retryTimes; k++ {
 | 
			
		||||
		body, httpCode, err = doHTTP(ctx, method, url, nil, opt)
 | 
			
		||||
		if shouldRetry(ctx, httpCode) || (opt.retryVerify != nil && opt.retryVerify(body)) {
 | 
			
		||||
			time.Sleep(retryDelay)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PostForm post form 请求
 | 
			
		||||
func PostForm(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withFormBody(http.MethodPost, url, form, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PostJSON post json 请求
 | 
			
		||||
func PostJSON(url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withJSONBody(http.MethodPost, url, raw, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PutForm put form 请求
 | 
			
		||||
func PutForm(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withFormBody(http.MethodPut, url, form, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PutJSON put json 请求
 | 
			
		||||
func PutJSON(url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withJSONBody(http.MethodPut, url, raw, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PatchFrom patch form 请求
 | 
			
		||||
func PatchFrom(url string, form httpURL.Values, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withFormBody(http.MethodPatch, url, form, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PatchJSON patch json 请求
 | 
			
		||||
func PatchJSON(url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
 | 
			
		||||
	return withJSONBody(http.MethodPatch, url, raw, options...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func withFormBody(method, url string, form httpURL.Values, options ...Option) (body []byte, err error) {
 | 
			
		||||
	if url == "" {
 | 
			
		||||
		return nil, errors.New("url required")
 | 
			
		||||
	}
 | 
			
		||||
	if len(form) == 0 {
 | 
			
		||||
		return nil, errors.New("form required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
 | 
			
		||||
	opt := getOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.trace != nil {
 | 
			
		||||
			opt.dialog.Success = err == nil
 | 
			
		||||
			opt.dialog.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.trace.AppendDialog(opt.dialog)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		releaseOption(opt)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
	opt.header["Content-Type"] = []string{"application/x-www-form-urlencoded; charset=utf-8"}
 | 
			
		||||
	if opt.trace != nil {
 | 
			
		||||
		opt.header[trace.Header] = []string{opt.trace.ID()}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ttl := opt.ttl
 | 
			
		||||
	if ttl <= 0 {
 | 
			
		||||
		ttl = DefaultTTL
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), ttl)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	formValue := form.Encode()
 | 
			
		||||
	if opt.dialog != nil {
 | 
			
		||||
		decodedURL, _ := httpURL.QueryUnescape(url)
 | 
			
		||||
		opt.dialog.Request = &trace.Request{
 | 
			
		||||
			TTL:        ttl.String(),
 | 
			
		||||
			Method:     method,
 | 
			
		||||
			DecodedURL: decodedURL,
 | 
			
		||||
			Header:     opt.header,
 | 
			
		||||
			Body:       formValue,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	retryTimes := opt.retryTimes
 | 
			
		||||
	if retryTimes <= 0 {
 | 
			
		||||
		retryTimes = DefaultRetryTimes
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	retryDelay := opt.retryDelay
 | 
			
		||||
	if retryDelay <= 0 {
 | 
			
		||||
		retryDelay = DefaultRetryDelay
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var httpCode int
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.alarmObject == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if opt.alarmVerify != nil && !opt.alarmVerify(body) && err == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		info := &struct {
 | 
			
		||||
			TraceID string `json:"trace_id"`
 | 
			
		||||
			Request struct {
 | 
			
		||||
				Method string `json:"method"`
 | 
			
		||||
				URL    string `json:"url"`
 | 
			
		||||
			} `json:"request"`
 | 
			
		||||
			Response struct {
 | 
			
		||||
				HTTPCode int    `json:"http_code"`
 | 
			
		||||
				Body     string `json:"body"`
 | 
			
		||||
			} `json:"response"`
 | 
			
		||||
			Error string `json:"error"`
 | 
			
		||||
		}{}
 | 
			
		||||
 | 
			
		||||
		if opt.trace != nil {
 | 
			
		||||
			info.TraceID = opt.trace.ID()
 | 
			
		||||
		}
 | 
			
		||||
		info.Request.Method = method
 | 
			
		||||
		info.Request.URL = url
 | 
			
		||||
		info.Response.HTTPCode = httpCode
 | 
			
		||||
		info.Response.Body = string(body)
 | 
			
		||||
		info.Error = ""
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			info.Error = fmt.Sprintf("%+v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		raw, _ := json.MarshalIndent(info, "", " ")
 | 
			
		||||
		onFailedAlarm(opt.alarmTitle, raw, opt.logger, opt.alarmObject)
 | 
			
		||||
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for k := 0; k < retryTimes; k++ {
 | 
			
		||||
		body, httpCode, err = doHTTP(ctx, method, url, []byte(formValue), opt)
 | 
			
		||||
		if shouldRetry(ctx, httpCode) || (opt.retryVerify != nil && opt.retryVerify(body)) {
 | 
			
		||||
			time.Sleep(retryDelay)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func withJSONBody(method, url string, raw json.RawMessage, options ...Option) (body []byte, err error) {
 | 
			
		||||
	if url == "" {
 | 
			
		||||
		return nil, errors.New("url required")
 | 
			
		||||
	}
 | 
			
		||||
	if len(raw) == 0 {
 | 
			
		||||
		return nil, errors.New("raw required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
 | 
			
		||||
	opt := getOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.trace != nil {
 | 
			
		||||
			opt.dialog.Success = err == nil
 | 
			
		||||
			opt.dialog.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.trace.AppendDialog(opt.dialog)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		releaseOption(opt)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
	opt.header["Content-Type"] = []string{"application/json; charset=utf-8"}
 | 
			
		||||
	if opt.trace != nil {
 | 
			
		||||
		opt.header[trace.Header] = []string{opt.trace.ID()}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ttl := opt.ttl
 | 
			
		||||
	if ttl <= 0 {
 | 
			
		||||
		ttl = DefaultTTL
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.Background(), ttl)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	if opt.dialog != nil {
 | 
			
		||||
		decodedURL, _ := httpURL.QueryUnescape(url)
 | 
			
		||||
		opt.dialog.Request = &trace.Request{
 | 
			
		||||
			TTL:        ttl.String(),
 | 
			
		||||
			Method:     method,
 | 
			
		||||
			DecodedURL: decodedURL,
 | 
			
		||||
			Header:     opt.header,
 | 
			
		||||
			Body:       cast.ToString(raw),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	retryTimes := opt.retryTimes
 | 
			
		||||
	if retryTimes <= 0 {
 | 
			
		||||
		retryTimes = DefaultRetryTimes
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	retryDelay := opt.retryDelay
 | 
			
		||||
	if retryDelay <= 0 {
 | 
			
		||||
		retryDelay = DefaultRetryDelay
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var httpCode int
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.alarmObject == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if opt.alarmVerify != nil && !opt.alarmVerify(body) && err == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		info := &struct {
 | 
			
		||||
			TraceID string `json:"trace_id"`
 | 
			
		||||
			Request struct {
 | 
			
		||||
				Method string `json:"method"`
 | 
			
		||||
				URL    string `json:"url"`
 | 
			
		||||
			} `json:"request"`
 | 
			
		||||
			Response struct {
 | 
			
		||||
				HTTPCode int    `json:"http_code"`
 | 
			
		||||
				Body     string `json:"body"`
 | 
			
		||||
			} `json:"response"`
 | 
			
		||||
			Error string `json:"error"`
 | 
			
		||||
		}{}
 | 
			
		||||
 | 
			
		||||
		if opt.trace != nil {
 | 
			
		||||
			info.TraceID = opt.trace.ID()
 | 
			
		||||
		}
 | 
			
		||||
		info.Request.Method = method
 | 
			
		||||
		info.Request.URL = url
 | 
			
		||||
		info.Response.HTTPCode = httpCode
 | 
			
		||||
		info.Response.Body = string(body)
 | 
			
		||||
		info.Error = ""
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			info.Error = fmt.Sprintf("%+v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		raw, _ := json.MarshalIndent(info, "", " ")
 | 
			
		||||
		onFailedAlarm(opt.alarmTitle, raw, opt.logger, opt.alarmObject)
 | 
			
		||||
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for k := 0; k < retryTimes; k++ {
 | 
			
		||||
		body, httpCode, err = doHTTP(ctx, method, url, raw, opt)
 | 
			
		||||
		if shouldRetry(ctx, httpCode) || (opt.retryVerify != nil && opt.retryVerify(body)) {
 | 
			
		||||
			time.Sleep(retryDelay)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										46
									
								
								pkg/httpclient/error.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								pkg/httpclient/error.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
			
		||||
package httpclient
 | 
			
		||||
 | 
			
		||||
var _ ReplyErr = (*replyErr)(nil)
 | 
			
		||||
 | 
			
		||||
// ReplyErr 错误响应,当 resp.StatusCode != http.StatusOK 时用来包装返回的 httpcode 和 body 。
 | 
			
		||||
type ReplyErr interface {
 | 
			
		||||
	error
 | 
			
		||||
	StatusCode() int
 | 
			
		||||
	Body() []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type replyErr struct {
 | 
			
		||||
	err        error
 | 
			
		||||
	statusCode int
 | 
			
		||||
	body       []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *replyErr) Error() string {
 | 
			
		||||
	return r.err.Error()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *replyErr) StatusCode() int {
 | 
			
		||||
	return r.statusCode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *replyErr) Body() []byte {
 | 
			
		||||
	return r.body
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newReplyErr(statusCode int, body []byte, err error) ReplyErr {
 | 
			
		||||
	return &replyErr{
 | 
			
		||||
		statusCode: statusCode,
 | 
			
		||||
		body:       body,
 | 
			
		||||
		err:        err,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ToReplyErr 尝试将 err 转换为 ReplyErr
 | 
			
		||||
func ToReplyErr(err error) (ReplyErr, bool) {
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return nil, false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	e, ok := err.(ReplyErr)
 | 
			
		||||
	return e, ok
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										138
									
								
								pkg/httpclient/option.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								pkg/httpclient/option.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,138 @@
 | 
			
		||||
package httpclient
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	cache = &sync.Pool{
 | 
			
		||||
		New: func() any {
 | 
			
		||||
			return &option{
 | 
			
		||||
				header: make(map[string][]string),
 | 
			
		||||
			}
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Mock 定义接口Mock数据
 | 
			
		||||
type Mock func() (body []byte)
 | 
			
		||||
 | 
			
		||||
// Option 自定义设置http请求
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
type basicAuth struct {
 | 
			
		||||
	username string
 | 
			
		||||
	password string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	ttl         time.Duration
 | 
			
		||||
	basicAuth   *basicAuth
 | 
			
		||||
	header      map[string][]string
 | 
			
		||||
	trace       *trace.Trace
 | 
			
		||||
	dialog      *trace.Dialog
 | 
			
		||||
	logger      *zap.Logger
 | 
			
		||||
	retryTimes  int
 | 
			
		||||
	retryDelay  time.Duration
 | 
			
		||||
	retryVerify RetryVerify
 | 
			
		||||
	alarmTitle  string
 | 
			
		||||
	alarmObject AlarmObject
 | 
			
		||||
	alarmVerify AlarmVerify
 | 
			
		||||
	mock        Mock
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *option) reset() {
 | 
			
		||||
	o.ttl = 0
 | 
			
		||||
	o.basicAuth = nil
 | 
			
		||||
	o.header = make(map[string][]string)
 | 
			
		||||
	o.trace = nil
 | 
			
		||||
	o.dialog = nil
 | 
			
		||||
	o.logger = nil
 | 
			
		||||
	o.retryTimes = 0
 | 
			
		||||
	o.retryDelay = 0
 | 
			
		||||
	o.retryVerify = nil
 | 
			
		||||
	o.alarmTitle = ""
 | 
			
		||||
	o.alarmObject = nil
 | 
			
		||||
	o.alarmVerify = nil
 | 
			
		||||
	o.mock = nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOption() *option {
 | 
			
		||||
	return cache.Get().(*option)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func releaseOption(opt *option) {
 | 
			
		||||
	opt.reset()
 | 
			
		||||
	cache.Put(opt)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTTL 本次http请求最长执行时间
 | 
			
		||||
func WithTTL(ttl time.Duration) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.ttl = ttl
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithHeader 设置http header,可以调用多次设置多对key-value
 | 
			
		||||
func WithHeader(key, value string) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.header[key] = []string{value}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithBasicAuth 设置基础认证权限
 | 
			
		||||
func WithBasicAuth(username, password string) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.basicAuth = &basicAuth{
 | 
			
		||||
			username: username,
 | 
			
		||||
			password: password,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTrace 设置trace信息
 | 
			
		||||
func WithTrace(t trace.T) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		if t != nil {
 | 
			
		||||
			opt.trace = t.(*trace.Trace)
 | 
			
		||||
			opt.dialog = new(trace.Dialog)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithLogger 设置logger以便打印关键日志
 | 
			
		||||
func WithLogger(logger *zap.Logger) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.logger = logger
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithMock 设置 mock 数据
 | 
			
		||||
func WithMock(m Mock) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.mock = m
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithOnFailedAlarm 设置告警通知
 | 
			
		||||
func WithOnFailedAlarm(alarmTitle string, alarmObject AlarmObject, alarmVerify AlarmVerify) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.alarmTitle = alarmTitle
 | 
			
		||||
		opt.alarmObject = alarmObject
 | 
			
		||||
		opt.alarmVerify = alarmVerify
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithOnFailedRetry 设置失败重试
 | 
			
		||||
func WithOnFailedRetry(retryTimes int, retryDelay time.Duration, retryVerify RetryVerify) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.retryTimes = retryTimes
 | 
			
		||||
		opt.retryDelay = retryDelay
 | 
			
		||||
		opt.retryVerify = retryVerify
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										44
									
								
								pkg/httpclient/retry.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								pkg/httpclient/retry.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package httpclient
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// DefaultRetryTimes 如果请求失败,最多重试3次
 | 
			
		||||
	DefaultRetryTimes = 3
 | 
			
		||||
	// DefaultRetryDelay 在重试前,延迟等待100毫秒
 | 
			
		||||
	DefaultRetryDelay = time.Millisecond * 100
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Verify parse the body and verify that it is correct
 | 
			
		||||
type RetryVerify func(body []byte) (shouldRetry bool)
 | 
			
		||||
 | 
			
		||||
func shouldRetry(ctx context.Context, httpCode int) bool {
 | 
			
		||||
	select {
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		return false
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch httpCode {
 | 
			
		||||
	case
 | 
			
		||||
		_StatusReadRespErr,
 | 
			
		||||
		_StatusDoReqErr,
 | 
			
		||||
 | 
			
		||||
		http.StatusRequestTimeout,
 | 
			
		||||
		http.StatusLocked,
 | 
			
		||||
		http.StatusTooEarly,
 | 
			
		||||
		http.StatusTooManyRequests,
 | 
			
		||||
 | 
			
		||||
		http.StatusServiceUnavailable,
 | 
			
		||||
		http.StatusGatewayTimeout:
 | 
			
		||||
 | 
			
		||||
		return true
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										147
									
								
								pkg/httpclient/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								pkg/httpclient/util.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,147 @@
 | 
			
		||||
package httpclient
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// _StatusReadRespErr read resp body err, should re-call doHTTP again.
 | 
			
		||||
	_StatusReadRespErr = -204
 | 
			
		||||
	// _StatusDoReqErr do req err, should re-call doHTTP again.
 | 
			
		||||
	_StatusDoReqErr = -500
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var defaultClient = &http.Client{
 | 
			
		||||
	Transport: &http.Transport{
 | 
			
		||||
		DisableKeepAlives:  true,
 | 
			
		||||
		DisableCompression: true,
 | 
			
		||||
		TLSClientConfig: &tls.Config{
 | 
			
		||||
			InsecureSkipVerify: true,
 | 
			
		||||
		},
 | 
			
		||||
		MaxIdleConns:        100,
 | 
			
		||||
		MaxConnsPerHost:     100,
 | 
			
		||||
		MaxIdleConnsPerHost: 100,
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doHTTP(ctx context.Context, method, url string, payload []byte, opt *option) ([]byte, int, error) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
 | 
			
		||||
	if mock := opt.mock; mock != nil {
 | 
			
		||||
		if opt.dialog != nil {
 | 
			
		||||
			opt.dialog.AppendResponse(&trace.Response{
 | 
			
		||||
				HttpCode:    http.StatusOK,
 | 
			
		||||
				HttpCodeMsg: http.StatusText(http.StatusOK),
 | 
			
		||||
				Body:        string(mock()),
 | 
			
		||||
				CostSeconds: time.Since(ts).Seconds(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		return mock(), http.StatusOK, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(payload))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, -1, errors.Join(err, fmt.Errorf("new request [%s %s] err", method, url))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for key, value := range opt.header {
 | 
			
		||||
		req.Header.Set(key, value[0])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.basicAuth != nil {
 | 
			
		||||
		req.SetBasicAuth(opt.basicAuth.username, opt.basicAuth.password)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp, err := defaultClient.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = errors.Join(err, fmt.Errorf("do request [%s %s] err", method, url))
 | 
			
		||||
		if opt.dialog != nil {
 | 
			
		||||
			opt.dialog.AppendResponse(&trace.Response{
 | 
			
		||||
				Body:        err.Error(),
 | 
			
		||||
				CostSeconds: time.Since(ts).Seconds(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if opt.logger != nil {
 | 
			
		||||
			opt.logger.Warn("doHTTP got err", zap.Error(err))
 | 
			
		||||
		}
 | 
			
		||||
		return nil, _StatusDoReqErr, err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = resp.Body.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = errors.Join(err, fmt.Errorf("read resp body from [%s %s] err", method, url))
 | 
			
		||||
		if opt.dialog != nil {
 | 
			
		||||
			opt.dialog.AppendResponse(&trace.Response{
 | 
			
		||||
				Body:        err.Error(),
 | 
			
		||||
				CostSeconds: time.Since(ts).Seconds(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if opt.logger != nil {
 | 
			
		||||
			opt.logger.Warn("doHTTP got err", zap.Error(err))
 | 
			
		||||
		}
 | 
			
		||||
		return nil, _StatusReadRespErr, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.dialog != nil {
 | 
			
		||||
			opt.dialog.AppendResponse(&trace.Response{
 | 
			
		||||
				Header:      resp.Header,
 | 
			
		||||
				HttpCode:    resp.StatusCode,
 | 
			
		||||
				HttpCodeMsg: resp.Status,
 | 
			
		||||
				Body:        string(body), // unsafe
 | 
			
		||||
				CostSeconds: time.Since(ts).Seconds(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return nil, resp.StatusCode, newReplyErr(
 | 
			
		||||
			resp.StatusCode,
 | 
			
		||||
			body,
 | 
			
		||||
			fmt.Errorf("do [%s %s] return code: %d message: %s", method, url, resp.StatusCode, string(body)),
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return body, http.StatusOK, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// addFormValuesIntoURL append url.Values into url string
 | 
			
		||||
func addFormValuesIntoURL(rawURL string, form url.Values) (string, error) {
 | 
			
		||||
	if rawURL == "" {
 | 
			
		||||
		return "", errors.New("rawURL required")
 | 
			
		||||
	}
 | 
			
		||||
	if len(form) == 0 {
 | 
			
		||||
		return "", errors.New("form required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	target, err := url.Parse(rawURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", errors.Join(err, fmt.Errorf("parse rawURL `%s` err", rawURL))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	urlValues := target.Query()
 | 
			
		||||
	for key, values := range form {
 | 
			
		||||
		for _, value := range values {
 | 
			
		||||
			urlValues.Add(key, value)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	target.RawQuery = urlValues.Encode()
 | 
			
		||||
	return target.String(), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										71
									
								
								pkg/limiter/rate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								pkg/limiter/rate.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,71 @@
 | 
			
		||||
package limiter
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/ticker"
 | 
			
		||||
	"golang.org/x/time/rate"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ RateLimiter = (*rateLimiter)(nil)
 | 
			
		||||
 | 
			
		||||
type item struct {
 | 
			
		||||
	lastTime time.Time
 | 
			
		||||
	limiter  *rate.Limiter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RateLimiter interface {
 | 
			
		||||
	set(key string) *rate.Limiter
 | 
			
		||||
	get(key string) *rate.Limiter
 | 
			
		||||
	Allow(key string) bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rateLimiter struct {
 | 
			
		||||
	limit   rate.Limit
 | 
			
		||||
	burst   int
 | 
			
		||||
	list    *sync.Map
 | 
			
		||||
	recycle ticker.Ticker
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRateLimiter(limit rate.Limit, burst int) RateLimiter {
 | 
			
		||||
	list := new(sync.Map)
 | 
			
		||||
	t := ticker.New(time.Minute)
 | 
			
		||||
	t.Process(func() {
 | 
			
		||||
		list.Range(func(key, value any) bool {
 | 
			
		||||
			if value.(*item).lastTime.Before(time.Now().Add(-time.Hour)) {
 | 
			
		||||
				list.Delete(key)
 | 
			
		||||
			}
 | 
			
		||||
			return true
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
	return &rateLimiter{
 | 
			
		||||
		list:    list,
 | 
			
		||||
		limit:   limit,
 | 
			
		||||
		recycle: t,
 | 
			
		||||
		burst:   burst,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (i *rateLimiter) set(key string) *rate.Limiter {
 | 
			
		||||
	store := &item{
 | 
			
		||||
		lastTime: time.Now(),
 | 
			
		||||
		limiter:  rate.NewLimiter(i.limit, i.burst),
 | 
			
		||||
	}
 | 
			
		||||
	i.list.Store(key, store)
 | 
			
		||||
	return store.limiter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (i *rateLimiter) get(key string) *rate.Limiter {
 | 
			
		||||
	value, ok := i.list.Load(key)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return i.set(key)
 | 
			
		||||
	}
 | 
			
		||||
	value.(*item).lastTime = time.Now()
 | 
			
		||||
	i.list.Store(key, value)
 | 
			
		||||
	return value.(*item).limiter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (i *rateLimiter) Allow(key string) bool {
 | 
			
		||||
	limiter := i.get(key)
 | 
			
		||||
	return limiter.Allow()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										47
									
								
								pkg/lock/lock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								pkg/lock/lock.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
			
		||||
package lock
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Locker = (*locker)(nil)
 | 
			
		||||
 | 
			
		||||
type Locker interface {
 | 
			
		||||
	condition() bool
 | 
			
		||||
	Lock()
 | 
			
		||||
	Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type locker struct {
 | 
			
		||||
	lock *sync.Mutex
 | 
			
		||||
	cond *sync.Cond
 | 
			
		||||
	v    *atomic.Bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewLocker() Locker {
 | 
			
		||||
	lock := new(sync.Mutex)
 | 
			
		||||
	return &locker{
 | 
			
		||||
		lock: lock,
 | 
			
		||||
		cond: sync.NewCond(lock),
 | 
			
		||||
		v:    new(atomic.Bool),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *locker) condition() bool {
 | 
			
		||||
	return l.v.Load()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *locker) Lock() {
 | 
			
		||||
	l.cond.L.Lock()
 | 
			
		||||
	for l.condition() {
 | 
			
		||||
		l.cond.Wait()
 | 
			
		||||
	}
 | 
			
		||||
	l.v.Store(true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *locker) Unlock() {
 | 
			
		||||
	l.v.Store(false)
 | 
			
		||||
	l.cond.L.Unlock()
 | 
			
		||||
	l.cond.Signal()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										275
									
								
								pkg/logger/logger.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										275
									
								
								pkg/logger/logger.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,275 @@
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"go.uber.org/zap/zapcore"
 | 
			
		||||
	"gopkg.in/natefinch/lumberjack.v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// DefaultLevel the default log level
 | 
			
		||||
	DefaultLevel = zapcore.InfoLevel
 | 
			
		||||
 | 
			
		||||
	// DefaultTimeLayout the default time layout;
 | 
			
		||||
	DefaultTimeLayout = time.DateTime
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Option custom setup config
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	level          zapcore.Level
 | 
			
		||||
	fields         map[string]string
 | 
			
		||||
	file           io.Writer
 | 
			
		||||
	timeLayout     string
 | 
			
		||||
	disableConsole bool
 | 
			
		||||
	disableCaller  bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithDebugLevel only greater than 'level' will output
 | 
			
		||||
func WithDebugLevel() Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.level = zapcore.DebugLevel
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithInfoLevel only greater than 'level' will output
 | 
			
		||||
func WithInfoLevel() Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.level = zapcore.InfoLevel
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithWarnLevel only greater than 'level' will output
 | 
			
		||||
func WithWarnLevel() Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.level = zapcore.WarnLevel
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithErrorLevel only greater than 'level' will output
 | 
			
		||||
func WithErrorLevel() Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.level = zapcore.ErrorLevel
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithField add some field(s) to log
 | 
			
		||||
func WithField(key, value string) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.fields[key] = value
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithFileP write log to some file
 | 
			
		||||
func WithFileP(file string) Option {
 | 
			
		||||
	dir := filepath.Dir(file)
 | 
			
		||||
	if err := os.MkdirAll(dir, 0766); err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	f, err := os.OpenFile(file, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0766)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.file = zapcore.Lock(f)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithFileRotationP write log to some file with rotation
 | 
			
		||||
func WithFileRotationP(file string) Option {
 | 
			
		||||
	dir := filepath.Dir(file)
 | 
			
		||||
	if err := os.MkdirAll(dir, 0766); err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.file = &lumberjack.Logger{ // concurrent-safed
 | 
			
		||||
			Filename:   file, // 文件路径
 | 
			
		||||
			MaxSize:    128,  // 单个文件最大尺寸,默认单位 M
 | 
			
		||||
			MaxBackups: 300,  // 最多保留 300 个备份
 | 
			
		||||
			MaxAge:     30,   // 最大时间,默认单位 day
 | 
			
		||||
			LocalTime:  true, // 使用本地时间
 | 
			
		||||
			Compress:   true, // 是否压缩 disabled by default
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTimeLayout custom time format
 | 
			
		||||
func WithTimeLayout(timeLayout string) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.timeLayout = timeLayout
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithDisableConsole WithEnableConsole write log to os.Stdout or os.Stderr
 | 
			
		||||
func WithDisableConsole() Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.disableConsole = true
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithDisableCaller disable caller field
 | 
			
		||||
func WithDisableCaller() Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.disableCaller = true
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewJSONLogger return a json-encoder zap logger,
 | 
			
		||||
func NewJSONLogger(opts ...Option) (*zap.Logger, error) {
 | 
			
		||||
	opt := &option{level: DefaultLevel, fields: make(map[string]string)}
 | 
			
		||||
	for _, f := range opts {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	timeLayout := DefaultTimeLayout
 | 
			
		||||
	if opt.timeLayout != "" {
 | 
			
		||||
		timeLayout = opt.timeLayout
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// similar to zap.NewProductionEncoderConfig()
 | 
			
		||||
	encoderConfig := zapcore.EncoderConfig{
 | 
			
		||||
		TimeKey:       "time",
 | 
			
		||||
		LevelKey:      "level",
 | 
			
		||||
		NameKey:       "logger", // used by logger.Named(key); optional; useless
 | 
			
		||||
		CallerKey:     "caller",
 | 
			
		||||
		MessageKey:    "msg",
 | 
			
		||||
		StacktraceKey: "stacktrace", // use by zap.AddStacktrace; optional; useless
 | 
			
		||||
		LineEnding:    zapcore.DefaultLineEnding,
 | 
			
		||||
		EncodeLevel:   zapcore.LowercaseLevelEncoder, // 小写编码器
 | 
			
		||||
		EncodeTime: func(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
 | 
			
		||||
			enc.AppendString(t.Format(timeLayout))
 | 
			
		||||
		},
 | 
			
		||||
		EncodeDuration: zapcore.MillisDurationEncoder,
 | 
			
		||||
		EncodeCaller:   zapcore.ShortCallerEncoder, // 全路径编码器
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	jsonEncoder := zapcore.NewJSONEncoder(encoderConfig)
 | 
			
		||||
 | 
			
		||||
	// lowPriority usd by info\debug\warn
 | 
			
		||||
	lowPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
 | 
			
		||||
		return lvl >= opt.level && lvl < zapcore.ErrorLevel
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// highPriority usd by error\panic\fatal
 | 
			
		||||
	highPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
 | 
			
		||||
		return lvl >= opt.level && lvl >= zapcore.ErrorLevel
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	stdout := zapcore.Lock(os.Stdout) // lock for concurrent safe
 | 
			
		||||
	stderr := zapcore.Lock(os.Stderr) // lock for concurrent safe
 | 
			
		||||
 | 
			
		||||
	core := zapcore.NewTee()
 | 
			
		||||
 | 
			
		||||
	if !opt.disableConsole {
 | 
			
		||||
		core = zapcore.NewTee(
 | 
			
		||||
			zapcore.NewCore(jsonEncoder,
 | 
			
		||||
				zapcore.NewMultiWriteSyncer(stdout),
 | 
			
		||||
				lowPriority,
 | 
			
		||||
			),
 | 
			
		||||
			zapcore.NewCore(jsonEncoder,
 | 
			
		||||
				zapcore.NewMultiWriteSyncer(stderr),
 | 
			
		||||
				highPriority,
 | 
			
		||||
			),
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.file != nil {
 | 
			
		||||
		core = zapcore.NewTee(core,
 | 
			
		||||
			zapcore.NewCore(jsonEncoder,
 | 
			
		||||
				zapcore.AddSync(opt.file),
 | 
			
		||||
				zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
 | 
			
		||||
					return lvl >= opt.level
 | 
			
		||||
				}),
 | 
			
		||||
			),
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger := zap.New(core,
 | 
			
		||||
		zap.WithCaller(!opt.disableCaller),
 | 
			
		||||
		zap.ErrorOutput(stderr),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for key, value := range opt.fields {
 | 
			
		||||
		logger = logger.WithOptions(zap.Fields(zapcore.Field{Key: key, Type: zapcore.StringType, String: value}))
 | 
			
		||||
	}
 | 
			
		||||
	return logger, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Meta = (*meta)(nil)
 | 
			
		||||
 | 
			
		||||
// Meta key-value
 | 
			
		||||
type Meta interface {
 | 
			
		||||
	Key() string
 | 
			
		||||
	Value() any
 | 
			
		||||
	meta()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type meta struct {
 | 
			
		||||
	key   string
 | 
			
		||||
	value any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *meta) Key() string {
 | 
			
		||||
	return m.key
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *meta) Value() any {
 | 
			
		||||
	return m.value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *meta) meta() {}
 | 
			
		||||
 | 
			
		||||
// NewMeta create meat
 | 
			
		||||
func NewMeta(key string, value any) Meta {
 | 
			
		||||
	return &meta{key: key, value: value}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WrapMeta wrap meta to zap fields
 | 
			
		||||
func WrapMeta(err error, metas ...Meta) (fields []zap.Field) {
 | 
			
		||||
	capacity := len(metas) + 1 // namespace meta
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		capacity++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fields = make([]zap.Field, 0, capacity)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fields = append(fields, zap.Error(err))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fields = append(fields, zap.Namespace("meta"))
 | 
			
		||||
	for _, meta := range metas {
 | 
			
		||||
		fields = append(fields, zap.Any(meta.Key(), meta.Value()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RestyClientLogger use by resty.Client
 | 
			
		||||
func RestyClientLogger(logger *zap.Logger) *RestyClientLog {
 | 
			
		||||
	return &RestyClientLog{logger: logger}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RestyClientLog struct {
 | 
			
		||||
	logger *zap.Logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RestyClientLog) Errorf(format string, v ...any) {
 | 
			
		||||
	r.logger.Sugar().Errorf(format, v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RestyClientLog) Warnf(format string, v ...any) {
 | 
			
		||||
	r.logger.Sugar().Warnf(format, v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RestyClientLog) Debugf(format string, v ...any) {
 | 
			
		||||
	r.logger.Sugar().Debugf(format, v)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										36
									
								
								pkg/mail/mail.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								pkg/mail/mail.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
package mail
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"gopkg.in/gomail.v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Options struct {
 | 
			
		||||
	MailHost string
 | 
			
		||||
	MailPort int
 | 
			
		||||
	MailUser string   // 发件人
 | 
			
		||||
	MailPass string   // 发件人密码
 | 
			
		||||
	MailTo   []string // 多个收件人
 | 
			
		||||
	Subject  string   // 邮件主题
 | 
			
		||||
	Body     string   // 邮件内容
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Send(o *Options) error {
 | 
			
		||||
 | 
			
		||||
	m := gomail.NewMessage()
 | 
			
		||||
 | 
			
		||||
	//设置发件人
 | 
			
		||||
	m.SetHeader("From", o.MailUser)
 | 
			
		||||
 | 
			
		||||
	//设置发送给多个用户
 | 
			
		||||
	m.SetHeader("To", o.MailTo...)
 | 
			
		||||
 | 
			
		||||
	//设置邮件主题
 | 
			
		||||
	m.SetHeader("Subject", o.Subject)
 | 
			
		||||
 | 
			
		||||
	//设置邮件正文
 | 
			
		||||
	m.SetBody("text/html", o.Body)
 | 
			
		||||
 | 
			
		||||
	d := gomail.NewDialer(o.MailHost, o.MailPort, o.MailUser, o.MailPass)
 | 
			
		||||
 | 
			
		||||
	return d.DialAndSend(m)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										28
									
								
								pkg/md5/md5.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								pkg/md5/md5.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
			
		||||
package md5
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	cryptoMD5 "crypto/md5"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ MD5 = (*md5)(nil)
 | 
			
		||||
 | 
			
		||||
type MD5 interface {
 | 
			
		||||
	i()
 | 
			
		||||
	// Encrypt 加密
 | 
			
		||||
	Encrypt(encryptStr string) string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type md5 struct{}
 | 
			
		||||
 | 
			
		||||
func New() MD5 {
 | 
			
		||||
	return &md5{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *md5) i() {}
 | 
			
		||||
 | 
			
		||||
func (m *md5) Encrypt(encryptStr string) string {
 | 
			
		||||
	s := cryptoMD5.New()
 | 
			
		||||
	s.Write([]byte(encryptStr))
 | 
			
		||||
	return hex.EncodeToString(s.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										51
									
								
								pkg/metrics/metrics.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								pkg/metrics/metrics.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,51 @@
 | 
			
		||||
package metrics
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
	"github.com/spf13/cast"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	namespace = "bvbej"
 | 
			
		||||
	subsystem = "api"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// metricsRequestsTotal metrics for request total 计数器(Counter)
 | 
			
		||||
var metricsRequestsTotal = prometheus.NewCounterVec(
 | 
			
		||||
	prometheus.CounterOpts{
 | 
			
		||||
		Namespace: namespace,
 | 
			
		||||
		Subsystem: subsystem,
 | 
			
		||||
		Name:      "requests_total",
 | 
			
		||||
		Help:      "request(ms) total",
 | 
			
		||||
	},
 | 
			
		||||
	[]string{"method", "path"},
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// metricsRequestsCost metrics for requests cost 累积直方图(Histogram)
 | 
			
		||||
var metricsRequestsCost = prometheus.NewHistogramVec(
 | 
			
		||||
	prometheus.HistogramOpts{
 | 
			
		||||
		Namespace: namespace,
 | 
			
		||||
		Subsystem: subsystem,
 | 
			
		||||
		Name:      "requests_cost",
 | 
			
		||||
		Help:      "request(ms) cost milliseconds",
 | 
			
		||||
	},
 | 
			
		||||
	[]string{"method", "path", "success"},
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	prometheus.MustRegister(metricsRequestsTotal, metricsRequestsCost)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RecordMetrics 记录指标
 | 
			
		||||
func RecordMetrics(method, uri string, success bool, costSeconds float64) {
 | 
			
		||||
	metricsRequestsTotal.With(prometheus.Labels{
 | 
			
		||||
		"method": method,
 | 
			
		||||
		"path":   uri,
 | 
			
		||||
	}).Inc()
 | 
			
		||||
 | 
			
		||||
	metricsRequestsCost.With(prometheus.Labels{
 | 
			
		||||
		"method":  method,
 | 
			
		||||
		"path":    uri,
 | 
			
		||||
		"success": cast.ToString(success),
 | 
			
		||||
	}).Observe(costSeconds)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										406
									
								
								pkg/mux/context.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										406
									
								
								pkg/mux/context.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,406 @@
 | 
			
		||||
package mux
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	stdCtx "context"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/errno"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gin-gonic/gin/binding"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type HandlerFunc func(c Context)
 | 
			
		||||
 | 
			
		||||
type Trace = trace.T
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	_Alias            = "_alias_"
 | 
			
		||||
	_TraceName        = "_trace_"
 | 
			
		||||
	_LoggerName       = "_logger_"
 | 
			
		||||
	_BodyName         = "_body_"
 | 
			
		||||
	_PayloadName      = "_payload_"
 | 
			
		||||
	_GraphPayloadName = "_graph_payload_"
 | 
			
		||||
	_AbortErrorName   = "_abort_error_"
 | 
			
		||||
	_Auth             = "_auth_"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var contextPool = &sync.Pool{
 | 
			
		||||
	New: func() any {
 | 
			
		||||
		return new(context)
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newContext(ctx *gin.Context) Context {
 | 
			
		||||
	getContext := contextPool.Get().(*context)
 | 
			
		||||
	getContext.ctx = ctx
 | 
			
		||||
	return getContext
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func releaseContext(ctx Context) {
 | 
			
		||||
	c := ctx.(*context)
 | 
			
		||||
	c.ctx = nil
 | 
			
		||||
	contextPool.Put(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Context = (*context)(nil)
 | 
			
		||||
 | 
			
		||||
type Context interface {
 | 
			
		||||
	init()
 | 
			
		||||
 | 
			
		||||
	Context() *gin.Context
 | 
			
		||||
 | 
			
		||||
	// ShouldBindQuery 反序列化 query
 | 
			
		||||
	// tag: `form:"xxx"` (注:不要写成 query)
 | 
			
		||||
	ShouldBindQuery(obj any) error
 | 
			
		||||
 | 
			
		||||
	// ShouldBindPostForm 反序列化 x-www-from-urlencoded
 | 
			
		||||
	// tag: `form:"xxx"`
 | 
			
		||||
	ShouldBindPostForm(obj any) error
 | 
			
		||||
 | 
			
		||||
	// ShouldBindForm 同时反序列化 form-data;
 | 
			
		||||
	// tag: `form:"xxx"`
 | 
			
		||||
	ShouldBindForm(obj any) error
 | 
			
		||||
 | 
			
		||||
	// ShouldBindJSON 反序列化 post-json
 | 
			
		||||
	// tag: `json:"xxx"`
 | 
			
		||||
	ShouldBindJSON(obj any) error
 | 
			
		||||
 | 
			
		||||
	// ShouldBindURI 反序列化 path 参数(如路由路径为 /user/:name)
 | 
			
		||||
	// tag: `uri:"xxx"`
 | 
			
		||||
	ShouldBindURI(obj any) error
 | 
			
		||||
 | 
			
		||||
	// Redirect 重定向
 | 
			
		||||
	Redirect(code int, location string)
 | 
			
		||||
 | 
			
		||||
	// Trace 获取 Trace 对象
 | 
			
		||||
	Trace() Trace
 | 
			
		||||
	setTrace(trace Trace)
 | 
			
		||||
	disableTrace()
 | 
			
		||||
 | 
			
		||||
	// Logger 获取 Logger 对象
 | 
			
		||||
	Logger() *zap.Logger
 | 
			
		||||
	setLogger(logger *zap.Logger)
 | 
			
		||||
 | 
			
		||||
	// Payload 正确返回
 | 
			
		||||
	Payload(payload any)
 | 
			
		||||
	getPayload() any
 | 
			
		||||
 | 
			
		||||
	// GraphPayload GraphQL返回值 与 api 返回结构不同
 | 
			
		||||
	GraphPayload(payload any)
 | 
			
		||||
	getGraphPayload() any
 | 
			
		||||
 | 
			
		||||
	// HTML 返回界面
 | 
			
		||||
	HTML(name string, obj any)
 | 
			
		||||
 | 
			
		||||
	// AbortWithError 错误返回
 | 
			
		||||
	AbortWithError(err errno.Error)
 | 
			
		||||
	abortError() errno.Error
 | 
			
		||||
 | 
			
		||||
	// Header 获取 Header 对象
 | 
			
		||||
	Header() http.Header
 | 
			
		||||
	// GetHeader 获取 Header
 | 
			
		||||
	GetHeader(key string) string
 | 
			
		||||
	// SetHeader 设置 Header
 | 
			
		||||
	SetHeader(key, value string)
 | 
			
		||||
 | 
			
		||||
	Auth() any
 | 
			
		||||
	SetAuth(auth any)
 | 
			
		||||
 | 
			
		||||
	// Authorization 获取请求认证信息
 | 
			
		||||
	Authorization() string
 | 
			
		||||
 | 
			
		||||
	// Platform 平台标识
 | 
			
		||||
	Platform() string
 | 
			
		||||
 | 
			
		||||
	// Alias 设置路由别名 for metrics uri
 | 
			
		||||
	Alias() string
 | 
			
		||||
	setAlias(path string)
 | 
			
		||||
 | 
			
		||||
	// RequestInputParams 获取所有参数
 | 
			
		||||
	RequestInputParams() url.Values
 | 
			
		||||
	// RequestQueryParams 获取 Query 参数
 | 
			
		||||
	RequestQueryParams() url.Values
 | 
			
		||||
	// RequestPostFormParams  获取 PostForm 参数
 | 
			
		||||
	RequestPostFormParams() url.Values
 | 
			
		||||
	// Request 获取 Request 对象
 | 
			
		||||
	Request() *http.Request
 | 
			
		||||
	// RawData 获取 Request.Body
 | 
			
		||||
	RawData() []byte
 | 
			
		||||
	// Method 获取 Request.Method
 | 
			
		||||
	Method() string
 | 
			
		||||
	// Host 获取 Request.Host
 | 
			
		||||
	Host() string
 | 
			
		||||
	// Path 获取 请求的路径 Request.URL.Path (不附带 querystring)
 | 
			
		||||
	Path() string
 | 
			
		||||
	// URI 获取 unescape 后的 Request.URL.RequestURI()
 | 
			
		||||
	URI() string
 | 
			
		||||
	// RequestContext 获取请求的 context (当 client 关闭后,会自动 canceled)
 | 
			
		||||
	RequestContext() StdContext
 | 
			
		||||
	// ResponseWriter 获取 ResponseWriter 对象
 | 
			
		||||
	ResponseWriter() gin.ResponseWriter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type context struct {
 | 
			
		||||
	ctx *gin.Context
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StdContext struct {
 | 
			
		||||
	stdCtx.Context
 | 
			
		||||
	Trace
 | 
			
		||||
	*zap.Logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) init() {
 | 
			
		||||
	body, err := c.ctx.GetRawData()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.ctx.Set(_BodyName, body)                               // cache body是为了trace使用
 | 
			
		||||
	c.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // re-construct req body
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Context() *gin.Context {
 | 
			
		||||
	return c.ctx
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ShouldBindQuery 反序列化querystring
 | 
			
		||||
// tag: `form:"xxx"` (注:不要写成query)
 | 
			
		||||
func (c *context) ShouldBindQuery(obj any) error {
 | 
			
		||||
	return c.ctx.ShouldBindWith(obj, binding.Query)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ShouldBindPostForm 反序列化 postform (querystring 会被忽略)
 | 
			
		||||
// tag: `form:"xxx"`
 | 
			
		||||
func (c *context) ShouldBindPostForm(obj any) error {
 | 
			
		||||
	return c.ctx.ShouldBindWith(obj, binding.FormPost)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ShouldBindForm 同时反序列化querystring和postform;
 | 
			
		||||
// 当querystring和postform存在相同字段时,postform优先使用。
 | 
			
		||||
// tag: `form:"xxx"`
 | 
			
		||||
func (c *context) ShouldBindForm(obj any) error {
 | 
			
		||||
	return c.ctx.ShouldBindWith(obj, binding.Form)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ShouldBindJSON 反序列化postjson
 | 
			
		||||
// tag: `json:"xxx"`
 | 
			
		||||
func (c *context) ShouldBindJSON(obj any) error {
 | 
			
		||||
	return c.ctx.ShouldBindWith(obj, binding.JSON)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ShouldBindURI 反序列化path参数(如路由路径为 /user/:name)
 | 
			
		||||
// tag: `uri:"xxx"`
 | 
			
		||||
func (c *context) ShouldBindURI(obj any) error {
 | 
			
		||||
	return c.ctx.ShouldBindUri(obj)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Redirect 重定向
 | 
			
		||||
func (c *context) Redirect(code int, location string) {
 | 
			
		||||
	c.ctx.Redirect(code, location)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Trace() Trace {
 | 
			
		||||
	t, ok := c.ctx.Get(_TraceName)
 | 
			
		||||
	if !ok || t == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return t.(Trace)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) setTrace(trace Trace) {
 | 
			
		||||
	c.ctx.Set(_TraceName, trace)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) disableTrace() {
 | 
			
		||||
	c.setTrace(nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Logger() *zap.Logger {
 | 
			
		||||
	logger, ok := c.ctx.Get(_LoggerName)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return logger.(*zap.Logger)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) setLogger(logger *zap.Logger) {
 | 
			
		||||
	c.ctx.Set(_LoggerName, logger)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) getPayload() any {
 | 
			
		||||
	if payload, ok := c.ctx.Get(_PayloadName); ok != false {
 | 
			
		||||
		return payload
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Payload(payload any) {
 | 
			
		||||
	c.ctx.Set(_PayloadName, payload)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) getGraphPayload() any {
 | 
			
		||||
	if payload, ok := c.ctx.Get(_GraphPayloadName); ok != false {
 | 
			
		||||
		return payload
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) GraphPayload(payload any) {
 | 
			
		||||
	c.ctx.Set(_GraphPayloadName, payload)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) HTML(name string, obj any) {
 | 
			
		||||
	c.ctx.HTML(200, name+".html", obj)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Header() http.Header {
 | 
			
		||||
	header := c.ctx.Request.Header
 | 
			
		||||
 | 
			
		||||
	clone := make(http.Header, len(header))
 | 
			
		||||
	for k, v := range header {
 | 
			
		||||
		value := make([]string, len(v))
 | 
			
		||||
		copy(value, v)
 | 
			
		||||
 | 
			
		||||
		clone[k] = value
 | 
			
		||||
	}
 | 
			
		||||
	return clone
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) GetHeader(key string) string {
 | 
			
		||||
	return c.ctx.GetHeader(key)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) SetHeader(key, value string) {
 | 
			
		||||
	c.ctx.Header(key, value)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Auth() any {
 | 
			
		||||
	val, ok := c.ctx.Get(_Auth)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return val
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) SetAuth(auth any) {
 | 
			
		||||
	c.ctx.Set(_Auth, auth)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Authorization() string {
 | 
			
		||||
	return c.ctx.GetHeader("Authorization")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Platform() string {
 | 
			
		||||
	return c.ctx.GetHeader("X-Platform")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) AbortWithError(err errno.Error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		httpCode := err.GetHttpCode()
 | 
			
		||||
		if httpCode == 0 {
 | 
			
		||||
			httpCode = http.StatusInternalServerError
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c.ctx.AbortWithStatus(httpCode)
 | 
			
		||||
		c.ctx.Set(_AbortErrorName, err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) abortError() errno.Error {
 | 
			
		||||
	err, _ := c.ctx.Get(_AbortErrorName)
 | 
			
		||||
	return err.(errno.Error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) Alias() string {
 | 
			
		||||
	path, ok := c.ctx.Get(_Alias)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return path.(string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) setAlias(path string) {
 | 
			
		||||
	if path = strings.TrimSpace(path); path != "" {
 | 
			
		||||
		c.ctx.Set(_Alias, path)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestInputParams 获取所有参数
 | 
			
		||||
func (c *context) RequestInputParams() url.Values {
 | 
			
		||||
	_ = c.ctx.Request.ParseForm()
 | 
			
		||||
	return c.ctx.Request.Form
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestQueryParams 获取Query参数
 | 
			
		||||
func (c *context) RequestQueryParams() url.Values {
 | 
			
		||||
	query, _ := url.ParseQuery(c.ctx.Request.URL.RawQuery)
 | 
			
		||||
	return query
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestPostFormParams 获取 PostForm 参数
 | 
			
		||||
func (c *context) RequestPostFormParams() url.Values {
 | 
			
		||||
	_ = c.ctx.Request.ParseForm()
 | 
			
		||||
	return c.ctx.Request.PostForm
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Request 获取 Request
 | 
			
		||||
func (c *context) Request() *http.Request {
 | 
			
		||||
	return c.ctx.Request
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *context) RawData() []byte {
 | 
			
		||||
	body, ok := c.ctx.Get(_BodyName)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return body.([]byte)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Method 请求的method
 | 
			
		||||
func (c *context) Method() string {
 | 
			
		||||
	return c.ctx.Request.Method
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Host 请求的host
 | 
			
		||||
func (c *context) Host() string {
 | 
			
		||||
	return c.ctx.Request.Host
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Path 请求的路径(不附带querystring)
 | 
			
		||||
func (c *context) Path() string {
 | 
			
		||||
	return c.ctx.Request.URL.Path
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// URI unescape后的uri
 | 
			
		||||
func (c *context) URI() string {
 | 
			
		||||
	uri, _ := url.QueryUnescape(c.ctx.Request.URL.RequestURI())
 | 
			
		||||
	return uri
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RequestContext (包装 Trace + Logger) 获取请求的 context (当client关闭后,会自动canceled)
 | 
			
		||||
func (c *context) RequestContext() StdContext {
 | 
			
		||||
	return StdContext{
 | 
			
		||||
		//c.ctx.Request.Context(),
 | 
			
		||||
		stdCtx.Background(),
 | 
			
		||||
		c.Trace(),
 | 
			
		||||
		c.Logger(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResponseWriter 获取 ResponseWriter
 | 
			
		||||
func (c *context) ResponseWriter() gin.ResponseWriter {
 | 
			
		||||
	return c.ctx.Writer
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										466
									
								
								pkg/mux/core.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										466
									
								
								pkg/mux/core.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,466 @@
 | 
			
		||||
package mux
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/color"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/env"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/errno"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/limiter"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/validator"
 | 
			
		||||
	"github.com/gin-contrib/pprof"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gin-gonic/gin/binding"
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus/promhttp"
 | 
			
		||||
	cors "github.com/rs/cors/wrapper/gin"
 | 
			
		||||
	"go.uber.org/multierr"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"golang.org/x/time/rate"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	enableCors        bool
 | 
			
		||||
	enablePProf       bool
 | 
			
		||||
	enablePrometheus  bool
 | 
			
		||||
	enableOpenBrowser string
 | 
			
		||||
	staticDirs        []string
 | 
			
		||||
	panicNotify       OnPanicNotify
 | 
			
		||||
	recordMetrics     RecordMetrics
 | 
			
		||||
	rateLimiter       limiter.RateLimiter
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const SuccessCode = 0
 | 
			
		||||
 | 
			
		||||
type Failure struct {
 | 
			
		||||
	ResultCode int    `json:"result_code"` // 业务码
 | 
			
		||||
	ResultInfo string `json:"result_info"` // 描述信息
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Success struct {
 | 
			
		||||
	ResultCode int `json:"result_code"` // 业务码
 | 
			
		||||
	ResultData any `json:"result_data"` //返回数据
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************************************************************************/
 | 
			
		||||
 | 
			
		||||
// OnPanicNotify 发生panic时通知用
 | 
			
		||||
type OnPanicNotify func(ctx Context, err any, stackInfo string)
 | 
			
		||||
 | 
			
		||||
// RecordMetrics 记录prometheus指标用
 | 
			
		||||
// 如果使用AliasForRecordMetrics配置了别名,uri将被替换为别名。
 | 
			
		||||
type RecordMetrics func(method, uri string, success bool, costSeconds float64)
 | 
			
		||||
 | 
			
		||||
// DisableTrace 禁用追踪链
 | 
			
		||||
func DisableTrace(ctx Context) {
 | 
			
		||||
	ctx.disableTrace()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithPanicNotify 设置panic时的通知回调
 | 
			
		||||
func WithPanicNotify(notify OnPanicNotify) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.panicNotify = notify
 | 
			
		||||
		fmt.Println(color.Green("* [register panic notify]"))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithRecordMetrics 设置记录prometheus记录指标回调
 | 
			
		||||
func WithRecordMetrics(record RecordMetrics) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.recordMetrics = record
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithEnableCors 开启CORS
 | 
			
		||||
func WithEnableCors() Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.enableCors = true
 | 
			
		||||
		fmt.Println(color.Green("* [register cors]"))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithEnableRate 开启限流
 | 
			
		||||
func WithEnableRate(limit rate.Limit, burst int) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.rateLimiter = limiter.NewRateLimiter(limit, burst)
 | 
			
		||||
		fmt.Println(color.Green("* [register rate]"))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithStaticDir 设置静态文件目录
 | 
			
		||||
func WithStaticDir(dirs []string) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		opt.staticDirs = dirs
 | 
			
		||||
		fmt.Println(color.Green("* [register rate]"))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AliasForRecordMetrics 对请求uri起个别名,用于prometheus记录指标。
 | 
			
		||||
// 如:Get /user/:username 这样的uri,因为username会有非常多的情况,这样记录prometheus指标会非常的不有好。
 | 
			
		||||
func AliasForRecordMetrics(path string) HandlerFunc {
 | 
			
		||||
	return func(ctx Context) {
 | 
			
		||||
		ctx.setAlias(path)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************************************************************************/
 | 
			
		||||
 | 
			
		||||
// RouterGroup 包装gin的RouterGroup
 | 
			
		||||
type RouterGroup interface {
 | 
			
		||||
	Group(string, ...HandlerFunc) RouterGroup
 | 
			
		||||
	IRoutes
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ IRoutes = (*router)(nil)
 | 
			
		||||
 | 
			
		||||
// IRoutes 包装gin的IRoutes
 | 
			
		||||
type IRoutes interface {
 | 
			
		||||
	Any(string, ...HandlerFunc)
 | 
			
		||||
	GET(string, ...HandlerFunc)
 | 
			
		||||
	POST(string, ...HandlerFunc)
 | 
			
		||||
	DELETE(string, ...HandlerFunc)
 | 
			
		||||
	PATCH(string, ...HandlerFunc)
 | 
			
		||||
	PUT(string, ...HandlerFunc)
 | 
			
		||||
	OPTIONS(string, ...HandlerFunc)
 | 
			
		||||
	HEAD(string, ...HandlerFunc)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type router struct {
 | 
			
		||||
	group *gin.RouterGroup
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
 | 
			
		||||
	group := r.group.Group(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
	return &router{group: group}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) Any(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.Any(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) GET(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.GET(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) POST(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.POST(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) DELETE(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.DELETE(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) PATCH(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.PATCH(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) PUT(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.PUT(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) OPTIONS(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.OPTIONS(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *router) HEAD(relativePath string, handlers ...HandlerFunc) {
 | 
			
		||||
	r.group.HEAD(relativePath, wrapHandlers(handlers...)...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func wrapHandlers(handlers ...HandlerFunc) []gin.HandlerFunc {
 | 
			
		||||
	list := make([]gin.HandlerFunc, len(handlers))
 | 
			
		||||
	for i, handler := range handlers {
 | 
			
		||||
		fn := handler
 | 
			
		||||
		list[i] = func(c *gin.Context) {
 | 
			
		||||
			ctx := newContext(c)
 | 
			
		||||
			defer releaseContext(ctx)
 | 
			
		||||
 | 
			
		||||
			fn(ctx)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return list
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************************************************************************/
 | 
			
		||||
 | 
			
		||||
var _ Mux = (*mux)(nil)
 | 
			
		||||
 | 
			
		||||
type Mux interface {
 | 
			
		||||
	http.Handler
 | 
			
		||||
	Group(relativePath string, handlers ...HandlerFunc) RouterGroup
 | 
			
		||||
	Routes() gin.RoutesInfo
 | 
			
		||||
	HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mux struct {
 | 
			
		||||
	engine *gin.Engine
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	m.engine.ServeHTTP(w, req)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mux) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
 | 
			
		||||
	return &router{
 | 
			
		||||
		group: m.engine.Group(relativePath, wrapHandlers(handlers...)...),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mux) Routes() gin.RoutesInfo {
 | 
			
		||||
	return m.engine.Routes()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mux) HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc) {
 | 
			
		||||
	m.engine.GET(relativePath, handlerFunc)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(logger *zap.Logger, options ...Option) (Mux, error) {
 | 
			
		||||
	if logger == nil {
 | 
			
		||||
		return nil, errors.New("logger required")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
	binding.Validator = validator.Validator
 | 
			
		||||
	newMux := &mux{
 | 
			
		||||
		engine: gin.New(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Println(color.Green(fmt.Sprintf("* [register env %s]", env.Active().Value())))
 | 
			
		||||
 | 
			
		||||
	// withoutLogPaths 这些请求,默认不记录日志
 | 
			
		||||
	withoutTracePaths := map[string]bool{
 | 
			
		||||
		"/metrics":       true,
 | 
			
		||||
		"/favicon.ico":   true,
 | 
			
		||||
		"/system/health": true,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	opt := new(option)
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.enablePProf {
 | 
			
		||||
		pprof.Register(newMux.engine)
 | 
			
		||||
		fmt.Println(color.Green("* [register pprof]"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.enablePrometheus {
 | 
			
		||||
		newMux.engine.GET("/metrics", gin.WrapH(promhttp.Handler()))
 | 
			
		||||
		fmt.Println(color.Green("* [register prometheus]"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.enableCors {
 | 
			
		||||
		newMux.engine.Use(cors.AllowAll())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if opt.staticDirs != nil {
 | 
			
		||||
		for _, dir := range opt.staticDirs {
 | 
			
		||||
			newMux.engine.StaticFS(dir, gin.Dir(dir, false))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// recover两次,防止处理时发生panic,尤其是在OnPanicNotify中。
 | 
			
		||||
	newMux.engine.Use(func(ctx *gin.Context) {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if err := recover(); err != nil {
 | 
			
		||||
				logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", string(debug.Stack())))
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		ctx.Next()
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	newMux.engine.Use(func(ctx *gin.Context) {
 | 
			
		||||
		ts := time.Now()
 | 
			
		||||
 | 
			
		||||
		newCtx := newContext(ctx)
 | 
			
		||||
		defer releaseContext(newCtx)
 | 
			
		||||
 | 
			
		||||
		newCtx.init()
 | 
			
		||||
		newCtx.setLogger(logger)
 | 
			
		||||
 | 
			
		||||
		if !withoutTracePaths[ctx.Request.URL.Path] {
 | 
			
		||||
			if traceId := newCtx.GetHeader(trace.Header); traceId != "" {
 | 
			
		||||
				newCtx.setTrace(trace.New(traceId))
 | 
			
		||||
			} else {
 | 
			
		||||
				newCtx.setTrace(trace.New(""))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if err := recover(); err != nil {
 | 
			
		||||
				stackInfo := string(debug.Stack())
 | 
			
		||||
				logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", stackInfo))
 | 
			
		||||
				newCtx.AbortWithError(errno.NewError(
 | 
			
		||||
					http.StatusInternalServerError,
 | 
			
		||||
					http.StatusInternalServerError,
 | 
			
		||||
					http.StatusText(http.StatusInternalServerError)),
 | 
			
		||||
				)
 | 
			
		||||
 | 
			
		||||
				if notify := opt.panicNotify; notify != nil {
 | 
			
		||||
					notify(newCtx, err, stackInfo)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if ctx.Writer.Status() == http.StatusNotFound {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var (
 | 
			
		||||
				response        any
 | 
			
		||||
				businessCode    int
 | 
			
		||||
				businessCodeMsg string
 | 
			
		||||
				abortErr        error
 | 
			
		||||
				graphResponse   any
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			if ctx.IsAborted() {
 | 
			
		||||
				for i := range ctx.Errors { // gin error
 | 
			
		||||
					multierr.AppendInto(&abortErr, ctx.Errors[i])
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if err := newCtx.abortError(); err != nil { // customer err
 | 
			
		||||
					multierr.AppendInto(&abortErr, err.GetErr())
 | 
			
		||||
					response = err
 | 
			
		||||
					businessCode = err.GetBusinessCode()
 | 
			
		||||
					businessCodeMsg = err.GetMsg()
 | 
			
		||||
 | 
			
		||||
					if x := newCtx.Trace(); x != nil {
 | 
			
		||||
						newCtx.SetHeader(trace.Header, x.ID())
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					ctx.JSON(err.GetHttpCode(), &Failure{
 | 
			
		||||
						ResultCode: businessCode,
 | 
			
		||||
						ResultInfo: businessCodeMsg,
 | 
			
		||||
					})
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				response = newCtx.getPayload()
 | 
			
		||||
				if response != nil {
 | 
			
		||||
					if x := newCtx.Trace(); x != nil {
 | 
			
		||||
						newCtx.SetHeader(trace.Header, x.ID())
 | 
			
		||||
					}
 | 
			
		||||
					ctx.JSON(http.StatusOK, response)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			graphResponse = newCtx.getGraphPayload()
 | 
			
		||||
 | 
			
		||||
			if opt.recordMetrics != nil {
 | 
			
		||||
				uri := newCtx.Path()
 | 
			
		||||
				if alias := newCtx.Alias(); alias != "" {
 | 
			
		||||
					uri = alias
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				opt.recordMetrics(
 | 
			
		||||
					newCtx.Method(),
 | 
			
		||||
					uri,
 | 
			
		||||
					!ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK,
 | 
			
		||||
					time.Since(ts).Seconds(),
 | 
			
		||||
				)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var t *trace.Trace
 | 
			
		||||
			if x := newCtx.Trace(); x != nil {
 | 
			
		||||
				t = x.(*trace.Trace)
 | 
			
		||||
			} else {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			decodedURL, _ := url.QueryUnescape(ctx.Request.URL.RequestURI())
 | 
			
		||||
 | 
			
		||||
			t.WithRequest(&trace.Request{
 | 
			
		||||
				TTL:        "un-limit",
 | 
			
		||||
				Method:     ctx.Request.Method,
 | 
			
		||||
				DecodedURL: decodedURL,
 | 
			
		||||
				Header:     ctx.Request.Header,
 | 
			
		||||
				Body:       string(newCtx.RawData()),
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			var responseBody any
 | 
			
		||||
 | 
			
		||||
			if response != nil {
 | 
			
		||||
				responseBody = response
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if graphResponse != nil {
 | 
			
		||||
				responseBody = graphResponse
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			t.WithResponse(&trace.Response{
 | 
			
		||||
				Header:          ctx.Writer.Header(),
 | 
			
		||||
				HttpCode:        ctx.Writer.Status(),
 | 
			
		||||
				HttpCodeMsg:     http.StatusText(ctx.Writer.Status()),
 | 
			
		||||
				BusinessCode:    businessCode,
 | 
			
		||||
				BusinessCodeMsg: businessCodeMsg,
 | 
			
		||||
				Body:            responseBody,
 | 
			
		||||
				CostSeconds:     time.Since(ts).Seconds(),
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			t.Success = !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK
 | 
			
		||||
			t.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
 | 
			
		||||
			logger.Info("core-interceptor",
 | 
			
		||||
				zap.Any("method", ctx.Request.Method),
 | 
			
		||||
				zap.Any("path", decodedURL),
 | 
			
		||||
				zap.Any("http_code", ctx.Writer.Status()),
 | 
			
		||||
				zap.Any("business_code", businessCode),
 | 
			
		||||
				zap.Any("success", t.Success),
 | 
			
		||||
				zap.Any("cost_seconds", t.CostSeconds),
 | 
			
		||||
				zap.Any("trace_id", t.Identifier),
 | 
			
		||||
				zap.Any("trace_info", t),
 | 
			
		||||
				zap.Error(abortErr),
 | 
			
		||||
			)
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		ctx.Next()
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if opt.rateLimiter != nil {
 | 
			
		||||
		newMux.engine.Use(func(ctx *gin.Context) {
 | 
			
		||||
			newCtx := newContext(ctx)
 | 
			
		||||
			defer releaseContext(newCtx)
 | 
			
		||||
 | 
			
		||||
			if !opt.rateLimiter.Allow(ctx.ClientIP()) {
 | 
			
		||||
				newCtx.AbortWithError(errno.NewError(
 | 
			
		||||
					http.StatusTooManyRequests,
 | 
			
		||||
					http.StatusTooManyRequests,
 | 
			
		||||
					http.StatusText(http.StatusTooManyRequests)),
 | 
			
		||||
				)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			ctx.Next()
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newMux.engine.NoMethod(wrapHandlers(DisableTrace)...)
 | 
			
		||||
	newMux.engine.NoRoute(wrapHandlers(DisableTrace)...)
 | 
			
		||||
	system := newMux.Group("/system")
 | 
			
		||||
	{
 | 
			
		||||
		// 健康检查
 | 
			
		||||
		system.GET("/health", func(ctx Context) {
 | 
			
		||||
			resp := &struct {
 | 
			
		||||
				Timestamp   time.Time `json:"timestamp"`
 | 
			
		||||
				Environment string    `json:"environment"`
 | 
			
		||||
				Host        string    `json:"host"`
 | 
			
		||||
				Status      string    `json:"status"`
 | 
			
		||||
			}{
 | 
			
		||||
				Timestamp:   time.Now(),
 | 
			
		||||
				Environment: env.Active().Value(),
 | 
			
		||||
				Host:        ctx.Host(),
 | 
			
		||||
				Status:      "ok",
 | 
			
		||||
			}
 | 
			
		||||
			ctx.Payload(resp)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return newMux, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										3
									
								
								pkg/observable/iterable.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								pkg/observable/iterable.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
package observable
 | 
			
		||||
 | 
			
		||||
type Iterable <-chan any
 | 
			
		||||
							
								
								
									
										67
									
								
								pkg/observable/observable.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								pkg/observable/observable.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,67 @@
 | 
			
		||||
package observable
 | 
			
		||||
 | 
			
		||||
// Ref: github.com/Dreamacro/clash/common/observable
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Observable struct {
 | 
			
		||||
	iterable Iterable
 | 
			
		||||
	listener map[Subscription]*Subscriber
 | 
			
		||||
	mux      sync.Mutex
 | 
			
		||||
	done     bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *Observable) process() {
 | 
			
		||||
	for item := range o.iterable {
 | 
			
		||||
		o.mux.Lock()
 | 
			
		||||
		for _, sub := range o.listener {
 | 
			
		||||
			sub.Emit(item)
 | 
			
		||||
		}
 | 
			
		||||
		o.mux.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
	o.close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *Observable) close() {
 | 
			
		||||
	o.mux.Lock()
 | 
			
		||||
	defer o.mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	o.done = true
 | 
			
		||||
	for _, sub := range o.listener {
 | 
			
		||||
		sub.Close()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *Observable) Subscribe() (Subscription, error) {
 | 
			
		||||
	o.mux.Lock()
 | 
			
		||||
	defer o.mux.Unlock()
 | 
			
		||||
	if o.done {
 | 
			
		||||
		return nil, errors.New("observable is closed")
 | 
			
		||||
	}
 | 
			
		||||
	subscriber := newSubscriber()
 | 
			
		||||
	o.listener[subscriber.Out()] = subscriber
 | 
			
		||||
	return subscriber.Out(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *Observable) UnSubscribe(sub Subscription) {
 | 
			
		||||
	o.mux.Lock()
 | 
			
		||||
	defer o.mux.Unlock()
 | 
			
		||||
	subscriber, exist := o.listener[sub]
 | 
			
		||||
	if !exist {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	delete(o.listener, sub)
 | 
			
		||||
	subscriber.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewObservable(any Iterable) *Observable {
 | 
			
		||||
	observable := &Observable{
 | 
			
		||||
		iterable: any,
 | 
			
		||||
		listener: map[Subscription]*Subscriber{},
 | 
			
		||||
	}
 | 
			
		||||
	go observable.process()
 | 
			
		||||
	return observable
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										33
									
								
								pkg/observable/subscriber.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								pkg/observable/subscriber.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
			
		||||
package observable
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Subscription <-chan any
 | 
			
		||||
 | 
			
		||||
type Subscriber struct {
 | 
			
		||||
	buffer chan any
 | 
			
		||||
	once   sync.Once
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Subscriber) Emit(item any) {
 | 
			
		||||
	s.buffer <- item
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Subscriber) Out() Subscription {
 | 
			
		||||
	return s.buffer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Subscriber) Close() {
 | 
			
		||||
	s.once.Do(func() {
 | 
			
		||||
		close(s.buffer)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSubscriber() *Subscriber {
 | 
			
		||||
	sub := &Subscriber{
 | 
			
		||||
		buffer: make(chan any, 200),
 | 
			
		||||
	}
 | 
			
		||||
	return sub
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										50
									
								
								pkg/p/print.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								pkg/p/print.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
package p
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
type Trace = trace.T
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	Trace *trace.Trace
 | 
			
		||||
	Debug *trace.Debug
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newOption() *option {
 | 
			
		||||
	return &option{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Println(key string, value any, options ...Option) {
 | 
			
		||||
	ts := time.Now()
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if opt.Trace != nil {
 | 
			
		||||
			opt.Debug.Key = key
 | 
			
		||||
			opt.Debug.Value = value
 | 
			
		||||
			opt.Debug.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
			opt.Trace.AppendDebug(opt.Debug)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Println(fmt.Sprintf("KEY: %s | VALUE: %v", key, value))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithTrace 设置trace信息
 | 
			
		||||
func WithTrace(t Trace) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		if t != nil {
 | 
			
		||||
			opt.Trace = t.(*trace.Trace)
 | 
			
		||||
			opt.Debug = new(trace.Debug)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										142
									
								
								pkg/proxy/io.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								pkg/proxy/io.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,142 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"golang.org/x/time/rate"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const burstLimit = 1000 * 1000 * 1000
 | 
			
		||||
 | 
			
		||||
type Reader struct {
 | 
			
		||||
	r       io.Reader
 | 
			
		||||
	limiter *rate.Limiter
 | 
			
		||||
	ctx     context.Context
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewReader(r io.Reader) *Reader {
 | 
			
		||||
	return &Reader{
 | 
			
		||||
		r:   r,
 | 
			
		||||
		ctx: context.Background(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Reader) SetRateLimit(bytesPerSec float64) {
 | 
			
		||||
	s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
 | 
			
		||||
	s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Reader) Read(p []byte) (int, error) {
 | 
			
		||||
	if s.limiter == nil {
 | 
			
		||||
		return s.r.Read(p)
 | 
			
		||||
	}
 | 
			
		||||
	n, err := s.r.Read(p)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return n, err
 | 
			
		||||
	}
 | 
			
		||||
	if err := s.limiter.WaitN(s.ctx, n); err != nil {
 | 
			
		||||
		return n, err
 | 
			
		||||
	}
 | 
			
		||||
	return n, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
 | 
			
		||||
	conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CloseConn(conn net.Conn) {
 | 
			
		||||
	if conn != nil {
 | 
			
		||||
		_ = conn.SetDeadline(time.Now().Add(time.Millisecond))
 | 
			
		||||
		_ = conn.Close()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
 | 
			
		||||
	var one = &sync.Once{}
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if e := recover(); e != nil {
 | 
			
		||||
				logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
		var err error
 | 
			
		||||
		var isSrcErr bool
 | 
			
		||||
		if bytesPreSec > 0 {
 | 
			
		||||
			newReader := NewReader(src)
 | 
			
		||||
			newReader.SetRateLimit(bytesPreSec)
 | 
			
		||||
			_, isSrcErr, err = IoCopy(dst, newReader, func(c int) {
 | 
			
		||||
				cfn(c, false)
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
		} else {
 | 
			
		||||
			_, isSrcErr, err = IoCopy(dst, src, func(c int) {
 | 
			
		||||
				cfn(c, false)
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			one.Do(func() {
 | 
			
		||||
				fn(isSrcErr, err)
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if e := recover(); e != nil {
 | 
			
		||||
				logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
		var err error
 | 
			
		||||
		var isSrcErr bool
 | 
			
		||||
		if bytesPreSec > 0 {
 | 
			
		||||
			newReader := NewReader(dst)
 | 
			
		||||
			newReader.SetRateLimit(bytesPreSec)
 | 
			
		||||
			_, isSrcErr, err = IoCopy(src, newReader, func(c int) {
 | 
			
		||||
				cfn(c, true)
 | 
			
		||||
			})
 | 
			
		||||
		} else {
 | 
			
		||||
			_, isSrcErr, err = IoCopy(src, dst, func(c int) {
 | 
			
		||||
				cfn(c, true)
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			one.Do(func() {
 | 
			
		||||
				fn(isSrcErr, err)
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IoCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
 | 
			
		||||
	buf := make([]byte, 32*1024)
 | 
			
		||||
	for {
 | 
			
		||||
		nr, er := src.Read(buf)
 | 
			
		||||
		if nr > 0 {
 | 
			
		||||
			nw, ew := dst.Write(buf[0:nr])
 | 
			
		||||
			if nw > 0 {
 | 
			
		||||
				written += int64(nw)
 | 
			
		||||
				if len(fn) == 1 {
 | 
			
		||||
					fn[0](nw)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if ew != nil {
 | 
			
		||||
				err = ew
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			if nr != nw {
 | 
			
		||||
				err = io.ErrShortWrite
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if er != nil {
 | 
			
		||||
			err = er
 | 
			
		||||
			isSrcErr = true
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return written, isSrcErr, err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										105
									
								
								pkg/proxy/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								pkg/proxy/service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,105 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"net"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	servicesMap = new(sync.Map)
 | 
			
		||||
	logger      *zap.Logger
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	TCPConn TCP
 | 
			
		||||
	Name    string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Stop() {
 | 
			
		||||
	servicesMap.Delete(s.Name)
 | 
			
		||||
	s.TCPConn.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Run(name string, args TCPArgs, zapLogger *zap.Logger) *Service {
 | 
			
		||||
	logger = zapLogger
 | 
			
		||||
	service := &Service{
 | 
			
		||||
		TCPConn: &tcp{cfg: args},
 | 
			
		||||
		Name:    name,
 | 
			
		||||
	}
 | 
			
		||||
	store, loaded := servicesMap.LoadOrStore(name, service)
 | 
			
		||||
	if loaded {
 | 
			
		||||
		service = store.(*Service)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			recoverErr := recover()
 | 
			
		||||
			if recoverErr != nil {
 | 
			
		||||
				logger.Sugar().Errorf("%s servcie crashed, ERR: %s\ntrace:%s", name, recoverErr, string(debug.Stack()))
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
		startErr := service.TCPConn.Start()
 | 
			
		||||
		if startErr != nil {
 | 
			
		||||
			logger.Sugar().Errorf("%s servcie fail, ERR: %s", name, startErr)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
///////////////////////////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
type Listener struct {
 | 
			
		||||
	ip               string
 | 
			
		||||
	port             int
 | 
			
		||||
	Listener         net.Listener
 | 
			
		||||
	errAcceptHandler func(err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewListener(ip string, port int) Listener {
 | 
			
		||||
	return Listener{
 | 
			
		||||
		ip:   ip,
 | 
			
		||||
		port: port,
 | 
			
		||||
		errAcceptHandler: func(err error) {
 | 
			
		||||
			logger.Sugar().Errorf("accept error , ERR:%s", err)
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sc *Listener) ListenTCP(fn func(conn net.Conn)) (err error) {
 | 
			
		||||
	sc.Listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sc.ip, sc.port))
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer func() {
 | 
			
		||||
				if e := recover(); e != nil {
 | 
			
		||||
					logger.Sugar().Infof("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
 | 
			
		||||
				}
 | 
			
		||||
			}()
 | 
			
		||||
			for {
 | 
			
		||||
				var conn net.Conn
 | 
			
		||||
				conn, err = sc.Listener.Accept()
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					go func() {
 | 
			
		||||
						defer func() {
 | 
			
		||||
							if e := recover(); e != nil {
 | 
			
		||||
								logger.Sugar().Infof("connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
 | 
			
		||||
							}
 | 
			
		||||
						}()
 | 
			
		||||
						fn(conn)
 | 
			
		||||
					}()
 | 
			
		||||
				} else {
 | 
			
		||||
					sc.errAcceptHandler(err)
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sc *Listener) CloseListen() error {
 | 
			
		||||
	return sc.Listener.Close()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										94
									
								
								pkg/proxy/tcp.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								pkg/proxy/tcp.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,94 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ TCP = (*tcp)(nil)
 | 
			
		||||
 | 
			
		||||
type TCPArgs struct {
 | 
			
		||||
	Local       string      //监听地址
 | 
			
		||||
	Parent      string      //被代理地址
 | 
			
		||||
	Timeout     int         //拨号超时(毫秒)
 | 
			
		||||
	OutCallback func() bool //回调 //是否允许代理
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TCP interface {
 | 
			
		||||
	Start() (err error)
 | 
			
		||||
	Close()
 | 
			
		||||
 | 
			
		||||
	callback(inConn net.Conn)
 | 
			
		||||
	outToTCP(inConn net.Conn) (err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type tcp struct {
 | 
			
		||||
	inConn  net.Conn
 | 
			
		||||
	outConn net.Conn
 | 
			
		||||
	listen  Listener
 | 
			
		||||
	cfg     TCPArgs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *tcp) Start() (err error) {
 | 
			
		||||
	host, port, _ := net.SplitHostPort(s.cfg.Local)
 | 
			
		||||
	p, _ := strconv.Atoi(port)
 | 
			
		||||
	s.listen = NewListener(host, p)
 | 
			
		||||
	err = s.listen.ListenTCP(s.callback)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *tcp) Close() {
 | 
			
		||||
	if s.inConn != nil {
 | 
			
		||||
		CloseConn(s.inConn)
 | 
			
		||||
	}
 | 
			
		||||
	if s.outConn != nil {
 | 
			
		||||
		CloseConn(s.outConn)
 | 
			
		||||
	}
 | 
			
		||||
	_ = s.listen.CloseListen()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *tcp) callback(inConn net.Conn) {
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			logger.Sugar().Infof("conn handler crashed with err : %s", err)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if s.cfg.OutCallback != nil {
 | 
			
		||||
		if !s.cfg.OutCallback() {
 | 
			
		||||
			CloseConn(inConn)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := s.outToTCP(inConn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		CloseConn(inConn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.inConn = inConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *tcp) outToTCP(inConn net.Conn) error {
 | 
			
		||||
	outConn, err := ConnectHost(s.cfg.Parent, s.cfg.Timeout)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	inAddr := inConn.RemoteAddr().String()
 | 
			
		||||
	inLocalAddr := inConn.LocalAddr().String()
 | 
			
		||||
	outAddr := outConn.RemoteAddr().String()
 | 
			
		||||
	outLocalAddr := outConn.LocalAddr().String()
 | 
			
		||||
	IoBind(inConn, outConn, func(isSrcErr bool, err error) {
 | 
			
		||||
		CloseConn(inConn)
 | 
			
		||||
		CloseConn(outConn)
 | 
			
		||||
		logger.Sugar().Infof("conn %s - %s - %s -%s released", inAddr, inLocalAddr, outLocalAddr, outAddr)
 | 
			
		||||
	}, func(n int, d bool) {}, 0)
 | 
			
		||||
 | 
			
		||||
	logger.Sugar().Infof("conn %s - %s - %s -%s connected", inAddr, inLocalAddr, outLocalAddr, outAddr)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										236
									
								
								pkg/qiniu/storage.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										236
									
								
								pkg/qiniu/storage.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,236 @@
 | 
			
		||||
package qiniu
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/md5"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/tool"
 | 
			
		||||
	"github.com/qiniu/go-sdk/v7/auth/qbox"
 | 
			
		||||
	"github.com/qiniu/go-sdk/v7/storage"
 | 
			
		||||
	"github.com/tidwall/gjson"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ QiNiu = (*qiNiu)(nil)
 | 
			
		||||
 | 
			
		||||
type QiNiu interface {
 | 
			
		||||
	i()
 | 
			
		||||
	SetDefaultUploadTokenTTL(ttl uint64)
 | 
			
		||||
	GetCallbackUploadToken(ttl uint64, callbackURL string) string
 | 
			
		||||
	GetUploadToken(ttl uint64) string
 | 
			
		||||
	GetPrivateURL(key string, ttl uint64) string
 | 
			
		||||
	VerifyCallback(req *http.Request) (bool, error)
 | 
			
		||||
	UploadFile(key, localFile string) (*PutRet, error)
 | 
			
		||||
	ResumeUploadFile(key, localFile string) (*PutRet, error)
 | 
			
		||||
	DelFile(key string) error
 | 
			
		||||
	TimestampSecuritySign(path string, ttl time.Duration) string
 | 
			
		||||
	GetFileInfo(key string) *storage.FileInfo
 | 
			
		||||
	ListFiles(prefix, delimiter, marker string, limit int) (entries []storage.ListItem, commonPrefixes []string, nextMarker string, hasNext bool, err error)
 | 
			
		||||
	GetFileHash(path, qhash string) (hash string, err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type qiNiu struct {
 | 
			
		||||
	mac            *qbox.Mac
 | 
			
		||||
	bucketManager  *storage.BucketManager
 | 
			
		||||
	conf           *storage.Config
 | 
			
		||||
	bucket         string
 | 
			
		||||
	domain         string
 | 
			
		||||
	securityKey    string
 | 
			
		||||
	md5            md5.MD5
 | 
			
		||||
	uploadTokenTTL uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PutRet struct {
 | 
			
		||||
	Key    string `json:"key"`
 | 
			
		||||
	Hash   string `json:"hash"`
 | 
			
		||||
	Fsize  string `json:"fsize"`
 | 
			
		||||
	Fname  string `json:"fname"`
 | 
			
		||||
	Ext    string `json:"ext"`
 | 
			
		||||
	Unique string `json:"unique"`
 | 
			
		||||
	User   string `json:"user"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(accessKey, secretKey, bucket, domain, securityKey string) QiNiu {
 | 
			
		||||
	region, _ := storage.GetRegion(accessKey, bucket)
 | 
			
		||||
	mac := qbox.NewMac(accessKey, secretKey)
 | 
			
		||||
	conf := &storage.Config{
 | 
			
		||||
		Region:        region, //空间所在的存储区域
 | 
			
		||||
		UseHTTPS:      true,   //是否使用https域名
 | 
			
		||||
		UseCdnDomains: true,   //上传是否使用CDN上传加速
 | 
			
		||||
	}
 | 
			
		||||
	return &qiNiu{
 | 
			
		||||
		mac:            mac,
 | 
			
		||||
		bucketManager:  storage.NewBucketManager(mac, conf),
 | 
			
		||||
		bucket:         bucket,
 | 
			
		||||
		domain:         domain,
 | 
			
		||||
		securityKey:    securityKey,
 | 
			
		||||
		conf:           conf,
 | 
			
		||||
		md5:            md5.New(),
 | 
			
		||||
		uploadTokenTTL: 3600,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) i() {}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) SetDefaultUploadTokenTTL(ttl uint64) {
 | 
			
		||||
	q.uploadTokenTTL = ttl
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) GetUploadToken(ttl uint64) string {
 | 
			
		||||
	putPolicy := storage.PutPolicy{
 | 
			
		||||
		Scope:   q.bucket,
 | 
			
		||||
		Expires: ttl,
 | 
			
		||||
	}
 | 
			
		||||
	return putPolicy.UploadToken(q.mac)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) GetCallbackUploadToken(ttl uint64, callbackURL string) string {
 | 
			
		||||
	putPolicy := storage.PutPolicy{
 | 
			
		||||
		Scope:            q.bucket,
 | 
			
		||||
		CallbackURL:      callbackURL,
 | 
			
		||||
		CallbackBody:     `{"key":"$(key)","hash":"$(etag)","fname":"$(fname)","fsize":"$(fsize)","ext":"$(ext)","unique":"$(x:unique)","user":"$(x:user)"}`,
 | 
			
		||||
		CallbackBodyType: "application/json",
 | 
			
		||||
		Expires:          ttl,
 | 
			
		||||
	}
 | 
			
		||||
	return putPolicy.UploadToken(q.mac)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) GetPrivateURL(key string, ttl uint64) string {
 | 
			
		||||
	deadline := time.Now().Add(time.Second * time.Duration(ttl)).Unix()
 | 
			
		||||
	return storage.MakePrivateURL(q.mac, q.domain, key, deadline)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) VerifyCallback(req *http.Request) (bool, error) {
 | 
			
		||||
	return q.mac.VerifyCallback(req)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) UploadFile(key, localFile string) (*PutRet, error) {
 | 
			
		||||
	upToken := q.GetUploadToken(q.uploadTokenTTL)
 | 
			
		||||
 | 
			
		||||
	//构建表单上传的对象
 | 
			
		||||
	formUploader := storage.NewFormUploader(q.conf)
 | 
			
		||||
 | 
			
		||||
	//请求参数
 | 
			
		||||
	filename := path.Base(key)
 | 
			
		||||
	fileSuffix := path.Ext(key)
 | 
			
		||||
	filePrefix := filename[0 : len(filename)-len(fileSuffix)]
 | 
			
		||||
	putExtra := &storage.PutExtra{
 | 
			
		||||
		Params: map[string]string{
 | 
			
		||||
			"x:unique": filePrefix,
 | 
			
		||||
			"x:user":   "-",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//自定义返回body
 | 
			
		||||
	ret := new(PutRet)
 | 
			
		||||
 | 
			
		||||
	err := formUploader.PutFile(context.Background(), ret, upToken, key, localFile, putExtra)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ret, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) ResumeUploadFile(key, localFile string) (*PutRet, error) {
 | 
			
		||||
	upToken := q.GetUploadToken(q.uploadTokenTTL)
 | 
			
		||||
 | 
			
		||||
	//构建分片上传的对象
 | 
			
		||||
	resumeUploader := storage.NewResumeUploaderV2(q.conf)
 | 
			
		||||
 | 
			
		||||
	//请求参数
 | 
			
		||||
	filename := path.Base(key)
 | 
			
		||||
	fileSuffix := path.Ext(key)
 | 
			
		||||
	filePrefix := filename[0 : len(filename)-len(fileSuffix)]
 | 
			
		||||
	putExtra := &storage.RputV2Extra{
 | 
			
		||||
		CustomVars: map[string]string{
 | 
			
		||||
			"x:unique": filePrefix,
 | 
			
		||||
			"x:user":   "-",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//自定义返回body
 | 
			
		||||
	ret := new(PutRet)
 | 
			
		||||
 | 
			
		||||
	err := resumeUploader.PutFile(context.Background(), ret, upToken, key, localFile, putExtra)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ret, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) DelFile(key string) error {
 | 
			
		||||
	err := q.bucketManager.Delete(q.bucket, key)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) TimestampSecuritySign(path string, ttl time.Duration) string {
 | 
			
		||||
	sep := "/"
 | 
			
		||||
	path = strings.Trim(path, sep)
 | 
			
		||||
	splits := strings.Split(path, sep)
 | 
			
		||||
	for i, split := range splits {
 | 
			
		||||
		splits[i] = url.QueryEscape(split)
 | 
			
		||||
	}
 | 
			
		||||
	path = sep + strings.Join(splits, sep)
 | 
			
		||||
 | 
			
		||||
	unix := time.Now().Add(ttl).Unix()
 | 
			
		||||
	hex := fmt.Sprintf("%x", unix)
 | 
			
		||||
 | 
			
		||||
	encrypt := q.md5.Encrypt(q.securityKey + path + hex)
 | 
			
		||||
 | 
			
		||||
	param := make(url.Values)
 | 
			
		||||
	param.Set("sign", encrypt)
 | 
			
		||||
	param.Set("t", hex)
 | 
			
		||||
 | 
			
		||||
	return param.Encode()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) GetFileInfo(key string) *storage.FileInfo {
 | 
			
		||||
	fileInfo, sErr := q.bucketManager.Stat(q.bucket, key)
 | 
			
		||||
	if sErr != nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &fileInfo
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) ListFiles(prefix, delimiter, marker string, limit int) (entries []storage.ListItem,
 | 
			
		||||
	commonPrefixes []string, nextMarker string, hasNext bool, err error) {
 | 
			
		||||
	return q.bucketManager.ListFiles(q.bucket, prefix, delimiter, marker, limit)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *qiNiu) GetFileHash(path, qhash string) (hash string, err error) {
 | 
			
		||||
	if !tool.InArray(qhash, []string{"sha1", "md5", "sha256"}) {
 | 
			
		||||
		return "", errors.New("qhash invalid")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sign := q.TimestampSecuritySign(path, time.Second*5)
 | 
			
		||||
	addr := fmt.Sprintf("https://cdn.mogume.com/%s?%s&qhash/%s", path, sign, qhash)
 | 
			
		||||
 | 
			
		||||
	resp, err := http.Get(addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() { _ = resp.Body.Close() }()
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return "", errors.New(resp.Status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return gjson.GetBytes(body, "hash").String(), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										131
									
								
								pkg/rsa/rsa.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								pkg/rsa/rsa.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,131 @@
 | 
			
		||||
package rsa
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/rsa"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/pem"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Public = (*rsaPub)(nil)
 | 
			
		||||
var _ Private = (*rsaPri)(nil)
 | 
			
		||||
 | 
			
		||||
type Public interface {
 | 
			
		||||
	i()
 | 
			
		||||
	EncryptURLEncoding(encryptStr string) (string, error)
 | 
			
		||||
	Encrypt(encryptStr string) (string, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Private interface {
 | 
			
		||||
	i()
 | 
			
		||||
 | 
			
		||||
	Decrypt(decryptStr string) (string, error)
 | 
			
		||||
	DecryptURLEncoding(decryptStr string) (string, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rsaPub struct {
 | 
			
		||||
	PublicKey string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rsaPri struct {
 | 
			
		||||
	PrivateKey string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPublic(publicKey string) Public {
 | 
			
		||||
	return &rsaPub{
 | 
			
		||||
		PublicKey: publicKey,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPrivate(privateKey string) Private {
 | 
			
		||||
	return &rsaPri{
 | 
			
		||||
		PrivateKey: privateKey,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pub *rsaPub) i() {}
 | 
			
		||||
 | 
			
		||||
func (pub *rsaPub) Encrypt(encryptStr string) (string, error) {
 | 
			
		||||
	// pem 解码
 | 
			
		||||
	block, _ := pem.Decode([]byte(pub.PublicKey))
 | 
			
		||||
 | 
			
		||||
	// x509 解码
 | 
			
		||||
	publicKeyInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 类型断言
 | 
			
		||||
	publicKey := publicKeyInterface.(*rsa.PublicKey)
 | 
			
		||||
 | 
			
		||||
	//对明文进行加密
 | 
			
		||||
	encryptedStr, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey, []byte(encryptStr))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//返回密文
 | 
			
		||||
	return base64.StdEncoding.EncodeToString(encryptedStr), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pub *rsaPub) EncryptURLEncoding(encryptStr string) (string, error) {
 | 
			
		||||
	// pem 解码
 | 
			
		||||
	block, _ := pem.Decode([]byte(pub.PublicKey))
 | 
			
		||||
 | 
			
		||||
	// x509 解码
 | 
			
		||||
	publicKeyInterface, err := x509.ParsePKIXPublicKey(block.Bytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 类型断言
 | 
			
		||||
	publicKey := publicKeyInterface.(*rsa.PublicKey)
 | 
			
		||||
 | 
			
		||||
	//对明文进行加密
 | 
			
		||||
	encryptedStr, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey, []byte(encryptStr))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	//返回密文
 | 
			
		||||
	return base64.URLEncoding.EncodeToString(encryptedStr), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pri *rsaPri) i() {}
 | 
			
		||||
 | 
			
		||||
func (pri *rsaPri) Decrypt(decryptStr string) (string, error) {
 | 
			
		||||
	// pem 解码
 | 
			
		||||
	block, _ := pem.Decode([]byte(pri.PrivateKey))
 | 
			
		||||
 | 
			
		||||
	// X509 解码
 | 
			
		||||
	privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	decryptBytes, err := base64.StdEncoding.DecodeString(decryptStr)
 | 
			
		||||
 | 
			
		||||
	//对密文进行解密
 | 
			
		||||
	decrypted, _ := rsa.DecryptPKCS1v15(rand.Reader, privateKey, decryptBytes)
 | 
			
		||||
 | 
			
		||||
	//返回明文
 | 
			
		||||
	return string(decrypted), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pri *rsaPri) DecryptURLEncoding(decryptStr string) (string, error) {
 | 
			
		||||
	// pem 解码
 | 
			
		||||
	block, _ := pem.Decode([]byte(pri.PrivateKey))
 | 
			
		||||
 | 
			
		||||
	// X509 解码
 | 
			
		||||
	privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	decryptBytes, err := base64.URLEncoding.DecodeString(decryptStr)
 | 
			
		||||
 | 
			
		||||
	//对密文进行解密
 | 
			
		||||
	decrypted, _ := rsa.DecryptPKCS1v15(rand.Reader, privateKey, decryptBytes)
 | 
			
		||||
 | 
			
		||||
	//返回明文
 | 
			
		||||
	return string(decrypted), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										50
									
								
								pkg/shutdown/shutdown.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								pkg/shutdown/shutdown.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
package shutdown
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"syscall"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Hook = (*hook)(nil)
 | 
			
		||||
 | 
			
		||||
// Hook a graceful shutdown hook, default with signals of SIGINT and SIGTERM
 | 
			
		||||
type Hook interface {
 | 
			
		||||
	// WithSignals add more signals into hook
 | 
			
		||||
	WithSignals(signals ...syscall.Signal) Hook
 | 
			
		||||
 | 
			
		||||
	// Close register shutdown handles
 | 
			
		||||
	Close(funcs ...func())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type hook struct {
 | 
			
		||||
	ctx chan os.Signal
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewHook create a Hook instance
 | 
			
		||||
func NewHook() Hook {
 | 
			
		||||
	hook := &hook{
 | 
			
		||||
		ctx: make(chan os.Signal, 1),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return hook.WithSignals(syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL, syscall.SIGQUIT)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *hook) WithSignals(signals ...syscall.Signal) Hook {
 | 
			
		||||
	for _, s := range signals {
 | 
			
		||||
		signal.Notify(h.ctx, s)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *hook) Close(funcs ...func()) {
 | 
			
		||||
	select {
 | 
			
		||||
	case <-h.ctx:
 | 
			
		||||
	}
 | 
			
		||||
	signal.Stop(h.ctx)
 | 
			
		||||
 | 
			
		||||
	for _, f := range funcs {
 | 
			
		||||
		f()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										52
									
								
								pkg/signature/signature.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								pkg/signature/signature.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
package signature
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Signature = (*signature)(nil)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	delimiter = "|"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 合法的 Methods
 | 
			
		||||
var methods = map[string]bool{
 | 
			
		||||
	http.MethodGet:     true,
 | 
			
		||||
	http.MethodPost:    true,
 | 
			
		||||
	http.MethodHead:    true,
 | 
			
		||||
	http.MethodPut:     true,
 | 
			
		||||
	http.MethodPatch:   true,
 | 
			
		||||
	http.MethodDelete:  true,
 | 
			
		||||
	http.MethodConnect: true,
 | 
			
		||||
	http.MethodOptions: true,
 | 
			
		||||
	http.MethodTrace:   true,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Signature interface {
 | 
			
		||||
	i()
 | 
			
		||||
 | 
			
		||||
	// Generate 生成签名
 | 
			
		||||
	Generate(path string, method string, params url.Values) (authorization, date string, err error)
 | 
			
		||||
 | 
			
		||||
	// Verify 验证签名
 | 
			
		||||
	Verify(authorization, date string, path string, method string, params url.Values) (ok bool, err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type signature struct {
 | 
			
		||||
	key    string
 | 
			
		||||
	secret string
 | 
			
		||||
	ttl    time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(key, secret string, ttl time.Duration) Signature {
 | 
			
		||||
	return &signature{
 | 
			
		||||
		key:    key,
 | 
			
		||||
		secret: secret,
 | 
			
		||||
		ttl:    ttl,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *signature) i() {}
 | 
			
		||||
							
								
								
									
										62
									
								
								pkg/signature/signature_generate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								pkg/signature/signature_generate.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,62 @@
 | 
			
		||||
package signature
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Generate
 | 
			
		||||
// path 请求的路径 (不附带 querystring)
 | 
			
		||||
func (s *signature) Generate(path string, method string, params url.Values) (authorization, date string, err error) {
 | 
			
		||||
	if path == "" {
 | 
			
		||||
		err = errors.New("path required")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if method == "" {
 | 
			
		||||
		err = errors.New("method required")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	methodName := strings.ToUpper(method)
 | 
			
		||||
	if !methods[methodName] {
 | 
			
		||||
		err = errors.New("method param error")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Date
 | 
			
		||||
	date = time_parse.CSTLayoutString()
 | 
			
		||||
 | 
			
		||||
	// Encode() 方法中自带 sorted by key
 | 
			
		||||
	sortParamsEncode, err := url.QueryUnescape(params.Encode())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = fmt.Errorf("url QueryUnescape error [%v]", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 加密字符串规则
 | 
			
		||||
	buffer := bytes.NewBuffer(nil)
 | 
			
		||||
	buffer.WriteString(path)
 | 
			
		||||
	buffer.WriteString(delimiter)
 | 
			
		||||
	buffer.WriteString(methodName)
 | 
			
		||||
	buffer.WriteString(delimiter)
 | 
			
		||||
	buffer.WriteString(sortParamsEncode)
 | 
			
		||||
	buffer.WriteString(delimiter)
 | 
			
		||||
	buffer.WriteString(date)
 | 
			
		||||
 | 
			
		||||
	// 对数据进行 sha256 加密,并进行 base64 encode
 | 
			
		||||
	hash := hmac.New(sha256.New, []byte(s.secret))
 | 
			
		||||
	hash.Write(buffer.Bytes())
 | 
			
		||||
	digest := base64.StdEncoding.EncodeToString(hash.Sum(nil))
 | 
			
		||||
 | 
			
		||||
	authorization = fmt.Sprintf("%s %s", s.key, digest)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										73
									
								
								pkg/signature/signature_verify.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								pkg/signature/signature_verify.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,73 @@
 | 
			
		||||
package signature
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *signature) Verify(authorization, date string, path string, method string, params url.Values) (ok bool, err error) {
 | 
			
		||||
	if date == "" {
 | 
			
		||||
		err = errors.New("date required")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if path == "" {
 | 
			
		||||
		err = errors.New("path required")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if method == "" {
 | 
			
		||||
		err = errors.New("method required")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	methodName := strings.ToUpper(method)
 | 
			
		||||
	if !methods[methodName] {
 | 
			
		||||
		err = errors.New("method param error")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ts, err := time_parse.ParseCSTInLocation(date)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = errors.New("date must follow '2006-01-02 15:04:05'")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if time_parse.SubInLocation(ts) > float64(s.ttl/time.Second) {
 | 
			
		||||
		err = fmt.Errorf("date exceeds limit [%v]", s.ttl)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode() 方法中自带 sorted by key
 | 
			
		||||
	sortParamsEncode, err := url.QueryUnescape(params.Encode())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err = fmt.Errorf("url QueryUnescape error [%v]", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buffer := bytes.NewBuffer(nil)
 | 
			
		||||
	buffer.WriteString(path)
 | 
			
		||||
	buffer.WriteString(delimiter)
 | 
			
		||||
	buffer.WriteString(methodName)
 | 
			
		||||
	buffer.WriteString(delimiter)
 | 
			
		||||
	buffer.WriteString(sortParamsEncode)
 | 
			
		||||
	buffer.WriteString(delimiter)
 | 
			
		||||
	buffer.WriteString(date)
 | 
			
		||||
 | 
			
		||||
	// 对数据进行 hmac 加密,并进行 base64 encode
 | 
			
		||||
	hash := hmac.New(sha256.New, []byte(s.secret))
 | 
			
		||||
	hash.Write(buffer.Bytes())
 | 
			
		||||
	digest := base64.StdEncoding.EncodeToString(hash.Sum(nil))
 | 
			
		||||
 | 
			
		||||
	ok = authorization == fmt.Sprintf("%s %s", s.key, digest)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										26
									
								
								pkg/sse/client.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								pkg/sse/client.html
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
			
		||||
<!doctype html>
 | 
			
		||||
<html lang="en">
 | 
			
		||||
 | 
			
		||||
<head>
 | 
			
		||||
    <meta charset="UTF-8">
 | 
			
		||||
    <title>Server Sent Event</title>
 | 
			
		||||
</head>
 | 
			
		||||
 | 
			
		||||
<body>
 | 
			
		||||
<div id="event-data"></div>
 | 
			
		||||
</body>
 | 
			
		||||
 | 
			
		||||
<script>
 | 
			
		||||
    const $stream = new EventSource("/stream");
 | 
			
		||||
    const $log = document.querySelector('#event-data')
 | 
			
		||||
    $stream.addEventListener("message", function (e) {
 | 
			
		||||
        log(e.data)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    function log(msg) {
 | 
			
		||||
        $log.innerHTML += `<p>消息: ${msg}</p>`
 | 
			
		||||
        $log.scrollTop += 1000
 | 
			
		||||
    }
 | 
			
		||||
</script>
 | 
			
		||||
 | 
			
		||||
</html>
 | 
			
		||||
							
								
								
									
										118
									
								
								pkg/sse/server.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								pkg/sse/server.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,118 @@
 | 
			
		||||
package sse
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Server = (*event)(nil)
 | 
			
		||||
 | 
			
		||||
type Server interface {
 | 
			
		||||
	HandlerFunc(auth func(c *gin.Context) (string, error)) gin.HandlerFunc
 | 
			
		||||
	Push(user, name, msg string) bool
 | 
			
		||||
	Broadcast(name, msg string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type clientChan struct {
 | 
			
		||||
	User string
 | 
			
		||||
	Chan chan msgChan
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type msgChan struct {
 | 
			
		||||
	Name    string
 | 
			
		||||
	Message string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type event struct {
 | 
			
		||||
	SessionList sync.Map
 | 
			
		||||
	Count       atomic.Int32
 | 
			
		||||
 | 
			
		||||
	Register   chan clientChan
 | 
			
		||||
	Unregister chan string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServer() Server {
 | 
			
		||||
	e := &event{
 | 
			
		||||
		SessionList: sync.Map{},
 | 
			
		||||
		Count:       atomic.Int32{},
 | 
			
		||||
		Register:    make(chan clientChan),
 | 
			
		||||
		Unregister:  make(chan string),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go e.listen()
 | 
			
		||||
 | 
			
		||||
	return e
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stream *event) listen() {
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case client := <-stream.Register:
 | 
			
		||||
			stream.SessionList.Store(client.User, client.Chan)
 | 
			
		||||
			stream.Count.Add(1)
 | 
			
		||||
		case user := <-stream.Unregister:
 | 
			
		||||
			value, ok := stream.SessionList.Load(user)
 | 
			
		||||
			if ok {
 | 
			
		||||
				event := value.(chan msgChan)
 | 
			
		||||
				close(event)
 | 
			
		||||
				stream.SessionList.Delete(user)
 | 
			
		||||
				stream.Count.Add(-1)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stream *event) HandlerFunc(auth func(c *gin.Context) (string, error)) gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		user, err := auth(c)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.AbortWithStatus(http.StatusBadRequest)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		e := make(chan msgChan)
 | 
			
		||||
		client := clientChan{
 | 
			
		||||
			User: user,
 | 
			
		||||
			Chan: e,
 | 
			
		||||
		}
 | 
			
		||||
		stream.Register <- client
 | 
			
		||||
		defer func() {
 | 
			
		||||
			stream.Unregister <- user
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
		c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
		c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
		c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
 | 
			
		||||
		c.Stream(func(w io.Writer) bool {
 | 
			
		||||
			if msg, ok := <-e; ok {
 | 
			
		||||
				c.SSEvent(msg.Name, msg.Message)
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			return false
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stream *event) Push(user, name, msg string) bool {
 | 
			
		||||
	value, ok := stream.SessionList.Load(user)
 | 
			
		||||
	if ok {
 | 
			
		||||
		val := value.(chan msgChan)
 | 
			
		||||
		val <- msgChan{Name: name, Message: msg}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stream *event) Broadcast(name, msg string) {
 | 
			
		||||
	stream.SessionList.Range(func(user, value any) bool {
 | 
			
		||||
		val := value.(chan msgChan)
 | 
			
		||||
		val <- msgChan{Name: name, Message: msg}
 | 
			
		||||
		return true
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										52
									
								
								pkg/ticker/ticker.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								pkg/ticker/ticker.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
package ticker
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Ticker = (*ticker)(nil)
 | 
			
		||||
 | 
			
		||||
type Ticker interface {
 | 
			
		||||
	worker()
 | 
			
		||||
 | 
			
		||||
	Process(fun func())
 | 
			
		||||
	Stop()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ticker struct {
 | 
			
		||||
	ticker *time.Ticker
 | 
			
		||||
	ctx    context.Context
 | 
			
		||||
	cancel context.CancelFunc
 | 
			
		||||
	f      func()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(d time.Duration) Ticker {
 | 
			
		||||
	ctx, cancelFunc := context.WithCancel(context.Background())
 | 
			
		||||
	return &ticker{
 | 
			
		||||
		ticker: time.NewTicker(d),
 | 
			
		||||
		ctx:    ctx,
 | 
			
		||||
		cancel: cancelFunc,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *ticker) worker() {
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-t.ticker.C:
 | 
			
		||||
			t.f()
 | 
			
		||||
		case <-t.ctx.Done():
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *ticker) Process(fun func()) {
 | 
			
		||||
	t.f = fun
 | 
			
		||||
	go t.worker()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *ticker) Stop() {
 | 
			
		||||
	t.ticker.Stop()
 | 
			
		||||
	t.cancel()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										124
									
								
								pkg/time_parse/time_parse.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								pkg/time_parse/time_parse.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,124 @@
 | 
			
		||||
package time_parse
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/jinzhu/now"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	cst = time.Local
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetPRCLocation() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	if cst, err = time.LoadLocation("Asia/Shanghai"); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Now() *now.Now {
 | 
			
		||||
	timeFormats := append(append(now.TimeFormats, time.DateTime), time.DateOnly)
 | 
			
		||||
 | 
			
		||||
	c := &now.Config{
 | 
			
		||||
		WeekStartDay: time.Monday,
 | 
			
		||||
		TimeLocation: cst,
 | 
			
		||||
		TimeFormats:  timeFormats,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c.With(time.Now().In(cst))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RFC3339ToCSTLayout convert rfc3339 value to China standard time layout
 | 
			
		||||
// 2020-11-08T08:18:46+08:00 => 2020-11-08 08:18:46
 | 
			
		||||
func RFC3339ToCSTLayout(value string) (string, error) {
 | 
			
		||||
	ts, err := time.Parse(time.RFC3339, value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ts.In(cst).Format(time.DateTime), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetMilliTimestamp 获取CST毫秒时间戳
 | 
			
		||||
func GetMilliTimestamp() int64 {
 | 
			
		||||
	return time.Now().In(cst).UnixMilli()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetTimestamp 获取CST时间戳
 | 
			
		||||
func GetTimestamp() int64 {
 | 
			
		||||
	return time.Now().In(cst).Unix()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CSTLayoutString 格式化时间
 | 
			
		||||
// 返回 "2006-01-02 15:04:05" 格式的时间
 | 
			
		||||
func CSTLayoutString() string {
 | 
			
		||||
	return time.Now().In(cst).Format(time.DateTime)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseCSTInLocation 格式化时间
 | 
			
		||||
func ParseCSTInLocation(date string) (time.Time, error) {
 | 
			
		||||
	return time.ParseInLocation(time.DateTime, date, cst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CSTLayoutStringToUnix 返回 unix 时间戳
 | 
			
		||||
// 2020-01-24 21:11:11 => 1579871471
 | 
			
		||||
func CSTLayoutStringToUnix(cstLayoutString string) (int64, error) {
 | 
			
		||||
	stamp, err := time.ParseInLocation(time.DateTime, cstLayoutString, cst)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	return stamp.Unix(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GMTLayoutString 格式化时间
 | 
			
		||||
// 返回 "Mon, 02 Jan 2006 15:04:05 GMT" 格式的时间
 | 
			
		||||
func GMTLayoutString() string {
 | 
			
		||||
	return time.Now().In(cst).Format(http.TimeFormat)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ParseGMTInLocation 格式化时间
 | 
			
		||||
func ParseGMTInLocation(date string) (time.Time, error) {
 | 
			
		||||
	return time.ParseInLocation(http.TimeFormat, date, cst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SubInLocation 计算时间差
 | 
			
		||||
func SubInLocation(ts time.Time) float64 {
 | 
			
		||||
	return math.Abs(time.Now().In(cst).Sub(ts).Seconds())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetFirstDateOfMonth 获取传入的时间所在月份的第一天,即某月第一天的0点。如传入time.Now(), 返回当前月份的第一天0点时间。
 | 
			
		||||
func GetFirstDateOfMonth(d time.Time) time.Time {
 | 
			
		||||
	d = d.AddDate(0, 0, -d.Day()+1)
 | 
			
		||||
	return GetStartDayTime(d)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetLastDateOfMonth 获取传入的时间所在月份的最后一天,即某月最后一天的0点。如传入time.Now(), 返回当前月份的最后一天0点时间。
 | 
			
		||||
func GetLastDateOfMonth(d time.Time) time.Time {
 | 
			
		||||
	return GetEndDayTime(GetFirstDateOfMonth(d).AddDate(0, 1, -1))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetStartDayTime 获取某一天的开始时间
 | 
			
		||||
func GetStartDayTime(d time.Time) time.Time {
 | 
			
		||||
	return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, cst)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetEndDayTime 获取某一天的结束时间
 | 
			
		||||
func GetEndDayTime(d time.Time) time.Time {
 | 
			
		||||
	return GetStartDayTime(d).Add(24 * time.Hour).Add(-time.Second)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetStartYesterdayTime 获取昨天的起始时间
 | 
			
		||||
func GetStartYesterdayTime() time.Time {
 | 
			
		||||
	d, _ := time.ParseDuration("-24h")
 | 
			
		||||
	return GetStartDayTime(time.Now().Add(d))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SubInLocationDays 计算两个日期相差的天数
 | 
			
		||||
func SubInLocationDays(t1, t2 time.Time) float64 {
 | 
			
		||||
	t1 = time.Date(t1.Year(), t1.Month(), t1.Day(), 0, 0, 0, 0, cst)
 | 
			
		||||
	t2 = time.Date(t2.Year(), t2.Month(), t2.Day(), 0, 0, 0, 0, cst)
 | 
			
		||||
	return math.Abs(t1.Sub(t2).Hours() / 24)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										20
									
								
								pkg/token/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								pkg/token/README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
## 与 UrlSign 对应的 PHP 加密算法
 | 
			
		||||
 | 
			
		||||
```php
 | 
			
		||||
// 对 params key 进行排序
 | 
			
		||||
ksort($params);
 | 
			
		||||
 | 
			
		||||
// 对 sortParams 进行 Encode
 | 
			
		||||
$sortParamsEncode = http_build_query($params);
 | 
			
		||||
 | 
			
		||||
// 加密字符串规则 path + method + sortParamsEncode + secret
 | 
			
		||||
$encryptStr = $path . $method . $sortParamsEncode . $secret
 | 
			
		||||
 | 
			
		||||
// 对加密字符串进行 md5
 | 
			
		||||
$md5Str = md5($encryptStr);
 | 
			
		||||
 | 
			
		||||
// 对 md5Str 进行 base64 encode
 | 
			
		||||
$tokenString = base64_encode($md5Str);
 | 
			
		||||
 | 
			
		||||
echo $tokenString;
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										36
									
								
								pkg/token/token.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								pkg/token/token.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
package token
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/golang-jwt/jwt/v4"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Token = (*token)(nil)
 | 
			
		||||
 | 
			
		||||
type Token interface {
 | 
			
		||||
	// i 为了避免被其他包实现
 | 
			
		||||
	i()
 | 
			
		||||
 | 
			
		||||
	// JwtSign JWT 签名方式
 | 
			
		||||
	JwtSign(userId, subject string, expireDuration time.Duration) (tokenString string, err error)
 | 
			
		||||
	JwtParse(tokenString string) (*jwt.RegisteredClaims, error)
 | 
			
		||||
 | 
			
		||||
	// UrlSign URL 签名方式,不支持解密
 | 
			
		||||
	UrlSign(path string, method string, params url.Values) (tokenString string, err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type token struct {
 | 
			
		||||
	secret string
 | 
			
		||||
	domain []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(secret string, domain ...string) Token {
 | 
			
		||||
	return &token{
 | 
			
		||||
		secret: secret,
 | 
			
		||||
		domain: domain,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *token) i() {}
 | 
			
		||||
							
								
								
									
										43
									
								
								pkg/token/token_jwt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								pkg/token/token_jwt.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
			
		||||
package token
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/golang-jwt/jwt/v4"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (t *token) JwtSign(userId, subject string, expireDuration time.Duration) (tokenString string, err error) {
 | 
			
		||||
	// The token content.
 | 
			
		||||
	// iss: (Issuer)签发者
 | 
			
		||||
	// iat: (Issued At)签发时间,用Unix时间戳表示
 | 
			
		||||
	// exp: (Expiration Time)过期时间,用Unix时间戳表示
 | 
			
		||||
	// aud: (Audience)接收该JWT的一方
 | 
			
		||||
	// sub: (Subject)该JWT的主题
 | 
			
		||||
	// nbf: (Not Before)不要早于这个时间
 | 
			
		||||
	// jti: (JWT ID)用于标识JWT的唯一ID
 | 
			
		||||
	c := &jwt.RegisteredClaims{
 | 
			
		||||
		Issuer:    "BvBeJ",
 | 
			
		||||
		Subject:   subject,
 | 
			
		||||
		Audience:  jwt.ClaimStrings(t.domain),
 | 
			
		||||
		ExpiresAt: jwt.NewNumericDate(time.Now().Add(expireDuration)),
 | 
			
		||||
		NotBefore: jwt.NewNumericDate(time.Now()),
 | 
			
		||||
		IssuedAt:  jwt.NewNumericDate(time.Now()),
 | 
			
		||||
		ID:        userId,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tokenString, err = jwt.NewWithClaims(jwt.SigningMethodHS256, c).SignedString([]byte(t.secret))
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *token) JwtParse(tokenString string) (*jwt.RegisteredClaims, error) {
 | 
			
		||||
	tokenClaims, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) {
 | 
			
		||||
		return []byte(t.secret), nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if tokenClaims != nil {
 | 
			
		||||
		if c, ok := tokenClaims.Claims.(*jwt.RegisteredClaims); ok && tokenClaims.Valid {
 | 
			
		||||
			return c, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										48
									
								
								pkg/token/token_url.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								pkg/token/token_url.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
			
		||||
package token
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// UrlSign
 | 
			
		||||
// path 请求的路径 (不附带 querystring)
 | 
			
		||||
func (t *token) UrlSign(path string, method string, params url.Values) (tokenString string, err error) {
 | 
			
		||||
	// 合法的 Methods
 | 
			
		||||
	methods := map[string]bool{
 | 
			
		||||
		"get":     true,
 | 
			
		||||
		"post":    true,
 | 
			
		||||
		"put":     true,
 | 
			
		||||
		"path":    true,
 | 
			
		||||
		"delete":  true,
 | 
			
		||||
		"head":    true,
 | 
			
		||||
		"options": true,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	methodName := strings.ToLower(method)
 | 
			
		||||
	if !methods[methodName] {
 | 
			
		||||
		err = errors.New("method param error")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode() 方法中自带 sorted by key
 | 
			
		||||
	sortParamsEncode := params.Encode()
 | 
			
		||||
 | 
			
		||||
	// 加密字符串规则 path + method + sortParamsEncode + secret
 | 
			
		||||
	encryptStr := fmt.Sprintf("%s%s%s%s", path, methodName, sortParamsEncode, t.secret)
 | 
			
		||||
 | 
			
		||||
	// 对加密字符串进行 md5
 | 
			
		||||
	s := md5.New()
 | 
			
		||||
	s.Write([]byte(encryptStr))
 | 
			
		||||
	md5Str := hex.EncodeToString(s.Sum(nil))
 | 
			
		||||
 | 
			
		||||
	// 对 md5Str 进行 base64 encode
 | 
			
		||||
	tokenString = base64.StdEncoding.EncodeToString([]byte(md5Str))
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										81
									
								
								pkg/trace/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								pkg/trace/README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
			
		||||
## trace
 | 
			
		||||
 | 
			
		||||
一个用于开发调试的辅助工具。
 | 
			
		||||
 | 
			
		||||
可以实时显示当前页面的操作的请求信息、运行情况、SQL执行、错误提示等。
 | 
			
		||||
 | 
			
		||||
- `trace.go` 主入口文件;
 | 
			
		||||
- `dialog.go` 处理 third_party_requests 记录;
 | 
			
		||||
- `debug.go` 处理 debug 记录;
 | 
			
		||||
 | 
			
		||||
#### 数据格式
 | 
			
		||||
 | 
			
		||||
##### trace_id
 | 
			
		||||
 | 
			
		||||
当前 trace 的 ID,例如:938ff86be98439c6c1a7,便于搜索使用。
 | 
			
		||||
 | 
			
		||||
##### request
 | 
			
		||||
 | 
			
		||||
请求信息,会包括:
 | 
			
		||||
 | 
			
		||||
- ttl 请求超时时间,例如:2s 或 un-limit
 | 
			
		||||
- method 请求方式,例如:GET 或 POST
 | 
			
		||||
- decoded_url 请求地址
 | 
			
		||||
- header 请求头信息
 | 
			
		||||
- body 请求体信息
 | 
			
		||||
 | 
			
		||||
##### response
 | 
			
		||||
 | 
			
		||||
- header 响应头信息
 | 
			
		||||
- body 响应提信息
 | 
			
		||||
- business_code 业务码,例如:10010
 | 
			
		||||
- business_code_msg 业务码信息,例如:签名错误
 | 
			
		||||
- http_code HTTP 状态码,例如:200
 | 
			
		||||
- http_code_msg HTTP 状态码信息,例如:OK
 | 
			
		||||
- cost_seconds 耗费时长:单位秒,比如 0.001105661
 | 
			
		||||
 | 
			
		||||
##### third_party_requests
 | 
			
		||||
 | 
			
		||||
每一个第三方 http 请求都会生成如下的一组数据,多个请求会生成多组数据。
 | 
			
		||||
 | 
			
		||||
- request,同上 request 结构一致
 | 
			
		||||
- response,同上 response 结构一致
 | 
			
		||||
- success,是否成功,true 或 false
 | 
			
		||||
- cost_seconds,耗费时长:单位秒
 | 
			
		||||
 | 
			
		||||
注意:response 中的 business_code、business_code_msg 为空,因为各个第三方返回结构不同,这两个字段为空。
 | 
			
		||||
 | 
			
		||||
##### sqls
 | 
			
		||||
 | 
			
		||||
执行的 SQL 信息,多个 SQL 会记录多组数据。
 | 
			
		||||
 | 
			
		||||
- timestamp,时间,格式:2006-01-02 15:04:05
 | 
			
		||||
- stack,文件地址和行号
 | 
			
		||||
- cost_seconds,执行时长,单位:秒
 | 
			
		||||
- sql,SQL 语句
 | 
			
		||||
- rows_affected,影响行数
 | 
			
		||||
 | 
			
		||||
##### debugs
 | 
			
		||||
 | 
			
		||||
- key 打印的标示
 | 
			
		||||
- value 打印的值
 | 
			
		||||
 | 
			
		||||
```cassandraql
 | 
			
		||||
// 调试时,使用这个方法:
 | 
			
		||||
p.Print("key", "value", p.WithTrace(c.Trace()))
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
只有参数中增加了 `p.WithTrace(c.Trace())`,才会记录到 `debugs` 中。
 | 
			
		||||
 | 
			
		||||
##### success
 | 
			
		||||
 | 
			
		||||
是否成功,true 或 false
 | 
			
		||||
 | 
			
		||||
```cassandraql
 | 
			
		||||
success = !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
##### cost_seconds
 | 
			
		||||
 | 
			
		||||
耗费时长:单位秒,比如 0.001105661
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								pkg/trace/debug.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								pkg/trace/debug.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
			
		||||
package trace
 | 
			
		||||
 | 
			
		||||
type Debug struct {
 | 
			
		||||
	Key         string  `json:"key"`          // 标示
 | 
			
		||||
	Value       any     `json:"value"`        // 值
 | 
			
		||||
	CostSeconds float64 `json:"cost_seconds"` // 执行时间(单位秒)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										32
									
								
								pkg/trace/dialog.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								pkg/trace/dialog.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
			
		||||
package trace
 | 
			
		||||
 | 
			
		||||
import "sync"
 | 
			
		||||
 | 
			
		||||
var _ D = (*Dialog)(nil)
 | 
			
		||||
 | 
			
		||||
type D interface {
 | 
			
		||||
	i()
 | 
			
		||||
	AppendResponse(resp *Response)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Dialog 内部调用其它方接口的会话信息;失败时会有retry操作,所以 response 会有多次。
 | 
			
		||||
type Dialog struct {
 | 
			
		||||
	mux         sync.Mutex
 | 
			
		||||
	Request     *Request    `json:"request"`      // 请求信息
 | 
			
		||||
	Responses   []*Response `json:"responses"`    // 返回信息
 | 
			
		||||
	Success     bool        `json:"success"`      // 是否成功,true 或 false
 | 
			
		||||
	CostSeconds float64     `json:"cost_seconds"` // 执行时长(单位秒)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Dialog) i() {}
 | 
			
		||||
 | 
			
		||||
// AppendResponse 按转的追加response信息
 | 
			
		||||
func (d *Dialog) AppendResponse(resp *Response) {
 | 
			
		||||
	if resp == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	d.mux.Lock()
 | 
			
		||||
	d.Responses = append(d.Responses, resp)
 | 
			
		||||
	d.mux.Unlock()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										17
									
								
								pkg/trace/grpc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								pkg/trace/grpc.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package trace
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"google.golang.org/grpc/metadata"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Grpc struct {
 | 
			
		||||
	Timestamp   string         `json:"timestamp"`             // 时间,格式:2006-01-02 15:04:05
 | 
			
		||||
	Addr        string         `json:"addr"`                  // 地址
 | 
			
		||||
	Method      string         `json:"method"`                // 操作方法
 | 
			
		||||
	Meta        metadata.MD    `json:"meta"`                  // Mate
 | 
			
		||||
	Request     map[string]any `json:"request"`               // 请求信息
 | 
			
		||||
	Response    map[string]any `json:"response"`              // 返回信息
 | 
			
		||||
	CostSeconds float64        `json:"cost_seconds"`          // 执行时间(单位秒)
 | 
			
		||||
	Code        string         `json:"err_code,omitempty"`    // 错误码
 | 
			
		||||
	Message     string         `json:"err_message,omitempty"` // 错误信息
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										10
									
								
								pkg/trace/redis.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								pkg/trace/redis.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
package trace
 | 
			
		||||
 | 
			
		||||
type Redis struct {
 | 
			
		||||
	Timestamp   string  `json:"timestamp"`       // 时间,格式:2006-01-02 15:04:05
 | 
			
		||||
	Handle      string  `json:"handle"`          // 操作,SET/GET 等
 | 
			
		||||
	Key         string  `json:"key"`             // Key
 | 
			
		||||
	Value       string  `json:"value,omitempty"` // Value
 | 
			
		||||
	TTL         float64 `json:"ttl,omitempty"`   // 超时时长(单位分)
 | 
			
		||||
	CostSeconds float64 `json:"cost_seconds"`    // 执行时间(单位秒)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								pkg/trace/sql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								pkg/trace/sql.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
package trace
 | 
			
		||||
 | 
			
		||||
type SQL struct {
 | 
			
		||||
	Timestamp   string  `json:"timestamp"`     // 时间,格式:2006-01-02 15:04:05
 | 
			
		||||
	Stack       string  `json:"stack"`         // 文件地址和行号
 | 
			
		||||
	SQL         string  `json:"sql"`           // SQL 语句
 | 
			
		||||
	Rows        int64   `json:"rows_affected"` // 影响行数
 | 
			
		||||
	CostSeconds float64 `json:"cost_seconds"`  // 执行时长(单位秒)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										154
									
								
								pkg/trace/trace.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								pkg/trace/trace.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,154 @@
 | 
			
		||||
package trace
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const Header = "X-TRACE-ID"
 | 
			
		||||
 | 
			
		||||
var _ T = (*Trace)(nil)
 | 
			
		||||
 | 
			
		||||
type T interface {
 | 
			
		||||
	i()
 | 
			
		||||
	ID() string
 | 
			
		||||
	WithRequest(req *Request) *Trace
 | 
			
		||||
	WithResponse(resp *Response) *Trace
 | 
			
		||||
	AppendDialog(dialog *Dialog) *Trace
 | 
			
		||||
	AppendDebug(debug *Debug) *Trace
 | 
			
		||||
	AppendSQL(sql *SQL) *Trace
 | 
			
		||||
	AppendRedis(redis *Redis) *Trace
 | 
			
		||||
	AppendGRPC(grpc *Grpc) *Trace
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Trace 记录的参数
 | 
			
		||||
type Trace struct {
 | 
			
		||||
	mux                sync.Mutex
 | 
			
		||||
	Identifier         string    `json:"trace_id"`                       // 链路ID
 | 
			
		||||
	Request            *Request  `json:"request"`                        // 请求信息
 | 
			
		||||
	Response           *Response `json:"response"`                       // 返回信息
 | 
			
		||||
	ThirdPartyRequests []*Dialog `json:"third_party_requests,omitempty"` // 调用第三方接口的信息
 | 
			
		||||
	Debugs             []*Debug  `json:"debugs,omitempty"`               // 调试信息
 | 
			
		||||
	SQLs               []*SQL    `json:"sqls,omitempty"`                 // 执行的 SQL 信息
 | 
			
		||||
	Redis              []*Redis  `json:"redis,omitempty"`                // 执行的 Redis 信息
 | 
			
		||||
	GRPCs              []*Grpc   `json:"grpc,omitempty"`                 // 执行的 gRPC 信息
 | 
			
		||||
	Success            bool      `json:"success"`                        // 请求结果 true or false
 | 
			
		||||
	CostSeconds        float64   `json:"cost_seconds"`                   // 执行时长(单位秒)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Request 请求信息
 | 
			
		||||
type Request struct {
 | 
			
		||||
	TTL        string `json:"ttl,omitempty"` // 请求超时时间
 | 
			
		||||
	Method     string `json:"method"`        // 请求方式
 | 
			
		||||
	DecodedURL string `json:"decoded_url"`   // 请求地址
 | 
			
		||||
	Header     any    `json:"header"`        // 请求 Header 信息
 | 
			
		||||
	Body       any    `json:"body"`          // 请求 Body 信息
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Response 响应信息
 | 
			
		||||
type Response struct {
 | 
			
		||||
	Header          any     `json:"header"`                      // Header 信息
 | 
			
		||||
	Body            any     `json:"body"`                        // Body 信息
 | 
			
		||||
	BusinessCode    int     `json:"business_code,omitempty"`     // 业务码
 | 
			
		||||
	BusinessCodeMsg string  `json:"business_code_msg,omitempty"` // 提示信息
 | 
			
		||||
	HttpCode        int     `json:"http_code"`                   // HTTP 状态码
 | 
			
		||||
	HttpCodeMsg     string  `json:"http_code_msg"`               // HTTP 状态码信息
 | 
			
		||||
	CostSeconds     float64 `json:"cost_seconds"`                // 执行时间(单位秒)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(id string) *Trace {
 | 
			
		||||
	if id == "" {
 | 
			
		||||
		buf := make([]byte, 10)
 | 
			
		||||
		_, _ = rand.Read(buf)
 | 
			
		||||
		id = hex.EncodeToString(buf)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Trace{
 | 
			
		||||
		Identifier: id,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *Trace) i() {}
 | 
			
		||||
 | 
			
		||||
// ID 唯一标识符
 | 
			
		||||
func (t *Trace) ID() string {
 | 
			
		||||
	return t.Identifier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithRequest 设置request
 | 
			
		||||
func (t *Trace) WithRequest(req *Request) *Trace {
 | 
			
		||||
	t.Request = req
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithResponse 设置response
 | 
			
		||||
func (t *Trace) WithResponse(resp *Response) *Trace {
 | 
			
		||||
	t.Response = resp
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AppendDialog 安全的追加内部调用过程dialog
 | 
			
		||||
func (t *Trace) AppendDialog(dialog *Dialog) *Trace {
 | 
			
		||||
	if dialog == nil {
 | 
			
		||||
		return t
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.mux.Lock()
 | 
			
		||||
	defer t.mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	t.ThirdPartyRequests = append(t.ThirdPartyRequests, dialog)
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AppendDebug 追加 debug
 | 
			
		||||
func (t *Trace) AppendDebug(debug *Debug) *Trace {
 | 
			
		||||
	if debug == nil {
 | 
			
		||||
		return t
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.mux.Lock()
 | 
			
		||||
	defer t.mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	t.Debugs = append(t.Debugs, debug)
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AppendSQL 追加 SQL
 | 
			
		||||
func (t *Trace) AppendSQL(sql *SQL) *Trace {
 | 
			
		||||
	if sql == nil {
 | 
			
		||||
		return t
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.mux.Lock()
 | 
			
		||||
	defer t.mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	t.SQLs = append(t.SQLs, sql)
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AppendRedis 追加 Redis
 | 
			
		||||
func (t *Trace) AppendRedis(redis *Redis) *Trace {
 | 
			
		||||
	if redis == nil {
 | 
			
		||||
		return t
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.mux.Lock()
 | 
			
		||||
	defer t.mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	t.Redis = append(t.Redis, redis)
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AppendGRPC 追加 gRPC 调用信息
 | 
			
		||||
func (t *Trace) AppendGRPC(grpc *Grpc) *Trace {
 | 
			
		||||
	if grpc == nil {
 | 
			
		||||
		return t
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.mux.Lock()
 | 
			
		||||
	defer t.mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	t.GRPCs = append(t.GRPCs, grpc)
 | 
			
		||||
	return t
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										228
									
								
								pkg/upload/tus.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								pkg/upload/tus.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,228 @@
 | 
			
		||||
package upload
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/color"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/ticker"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/token"
 | 
			
		||||
	"github.com/rs/cors"
 | 
			
		||||
	"github.com/tus/tusd/pkg/filestore"
 | 
			
		||||
	tus "github.com/tus/tusd/pkg/handler"
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Server = (*server)(nil)
 | 
			
		||||
 | 
			
		||||
type Server interface {
 | 
			
		||||
	GetUploadToken(string, string, time.Duration) string
 | 
			
		||||
	GetFileInfo(string) (*tus.FileInfo, error)
 | 
			
		||||
 | 
			
		||||
	Start(func(string, string, tus.FileInfo)) error
 | 
			
		||||
	Stop() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type server struct {
 | 
			
		||||
	headerTokenKey string
 | 
			
		||||
	uploading      sync.Map
 | 
			
		||||
	config         Config
 | 
			
		||||
	token          token.Token
 | 
			
		||||
	store          filestore.FileStore
 | 
			
		||||
	logger         *zap.Logger
 | 
			
		||||
	httpServer     *http.Server
 | 
			
		||||
	ctx            context.Context
 | 
			
		||||
	done           context.CancelFunc
 | 
			
		||||
	checker        ticker.Ticker
 | 
			
		||||
	completedEvent func(sha256, param string, info tus.FileInfo)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Config struct {
 | 
			
		||||
	ListenAddr      string
 | 
			
		||||
	Path            string
 | 
			
		||||
	Dir             string
 | 
			
		||||
	Secret          string
 | 
			
		||||
	DisableDownload bool
 | 
			
		||||
	Debug           bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New(conf Config, logger *zap.Logger) Server {
 | 
			
		||||
	ctx, cancelFunc := context.WithCancel(context.Background())
 | 
			
		||||
	return &server{
 | 
			
		||||
		config:         conf,
 | 
			
		||||
		uploading:      sync.Map{},
 | 
			
		||||
		headerTokenKey: "Authorization",
 | 
			
		||||
		logger:         logger,
 | 
			
		||||
		token:          token.New(conf.Secret),
 | 
			
		||||
		ctx:            ctx,
 | 
			
		||||
		done:           cancelFunc,
 | 
			
		||||
		checker:        ticker.New(time.Minute),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *server) GetUploadToken(sha256, param string, ttl time.Duration) string {
 | 
			
		||||
	sign, _ := s.token.JwtSign(sha256, param, ttl)
 | 
			
		||||
	return sign
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *server) GetFileInfo(id string) (*tus.FileInfo, error) {
 | 
			
		||||
	upload, err := s.store.GetUpload(context.Background(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	info, err := upload.GetInfo(context.Background())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &info, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *server) Start(completedEvent func(sha256, param string, info tus.FileInfo)) error {
 | 
			
		||||
	s.completedEvent = completedEvent
 | 
			
		||||
 | 
			
		||||
	composer := tus.NewStoreComposer()
 | 
			
		||||
	if err := os.MkdirAll(s.config.Dir, os.ModePerm); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	s.store = filestore.New(s.config.Dir)
 | 
			
		||||
	s.store.UseIn(composer)
 | 
			
		||||
 | 
			
		||||
	handler, err := tus.NewHandler(tus.Config{
 | 
			
		||||
		StoreComposer:           composer,
 | 
			
		||||
		BasePath:                s.config.Path,
 | 
			
		||||
		Logger:                  zap.NewStdLog(s.logger),
 | 
			
		||||
		NotifyCompleteUploads:   true,
 | 
			
		||||
		NotifyTerminatedUploads: true,
 | 
			
		||||
		DisableTermination:      true,
 | 
			
		||||
		DisableDownload:         s.config.DisableDownload,
 | 
			
		||||
		RespectForwardedHeaders: strings.Contains(s.config.ListenAddr, "127.0.0.1"),
 | 
			
		||||
		PreUploadCreateCallback: func(hook tus.HookEvent) error {
 | 
			
		||||
			authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey)
 | 
			
		||||
			jwtClaims, err := s.token.JwtParse(authStr)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				_, ok := s.uploading.Load(authStr)
 | 
			
		||||
				if !ok {
 | 
			
		||||
					s.uploading.Store(authStr, jwtClaims.ExpiresAt.Time)
 | 
			
		||||
					return nil
 | 
			
		||||
				}
 | 
			
		||||
				return errors.New("repeated")
 | 
			
		||||
			}
 | 
			
		||||
			return errors.New("unauthorized")
 | 
			
		||||
		},
 | 
			
		||||
		PreFinishResponseCallback: func(hook tus.HookEvent) error {
 | 
			
		||||
			authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey)
 | 
			
		||||
			jwtParse, err := s.token.JwtParse(authStr)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.New("token expired")
 | 
			
		||||
			}
 | 
			
		||||
			_, ok := s.uploading.Load(authStr)
 | 
			
		||||
			if ok {
 | 
			
		||||
				s.uploading.Delete(authStr)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			upload, err := s.store.GetUpload(context.Background(), hook.Upload.ID)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			info, err := upload.GetInfo(context.Background())
 | 
			
		||||
			path, exist := info.Storage["Path"]
 | 
			
		||||
			if err != nil || !exist {
 | 
			
		||||
				return errors.New("file not found")
 | 
			
		||||
			}
 | 
			
		||||
			content, err := os.ReadFile(path)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			hash := sha256.New()
 | 
			
		||||
			hash.Write(content)
 | 
			
		||||
			sha256Byte := hash.Sum(nil)
 | 
			
		||||
			sha256String := fmt.Sprintf("%x", sha256Byte)
 | 
			
		||||
			if !s.config.Debug && sha256String != strings.ToLower(jwtParse.ID) {
 | 
			
		||||
				_ = os.Remove(path)
 | 
			
		||||
				_ = os.Remove(path + ".info")
 | 
			
		||||
				return errors.New("file check error")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return nil
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case event := <-handler.CompleteUploads:
 | 
			
		||||
				authStr := event.HTTPRequest.Header.Get(s.headerTokenKey)
 | 
			
		||||
				jwtParse, _ := s.token.JwtParse(authStr)
 | 
			
		||||
				if s.completedEvent != nil {
 | 
			
		||||
					go func() {
 | 
			
		||||
						s.completedEvent(jwtParse.ID, jwtParse.Subject, event.Upload)
 | 
			
		||||
					}()
 | 
			
		||||
				}
 | 
			
		||||
			case <-s.ctx.Done():
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case event := <-handler.TerminatedUploads:
 | 
			
		||||
				upload, _ := s.store.GetUpload(context.Background(), event.Upload.ID)
 | 
			
		||||
				if upload != nil {
 | 
			
		||||
					info, _ := upload.GetInfo(context.Background())
 | 
			
		||||
					path, exist := info.Storage["Path"]
 | 
			
		||||
					if exist {
 | 
			
		||||
						_ = os.Remove(path)
 | 
			
		||||
						_ = os.Remove(path + ".info")
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			case <-s.ctx.Done():
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	s.checker.Process(func() {
 | 
			
		||||
		s.uploading.Range(func(key, value any) bool {
 | 
			
		||||
			t := value.(time.Time)
 | 
			
		||||
			if t.Before(time.Now()) {
 | 
			
		||||
				s.uploading.Delete(key)
 | 
			
		||||
			}
 | 
			
		||||
			return true
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	//监听服务
 | 
			
		||||
	addr := s.config.ListenAddr
 | 
			
		||||
	mux := http.NewServeMux()
 | 
			
		||||
	mux.Handle(s.config.Path, http.StripPrefix(s.config.Path, handler))
 | 
			
		||||
	s.httpServer = &http.Server{
 | 
			
		||||
		Addr:    addr,
 | 
			
		||||
		Handler: cors.AllowAll().Handler(mux),
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		if err = s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
 | 
			
		||||
			s.logger.Sugar().Fatal("upload server startup err", zap.Error(err))
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	fmt.Println(color.Green(fmt.Sprintf("* [register tusd listen %s]", addr)))
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *server) Stop() error {
 | 
			
		||||
	s.done()
 | 
			
		||||
	return s.httpServer.Close()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										169
									
								
								pkg/urltable/urltable.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								pkg/urltable/urltable.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,169 @@
 | 
			
		||||
package urltable
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	empty      = ""
 | 
			
		||||
	fuzzy      = "*"
 | 
			
		||||
	omitted    = "**"
 | 
			
		||||
	delimiter  = "/"
 | 
			
		||||
	methodView = "VIEW"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// parse and validate pattern
 | 
			
		||||
func parse(pattern string) ([]string, error) {
 | 
			
		||||
	const format = "[get, post, put, patch, delete, view]/{a-Z}+/{*}+/{**}"
 | 
			
		||||
 | 
			
		||||
	if pattern = strings.TrimLeft(strings.TrimSpace(pattern), delimiter); pattern == "" {
 | 
			
		||||
		return nil, fmt.Errorf("pattern illegal, should in format of %s", format)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	paths := strings.Split(pattern, delimiter)
 | 
			
		||||
	if len(paths) < 2 {
 | 
			
		||||
		return nil, fmt.Errorf("pattern illegal, should in format of %s", format)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i := range paths {
 | 
			
		||||
		paths[i] = strings.TrimSpace(paths[i])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// likes get/ get/* get/**
 | 
			
		||||
	if len(paths) == 2 && (paths[1] == empty || paths[1] == fuzzy || paths[1] == omitted) {
 | 
			
		||||
		return nil, errors.New("illegal wildcard")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch paths[0] = strings.ToUpper(paths[0]); paths[0] {
 | 
			
		||||
	case http.MethodGet,
 | 
			
		||||
		http.MethodPost,
 | 
			
		||||
		http.MethodPut,
 | 
			
		||||
		http.MethodPatch,
 | 
			
		||||
		http.MethodDelete,
 | 
			
		||||
		methodView:
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, fmt.Errorf("only supports [%s %s %s %s %s %s]",
 | 
			
		||||
			http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, methodView)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for k := 1; k < len(paths); k++ {
 | 
			
		||||
		if paths[k] == empty && k+1 != len(paths) {
 | 
			
		||||
			return nil, errors.New("pattern contains illegal empty path")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if paths[k] == omitted && k+1 != len(paths) {
 | 
			
		||||
			return nil, errors.New("pattern contains illegal omitted path")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return paths, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Format pattern
 | 
			
		||||
func Format(pattern string) (string, error) {
 | 
			
		||||
	paths, err := parse(pattern)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return strings.Join(paths, delimiter), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type section struct {
 | 
			
		||||
	leaf    bool
 | 
			
		||||
	mapping map[string]*section
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newSection() *section {
 | 
			
		||||
	return §ion{mapping: make(map[string]*section)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Table a (thread unsafe) table to store urls
 | 
			
		||||
type Table struct {
 | 
			
		||||
	size int
 | 
			
		||||
	root *section
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewTable create a table
 | 
			
		||||
func NewTable() *Table {
 | 
			
		||||
	return &Table{root: newSection()}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Size contains how many urls
 | 
			
		||||
func (t *Table) Size() int {
 | 
			
		||||
	return t.size
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Append pattern
 | 
			
		||||
func (t *Table) Append(pattern string) error {
 | 
			
		||||
	paths, err := parse(pattern)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	insert := false
 | 
			
		||||
	root := t.root
 | 
			
		||||
	for i, path := range paths {
 | 
			
		||||
		if (path == fuzzy && root.mapping[omitted] != nil) ||
 | 
			
		||||
			(path == omitted && root.mapping[fuzzy] != nil) ||
 | 
			
		||||
			(path != omitted && root.mapping[omitted] != nil) {
 | 
			
		||||
			return fmt.Errorf("conflict at %s", strings.Join(paths[:i], delimiter))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		next := root.mapping[path]
 | 
			
		||||
		if next == nil {
 | 
			
		||||
			next = newSection()
 | 
			
		||||
			root.mapping[path] = next
 | 
			
		||||
			insert = true
 | 
			
		||||
		}
 | 
			
		||||
		root = next
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if insert {
 | 
			
		||||
		t.size++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	root.leaf = true
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mapping url to pattern
 | 
			
		||||
func (t *Table) Mapping(url string) (string, error) {
 | 
			
		||||
	paths, err := parse(url)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pattern := make([]string, 0, len(paths))
 | 
			
		||||
 | 
			
		||||
	root := t.root
 | 
			
		||||
	for _, path := range paths {
 | 
			
		||||
		next := root.mapping[path]
 | 
			
		||||
		if next == nil {
 | 
			
		||||
			nextFuzzy := root.mapping[fuzzy]
 | 
			
		||||
			nextOmitted := root.mapping[omitted]
 | 
			
		||||
			if nextFuzzy == nil && nextOmitted == nil {
 | 
			
		||||
				return "", nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if nextOmitted != nil {
 | 
			
		||||
				pattern = append(pattern, omitted)
 | 
			
		||||
				return strings.Join(pattern, delimiter), nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			next = nextFuzzy
 | 
			
		||||
			path = fuzzy
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		root = next
 | 
			
		||||
		pattern = append(pattern, path)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if root.leaf {
 | 
			
		||||
		return strings.Join(pattern, delimiter), nil
 | 
			
		||||
	}
 | 
			
		||||
	return "", nil
 | 
			
		||||
}
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user