base-golang/pkg/mux/core.go
2024-07-31 16:49:14 +08:00

467 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mux
import (
"errors"
"fmt"
"net/http"
"net/url"
"runtime/debug"
"time"
"gitea.bvbej.com/bvbej/base-golang/pkg/color"
"gitea.bvbej.com/bvbej/base-golang/pkg/env"
"gitea.bvbej.com/bvbej/base-golang/pkg/errno"
"gitea.bvbej.com/bvbej/base-golang/pkg/limiter"
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
"gitea.bvbej.com/bvbej/base-golang/pkg/validator"
"github.com/gin-contrib/pprof"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/prometheus/client_golang/prometheus/promhttp"
cors "github.com/rs/cors/wrapper/gin"
"go.uber.org/multierr"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
type Option func(*option)
type option struct {
enableCors bool
enablePProf bool
enablePrometheus bool
enableOpenBrowser string
staticDirs []string
panicNotify OnPanicNotify
recordMetrics RecordMetrics
rateLimiter limiter.RateLimiter
}
const SuccessCode = 0
type Failure struct {
ResultCode int `json:"result_code"` // 业务码
ResultInfo string `json:"result_info"` // 描述信息
}
type Success struct {
ResultCode int `json:"result_code"` // 业务码
ResultData any `json:"result_data"` //返回数据
}
/******************************************************************************/
// OnPanicNotify 发生panic时通知用
type OnPanicNotify func(ctx Context, err any, stackInfo string)
// RecordMetrics 记录prometheus指标用
// 如果使用AliasForRecordMetrics配置了别名uri将被替换为别名。
type RecordMetrics func(method, uri string, success bool, costSeconds float64)
// DisableTrace 禁用追踪链
func DisableTrace(ctx Context) {
ctx.disableTrace()
}
// WithPanicNotify 设置panic时的通知回调
func WithPanicNotify(notify OnPanicNotify) Option {
return func(opt *option) {
opt.panicNotify = notify
fmt.Println(color.Green("* [register panic notify]"))
}
}
// WithRecordMetrics 设置记录prometheus记录指标回调
func WithRecordMetrics(record RecordMetrics) Option {
return func(opt *option) {
opt.recordMetrics = record
}
}
// WithEnableCors 开启CORS
func WithEnableCors() Option {
return func(opt *option) {
opt.enableCors = true
fmt.Println(color.Green("* [register cors]"))
}
}
// WithEnableRate 开启限流
func WithEnableRate(limit rate.Limit, burst int) Option {
return func(opt *option) {
opt.rateLimiter = limiter.NewRateLimiter(limit, burst)
fmt.Println(color.Green("* [register rate]"))
}
}
// WithStaticDir 设置静态文件目录
func WithStaticDir(dirs []string) Option {
return func(opt *option) {
opt.staticDirs = dirs
fmt.Println(color.Green("* [register rate]"))
}
}
// AliasForRecordMetrics 对请求uri起个别名用于prometheus记录指标。
// 如Get /user/:username 这样的uri因为username会有非常多的情况这样记录prometheus指标会非常的不有好。
func AliasForRecordMetrics(path string) HandlerFunc {
return func(ctx Context) {
ctx.setAlias(path)
}
}
/******************************************************************************/
// RouterGroup 包装gin的RouterGroup
type RouterGroup interface {
Group(string, ...HandlerFunc) RouterGroup
IRoutes
}
var _ IRoutes = (*router)(nil)
// IRoutes 包装gin的IRoutes
type IRoutes interface {
Any(string, ...HandlerFunc)
GET(string, ...HandlerFunc)
POST(string, ...HandlerFunc)
DELETE(string, ...HandlerFunc)
PATCH(string, ...HandlerFunc)
PUT(string, ...HandlerFunc)
OPTIONS(string, ...HandlerFunc)
HEAD(string, ...HandlerFunc)
}
type router struct {
group *gin.RouterGroup
}
func (r *router) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
group := r.group.Group(relativePath, wrapHandlers(handlers...)...)
return &router{group: group}
}
func (r *router) Any(relativePath string, handlers ...HandlerFunc) {
r.group.Any(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) GET(relativePath string, handlers ...HandlerFunc) {
r.group.GET(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) POST(relativePath string, handlers ...HandlerFunc) {
r.group.POST(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) DELETE(relativePath string, handlers ...HandlerFunc) {
r.group.DELETE(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) PATCH(relativePath string, handlers ...HandlerFunc) {
r.group.PATCH(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) PUT(relativePath string, handlers ...HandlerFunc) {
r.group.PUT(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) OPTIONS(relativePath string, handlers ...HandlerFunc) {
r.group.OPTIONS(relativePath, wrapHandlers(handlers...)...)
}
func (r *router) HEAD(relativePath string, handlers ...HandlerFunc) {
r.group.HEAD(relativePath, wrapHandlers(handlers...)...)
}
func wrapHandlers(handlers ...HandlerFunc) []gin.HandlerFunc {
list := make([]gin.HandlerFunc, len(handlers))
for i, handler := range handlers {
fn := handler
list[i] = func(c *gin.Context) {
ctx := newContext(c)
defer releaseContext(ctx)
fn(ctx)
}
}
return list
}
/******************************************************************************/
var _ Mux = (*mux)(nil)
type Mux interface {
http.Handler
Group(relativePath string, handlers ...HandlerFunc) RouterGroup
Routes() gin.RoutesInfo
HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc)
}
type mux struct {
engine *gin.Engine
}
func (m *mux) ServeHTTP(w http.ResponseWriter, req *http.Request) {
m.engine.ServeHTTP(w, req)
}
func (m *mux) Group(relativePath string, handlers ...HandlerFunc) RouterGroup {
return &router{
group: m.engine.Group(relativePath, wrapHandlers(handlers...)...),
}
}
func (m *mux) Routes() gin.RoutesInfo {
return m.engine.Routes()
}
func (m *mux) HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc) {
m.engine.GET(relativePath, handlerFunc)
}
func New(logger *zap.Logger, options ...Option) (Mux, error) {
if logger == nil {
return nil, errors.New("logger required")
}
gin.SetMode(gin.ReleaseMode)
binding.Validator = validator.Validator
newMux := &mux{
engine: gin.New(),
}
fmt.Println(color.Green(fmt.Sprintf("* [register env %s]", env.Active().Value())))
// withoutLogPaths 这些请求,默认不记录日志
withoutTracePaths := map[string]bool{
"/metrics": true,
"/favicon.ico": true,
"/system/health": true,
}
opt := new(option)
for _, f := range options {
f(opt)
}
if opt.enablePProf {
pprof.Register(newMux.engine)
fmt.Println(color.Green("* [register pprof]"))
}
if opt.enablePrometheus {
newMux.engine.GET("/metrics", gin.WrapH(promhttp.Handler()))
fmt.Println(color.Green("* [register prometheus]"))
}
if opt.enableCors {
newMux.engine.Use(cors.AllowAll())
}
if opt.staticDirs != nil {
for _, dir := range opt.staticDirs {
newMux.engine.StaticFS(dir, gin.Dir(dir, false))
}
}
// recover两次防止处理时发生panic尤其是在OnPanicNotify中。
newMux.engine.Use(func(ctx *gin.Context) {
defer func() {
if err := recover(); err != nil {
logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", string(debug.Stack())))
}
}()
ctx.Next()
})
newMux.engine.Use(func(ctx *gin.Context) {
ts := time.Now()
newCtx := newContext(ctx)
defer releaseContext(newCtx)
newCtx.init()
newCtx.setLogger(logger)
if !withoutTracePaths[ctx.Request.URL.Path] {
if traceId := newCtx.GetHeader(trace.Header); traceId != "" {
newCtx.setTrace(trace.New(traceId))
} else {
newCtx.setTrace(trace.New(""))
}
}
defer func() {
if err := recover(); err != nil {
stackInfo := string(debug.Stack())
logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", stackInfo))
newCtx.AbortWithError(errno.NewError(
http.StatusInternalServerError,
http.StatusInternalServerError,
http.StatusText(http.StatusInternalServerError)),
)
if notify := opt.panicNotify; notify != nil {
notify(newCtx, err, stackInfo)
}
}
if ctx.Writer.Status() == http.StatusNotFound {
return
}
var (
response any
businessCode int
businessCodeMsg string
abortErr error
graphResponse any
)
if ctx.IsAborted() {
for i := range ctx.Errors { // gin error
multierr.AppendInto(&abortErr, ctx.Errors[i])
}
if err := newCtx.abortError(); err != nil { // customer err
multierr.AppendInto(&abortErr, err.GetErr())
response = err
businessCode = err.GetBusinessCode()
businessCodeMsg = err.GetMsg()
if x := newCtx.Trace(); x != nil {
newCtx.SetHeader(trace.Header, x.ID())
}
ctx.JSON(err.GetHttpCode(), &Failure{
ResultCode: businessCode,
ResultInfo: businessCodeMsg,
})
}
} else {
response = newCtx.getPayload()
if response != nil {
if x := newCtx.Trace(); x != nil {
newCtx.SetHeader(trace.Header, x.ID())
}
ctx.JSON(http.StatusOK, response)
}
}
graphResponse = newCtx.getGraphPayload()
if opt.recordMetrics != nil {
uri := newCtx.Path()
if alias := newCtx.Alias(); alias != "" {
uri = alias
}
opt.recordMetrics(
newCtx.Method(),
uri,
!ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK,
time.Since(ts).Seconds(),
)
}
var t *trace.Trace
if x := newCtx.Trace(); x != nil {
t = x.(*trace.Trace)
} else {
return
}
decodedURL, _ := url.QueryUnescape(ctx.Request.URL.RequestURI())
t.WithRequest(&trace.Request{
TTL: "un-limit",
Method: ctx.Request.Method,
DecodedURL: decodedURL,
Header: ctx.Request.Header,
Body: string(newCtx.RawData()),
})
var responseBody any
if response != nil {
responseBody = response
}
if graphResponse != nil {
responseBody = graphResponse
}
t.WithResponse(&trace.Response{
Header: ctx.Writer.Header(),
HttpCode: ctx.Writer.Status(),
HttpCodeMsg: http.StatusText(ctx.Writer.Status()),
BusinessCode: businessCode,
BusinessCodeMsg: businessCodeMsg,
Body: responseBody,
CostSeconds: time.Since(ts).Seconds(),
})
t.Success = !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK
t.CostSeconds = time.Since(ts).Seconds()
logger.Info("core-interceptor",
zap.Any("method", ctx.Request.Method),
zap.Any("path", decodedURL),
zap.Any("http_code", ctx.Writer.Status()),
zap.Any("business_code", businessCode),
zap.Any("success", t.Success),
zap.Any("cost_seconds", t.CostSeconds),
zap.Any("trace_id", t.Identifier),
zap.Any("trace_info", t),
zap.Error(abortErr),
)
}()
ctx.Next()
})
if opt.rateLimiter != nil {
newMux.engine.Use(func(ctx *gin.Context) {
newCtx := newContext(ctx)
defer releaseContext(newCtx)
if !opt.rateLimiter.Allow(ctx.ClientIP()) {
newCtx.AbortWithError(errno.NewError(
http.StatusTooManyRequests,
http.StatusTooManyRequests,
http.StatusText(http.StatusTooManyRequests)),
)
return
}
ctx.Next()
})
}
newMux.engine.NoMethod(wrapHandlers(DisableTrace)...)
newMux.engine.NoRoute(wrapHandlers(DisableTrace)...)
system := newMux.Group("/system")
{
// 健康检查
system.GET("/health", func(ctx Context) {
resp := &struct {
Timestamp time.Time `json:"timestamp"`
Environment string `json:"environment"`
Host string `json:"host"`
Status string `json:"status"`
}{
Timestamp: time.Now(),
Environment: env.Active().Value(),
Host: ctx.Host(),
Status: "ok",
}
ctx.Payload(resp)
})
}
return newMux, nil
}