base-golang/pkg/mux/core.go

467 lines
12 KiB
Go
Raw Normal View History

2024-07-23 10:23:43 +08:00
package mux
import (
"errors"
"fmt"
"net/http"
"net/url"
"runtime/debug"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/color"
"git.bvbej.com/bvbej/base-golang/pkg/env"
"git.bvbej.com/bvbej/base-golang/pkg/errno"
"git.bvbej.com/bvbej/base-golang/pkg/limiter"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"git.bvbej.com/bvbej/base-golang/pkg/validator"
"github.com/gin-contrib/pprof"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/prometheus/client_golang/prometheus/promhttp"
cors "github.com/rs/cors/wrapper/gin"
"go.uber.org/multierr"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
type Option func(*option)
type option struct {
enableCors bool
enablePProf bool
enablePrometheus bool
enableOpenBrowser string
staticDirs []string
panicNotify OnPanicNotify
recordMetrics RecordMetrics
rateLimiter limiter.RateLimiter
}
const SuccessCode = 0
type Failure struct {
ResultCode int `json:"result_code"` // 业务码
ResultInfo string `json:"result_info"` // 描述信息
}
type Success struct {
ResultCode int `json:"result_code"` // 业务码
ResultData any `json:"result_data"` //返回数据
}
/******************************************************************************/
// OnPanicNotify 发生panic时通知用
type OnPanicNotify func(ctx Context, err any, stackInfo string)
// RecordMetrics 记录prometheus指标用
// 如果使用AliasForRecordMetrics配置了别名uri将被替换为别名。
type RecordMetrics func(method, uri string, success bool, costSeconds float64)
// DisableTrace 禁用追踪链
func DisableTrace(ctx Context) {
ctx.disableTrace()
}
// WithPanicNotify 设置panic时的通知回调
func WithPanicNotify(notify OnPanicNotify) Option {
return func(opt *option) {
opt.panicNotify = notify
fmt.Println(color.Green("* [register panic notify]"))
}
}
// WithRecordMetrics 设置记录prometheus记录指标回调
func WithRecordMetrics(record RecordMetrics) Option {
return func(opt *option) {
opt.recordMetrics = record
}
}
// WithEnableCors 开启CORS
func WithEnableCors() Option {
return func(opt *option) {
opt.enableCors = true
fmt.Println(color.Green("* [register cors]"))
}
}
// WithEnableRate 开启限流
func WithEnableRate(limit rate.Limit, burst int) Option {
return func(opt *option) {
opt.rateLimiter = limiter.NewRateLimiter(limit, burst)
fmt.Println(color.Green("* [register rate]"))
}
}
// WithStaticDir 设置静态文件目录
func WithStaticDir(dirs []string) Option {
return func(opt *option) {
opt.staticDirs = dirs
fmt.Println(color.Green("* [register rate]"))
}
}
// AliasForRecordMetrics 对请求uri起个别名用于prometheus记录指标。
// 如Get /user/:username 这样的uri因为username会有非常多的情况这样记录prometheus指标会非常的不有好。
func AliasForRecordMetrics(path string) HandlerFunc {
return func(ctx Context) {
ctx.setAlias(path)
}
}
/******************************************************************************/
// RouterGroup 包装gin的RouterGroup
type RouterGroup interface {
Group(string, ...HandlerFunc) RouterGroup
IRoutes
}
var _ IRoutes = (*router)(nil)
// IRoutes 包装gin的IRoutes
type IRoutes interface {
Any(string, ...HandlerFunc)
GET(string, ...HandlerFunc)
POST(string, ...HandlerFunc)
DELETE(string, ...HandlerFunc)
PATCH(string, ...HandlerFunc)
PUT(string, ...HandlerFunc)
OPTIONS(string, ...HandlerFunc)
HEAD(string, ...HandlerFunc)
}
type router struct {
group *gin.RouterGroup
}
func (r *router) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
group := r.group.Group(relativePath, wrapHandlers(handlers...)...)
return &router{group: group}
}
func (r *router) Any(relativePath string, handlers ...HandlerFunc) {
r.group.Any(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) GET(relativePath string, handlers ...HandlerFunc) {
r.group.GET(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) POST(relativePath string, handlers ...HandlerFunc) {
r.group.POST(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) DELETE(relativePath string, handlers ...HandlerFunc) {
r.group.DELETE(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) PATCH(relativePath string, handlers ...HandlerFunc) {
r.group.PATCH(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) PUT(relativePath string, handlers ...HandlerFunc) {
r.group.PUT(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) OPTIONS(relativePath string, handlers ...HandlerFunc) {
r.group.OPTIONS(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) HEAD(relativePath string, handlers ...HandlerFunc) {
r.group.HEAD(relativePath, wrapHandlers(handlers...)...)
}
func wrapHandlers(handlers ...HandlerFunc) []gin.HandlerFunc {
list := make([]gin.HandlerFunc, len(handlers))
for i, handler := range handlers {
fn := handler
list[i] = func(c *gin.Context) {
ctx := newContext(c)
defer releaseContext(ctx)
fn(ctx)
}
}
return list
}
/******************************************************************************/
var _ Mux = (*mux)(nil)
type Mux interface {
http.Handler
Group(relativePath string, handlers ...HandlerFunc) RouterGroup
Routes() gin.RoutesInfo
HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc)
}
type mux struct {
engine *gin.Engine
}
func (m *mux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
m.engine.ServeHTTP(w, req)
}
func (m *mux) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
return &router{
group: m.engine.Group(relativePath, wrapHandlers(handlers...)...),
}
}
func (m *mux) Routes() gin.RoutesInfo {
return m.engine.Routes()
}
func (m *mux) HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc) {
m.engine.GET(relativePath, handlerFunc)
}
func New(logger *zap.Logger, options ...Option) (Mux, error) {
if logger == nil {
return nil, errors.New("logger required")
}
gin.SetMode(gin.ReleaseMode)
binding.Validator = validator.Validator
newMux := &mux{
engine: gin.New(),
}
fmt.Println(color.Green(fmt.Sprintf("* [register env %s]", env.Active().Value())))
// withoutLogPaths 这些请求,默认不记录日志
withoutTracePaths := map[string]bool{
"/metrics": true,
"/favicon.ico": true,
"/system/health": true,
}
opt := new(option)
for _, f := range options {
f(opt)
}
if opt.enablePProf {
pprof.Register(newMux.engine)
fmt.Println(color.Green("* [register pprof]"))
}
if opt.enablePrometheus {
newMux.engine.GET("/metrics", gin.WrapH(promhttp.Handler()))
fmt.Println(color.Green("* [register prometheus]"))
}
if opt.enableCors {
newMux.engine.Use(cors.AllowAll())
}
if opt.staticDirs != nil {
for _, dir := range opt.staticDirs {
newMux.engine.StaticFS(dir, gin.Dir(dir, false))
}
}
// recover两次防止处理时发生panic尤其是在OnPanicNotify中。
newMux.engine.Use(func(ctx *gin.Context) {
defer func() {
if err := recover(); err != nil {
logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", string(debug.Stack())))
}
}()
ctx.Next()
})
newMux.engine.Use(func(ctx *gin.Context) {
ts := time.Now()
newCtx := newContext(ctx)
defer releaseContext(newCtx)
newCtx.init()
newCtx.setLogger(logger)
if !withoutTracePaths[ctx.Request.URL.Path] {
if traceId := newCtx.GetHeader(trace.Header); traceId != "" {
newCtx.setTrace(trace.New(traceId))
} else {
newCtx.setTrace(trace.New(""))
}
}
defer func() {
if err := recover(); err != nil {
stackInfo := string(debug.Stack())
logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", stackInfo))
newCtx.AbortWithError(errno.NewError(
http.StatusInternalServerError,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError)),
)
if notify := opt.panicNotify; notify != nil {
notify(newCtx, err, stackInfo)
}
}
if ctx.Writer.Status() == http.StatusNotFound {
return
}
var (
response any
businessCode int
businessCodeMsg string
abortErr error
graphResponse any
)
if ctx.IsAborted() {
for i := range ctx.Errors { // gin error
multierr.AppendInto(&abortErr, ctx.Errors[i])
}
if err := newCtx.abortError(); err != nil { // customer err
multierr.AppendInto(&abortErr, err.GetErr())
response = err
businessCode = err.GetBusinessCode()
businessCodeMsg = err.GetMsg()
if x := newCtx.Trace(); x != nil {
newCtx.SetHeader(trace.Header, x.ID())
}
ctx.JSON(err.GetHttpCode(), &Failure{
ResultCode: businessCode,
ResultInfo: businessCodeMsg,
})
}
} else {
response = newCtx.getPayload()
if response != nil {
if x := newCtx.Trace(); x != nil {
newCtx.SetHeader(trace.Header, x.ID())
}
ctx.JSON(http.StatusOK, response)
}
}
graphResponse = newCtx.getGraphPayload()
if opt.recordMetrics != nil {
uri := newCtx.Path()
if alias := newCtx.Alias(); alias != "" {
uri = alias
}
opt.recordMetrics(
newCtx.Method(),
uri,
!ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK,
time.Since(ts).Seconds(),
)
}
var t *trace.Trace
if x := newCtx.Trace(); x != nil {
t = x.(*trace.Trace)
} else {
return
}
decodedURL, _ := url.QueryUnescape(ctx.Request.URL.RequestURI())
t.WithRequest(&trace.Request{
TTL: "un-limit",
Method: ctx.Request.Method,
DecodedURL: decodedURL,
Header: ctx.Request.Header,
Body: string(newCtx.RawData()),
})
var responseBody any
if response != nil {
responseBody = response
}
if graphResponse != nil {
responseBody = graphResponse
}
t.WithResponse(&trace.Response{
Header: ctx.Writer.Header(),
HttpCode: ctx.Writer.Status(),
HttpCodeMsg: http.StatusText(ctx.Writer.Status()),
BusinessCode: businessCode,
BusinessCodeMsg: businessCodeMsg,
Body: responseBody,
CostSeconds: time.Since(ts).Seconds(),
})
t.Success = !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK
t.CostSeconds = time.Since(ts).Seconds()
logger.Info("core-interceptor",
zap.Any("method", ctx.Request.Method),
zap.Any("path", decodedURL),
zap.Any("http_code", ctx.Writer.Status()),
zap.Any("business_code", businessCode),
zap.Any("success", t.Success),
zap.Any("cost_seconds", t.CostSeconds),
zap.Any("trace_id", t.Identifier),
zap.Any("trace_info", t),
zap.Error(abortErr),
)
}()
ctx.Next()
})
if opt.rateLimiter != nil {
newMux.engine.Use(func(ctx *gin.Context) {
newCtx := newContext(ctx)
defer releaseContext(newCtx)
if !opt.rateLimiter.Allow(ctx.ClientIP()) {
newCtx.AbortWithError(errno.NewError(
http.StatusTooManyRequests,
http.StatusTooManyRequests,
http.StatusText(http.StatusTooManyRequests)),
)
return
}
ctx.Next()
})
}
newMux.engine.NoMethod(wrapHandlers(DisableTrace)...)
newMux.engine.NoRoute(wrapHandlers(DisableTrace)...)
system := newMux.Group("/system")
{
// 健康检查
system.GET("/health", func(ctx Context) {
resp := &struct {
Timestamp time.Time `json:"timestamp"`
Environment string `json:"environment"`
Host string `json:"host"`
Status string `json:"status"`
}{
Timestamp: time.Now(),
Environment: env.Active().Value(),
Host: ctx.Host(),
Status: "ok",
}
ctx.Payload(resp)
})
}
return newMux, nil
}