2024-07-23 10:23:43 +08:00
|
|
|
|
package mux
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
|
|
|
|
stdCtx "context"
|
|
|
|
|
"io"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/url"
|
|
|
|
|
"strings"
|
|
|
|
|
"sync"
|
|
|
|
|
|
2024-07-31 16:49:14 +08:00
|
|
|
|
"gitea.bvbej.com/bvbej/base-golang/pkg/errno"
|
|
|
|
|
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
|
2024-07-23 10:23:43 +08:00
|
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
"github.com/gin-gonic/gin/binding"
|
|
|
|
|
"go.uber.org/zap"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type HandlerFunc func(c Context)
|
|
|
|
|
|
|
|
|
|
type Trace = trace.T
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
_Alias = "_alias_"
|
|
|
|
|
_TraceName = "_trace_"
|
|
|
|
|
_LoggerName = "_logger_"
|
|
|
|
|
_BodyName = "_body_"
|
|
|
|
|
_PayloadName = "_payload_"
|
|
|
|
|
_GraphPayloadName = "_graph_payload_"
|
|
|
|
|
_AbortErrorName = "_abort_error_"
|
|
|
|
|
_Auth = "_auth_"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var contextPool = &sync.Pool{
|
|
|
|
|
New: func() any {
|
|
|
|
|
return new(context)
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newContext(ctx *gin.Context) Context {
|
|
|
|
|
getContext := contextPool.Get().(*context)
|
|
|
|
|
getContext.ctx = ctx
|
|
|
|
|
return getContext
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func releaseContext(ctx Context) {
|
|
|
|
|
c := ctx.(*context)
|
|
|
|
|
c.ctx = nil
|
|
|
|
|
contextPool.Put(c)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var _ Context = (*context)(nil)
|
|
|
|
|
|
|
|
|
|
type Context interface {
|
|
|
|
|
init()
|
|
|
|
|
|
|
|
|
|
Context() *gin.Context
|
|
|
|
|
|
|
|
|
|
// ShouldBindQuery 反序列化 query
|
|
|
|
|
// tag: `form:"xxx"` (注:不要写成 query)
|
|
|
|
|
ShouldBindQuery(obj any) error
|
|
|
|
|
|
|
|
|
|
// ShouldBindPostForm 反序列化 x-www-from-urlencoded
|
|
|
|
|
// tag: `form:"xxx"`
|
|
|
|
|
ShouldBindPostForm(obj any) error
|
|
|
|
|
|
|
|
|
|
// ShouldBindForm 同时反序列化 form-data;
|
|
|
|
|
// tag: `form:"xxx"`
|
|
|
|
|
ShouldBindForm(obj any) error
|
|
|
|
|
|
|
|
|
|
// ShouldBindJSON 反序列化 post-json
|
|
|
|
|
// tag: `json:"xxx"`
|
|
|
|
|
ShouldBindJSON(obj any) error
|
|
|
|
|
|
|
|
|
|
// ShouldBindURI 反序列化 path 参数(如路由路径为 /user/:name)
|
|
|
|
|
// tag: `uri:"xxx"`
|
|
|
|
|
ShouldBindURI(obj any) error
|
|
|
|
|
|
|
|
|
|
// Redirect 重定向
|
|
|
|
|
Redirect(code int, location string)
|
|
|
|
|
|
|
|
|
|
// Trace 获取 Trace 对象
|
|
|
|
|
Trace() Trace
|
|
|
|
|
setTrace(trace Trace)
|
|
|
|
|
disableTrace()
|
|
|
|
|
|
|
|
|
|
// Logger 获取 Logger 对象
|
|
|
|
|
Logger() *zap.Logger
|
|
|
|
|
setLogger(logger *zap.Logger)
|
|
|
|
|
|
|
|
|
|
// Payload 正确返回
|
|
|
|
|
Payload(payload any)
|
|
|
|
|
getPayload() any
|
|
|
|
|
|
|
|
|
|
// GraphPayload GraphQL返回值 与 api 返回结构不同
|
|
|
|
|
GraphPayload(payload any)
|
|
|
|
|
getGraphPayload() any
|
|
|
|
|
|
|
|
|
|
// HTML 返回界面
|
|
|
|
|
HTML(name string, obj any)
|
|
|
|
|
|
|
|
|
|
// AbortWithError 错误返回
|
|
|
|
|
AbortWithError(err errno.Error)
|
|
|
|
|
abortError() errno.Error
|
|
|
|
|
|
|
|
|
|
// Header 获取 Header 对象
|
|
|
|
|
Header() http.Header
|
|
|
|
|
// GetHeader 获取 Header
|
|
|
|
|
GetHeader(key string) string
|
|
|
|
|
// SetHeader 设置 Header
|
|
|
|
|
SetHeader(key, value string)
|
|
|
|
|
|
|
|
|
|
Auth() any
|
|
|
|
|
SetAuth(auth any)
|
|
|
|
|
|
|
|
|
|
// Authorization 获取请求认证信息
|
|
|
|
|
Authorization() string
|
|
|
|
|
|
|
|
|
|
// Platform 平台标识
|
|
|
|
|
Platform() string
|
|
|
|
|
|
|
|
|
|
// Alias 设置路由别名 for metrics uri
|
|
|
|
|
Alias() string
|
|
|
|
|
setAlias(path string)
|
|
|
|
|
|
|
|
|
|
// RequestInputParams 获取所有参数
|
|
|
|
|
RequestInputParams() url.Values
|
|
|
|
|
// RequestQueryParams 获取 Query 参数
|
|
|
|
|
RequestQueryParams() url.Values
|
|
|
|
|
// RequestPostFormParams 获取 PostForm 参数
|
|
|
|
|
RequestPostFormParams() url.Values
|
|
|
|
|
// Request 获取 Request 对象
|
|
|
|
|
Request() *http.Request
|
|
|
|
|
// RawData 获取 Request.Body
|
|
|
|
|
RawData() []byte
|
|
|
|
|
// Method 获取 Request.Method
|
|
|
|
|
Method() string
|
|
|
|
|
// Host 获取 Request.Host
|
|
|
|
|
Host() string
|
|
|
|
|
// Path 获取 请求的路径 Request.URL.Path (不附带 querystring)
|
|
|
|
|
Path() string
|
|
|
|
|
// URI 获取 unescape 后的 Request.URL.RequestURI()
|
|
|
|
|
URI() string
|
|
|
|
|
// RequestContext 获取请求的 context (当 client 关闭后,会自动 canceled)
|
|
|
|
|
RequestContext() StdContext
|
|
|
|
|
// ResponseWriter 获取 ResponseWriter 对象
|
|
|
|
|
ResponseWriter() gin.ResponseWriter
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type context struct {
|
|
|
|
|
ctx *gin.Context
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type StdContext struct {
|
|
|
|
|
stdCtx.Context
|
|
|
|
|
Trace
|
|
|
|
|
*zap.Logger
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) init() {
|
|
|
|
|
body, err := c.ctx.GetRawData()
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.ctx.Set(_BodyName, body) // cache body是为了trace使用
|
|
|
|
|
c.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // re-construct req body
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Context() *gin.Context {
|
|
|
|
|
return c.ctx
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ShouldBindQuery 反序列化querystring
|
|
|
|
|
// tag: `form:"xxx"` (注:不要写成query)
|
|
|
|
|
func (c *context) ShouldBindQuery(obj any) error {
|
|
|
|
|
return c.ctx.ShouldBindWith(obj, binding.Query)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ShouldBindPostForm 反序列化 postform (querystring 会被忽略)
|
|
|
|
|
// tag: `form:"xxx"`
|
|
|
|
|
func (c *context) ShouldBindPostForm(obj any) error {
|
|
|
|
|
return c.ctx.ShouldBindWith(obj, binding.FormPost)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ShouldBindForm 同时反序列化querystring和postform;
|
|
|
|
|
// 当querystring和postform存在相同字段时,postform优先使用。
|
|
|
|
|
// tag: `form:"xxx"`
|
|
|
|
|
func (c *context) ShouldBindForm(obj any) error {
|
|
|
|
|
return c.ctx.ShouldBindWith(obj, binding.Form)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ShouldBindJSON 反序列化postjson
|
|
|
|
|
// tag: `json:"xxx"`
|
|
|
|
|
func (c *context) ShouldBindJSON(obj any) error {
|
|
|
|
|
return c.ctx.ShouldBindWith(obj, binding.JSON)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ShouldBindURI 反序列化path参数(如路由路径为 /user/:name)
|
|
|
|
|
// tag: `uri:"xxx"`
|
|
|
|
|
func (c *context) ShouldBindURI(obj any) error {
|
|
|
|
|
return c.ctx.ShouldBindUri(obj)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Redirect 重定向
|
|
|
|
|
func (c *context) Redirect(code int, location string) {
|
|
|
|
|
c.ctx.Redirect(code, location)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Trace() Trace {
|
|
|
|
|
t, ok := c.ctx.Get(_TraceName)
|
|
|
|
|
if !ok || t == nil {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return t.(Trace)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) setTrace(trace Trace) {
|
|
|
|
|
c.ctx.Set(_TraceName, trace)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) disableTrace() {
|
|
|
|
|
c.setTrace(nil)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Logger() *zap.Logger {
|
|
|
|
|
logger, ok := c.ctx.Get(_LoggerName)
|
|
|
|
|
if !ok {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return logger.(*zap.Logger)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) setLogger(logger *zap.Logger) {
|
|
|
|
|
c.ctx.Set(_LoggerName, logger)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) getPayload() any {
|
|
|
|
|
if payload, ok := c.ctx.Get(_PayloadName); ok != false {
|
|
|
|
|
return payload
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Payload(payload any) {
|
|
|
|
|
c.ctx.Set(_PayloadName, payload)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) getGraphPayload() any {
|
|
|
|
|
if payload, ok := c.ctx.Get(_GraphPayloadName); ok != false {
|
|
|
|
|
return payload
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) GraphPayload(payload any) {
|
|
|
|
|
c.ctx.Set(_GraphPayloadName, payload)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) HTML(name string, obj any) {
|
|
|
|
|
c.ctx.HTML(200, name+".html", obj)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Header() http.Header {
|
|
|
|
|
header := c.ctx.Request.Header
|
|
|
|
|
|
|
|
|
|
clone := make(http.Header, len(header))
|
|
|
|
|
for k, v := range header {
|
|
|
|
|
value := make([]string, len(v))
|
|
|
|
|
copy(value, v)
|
|
|
|
|
|
|
|
|
|
clone[k] = value
|
|
|
|
|
}
|
|
|
|
|
return clone
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) GetHeader(key string) string {
|
|
|
|
|
return c.ctx.GetHeader(key)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) SetHeader(key, value string) {
|
|
|
|
|
c.ctx.Header(key, value)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Auth() any {
|
|
|
|
|
val, ok := c.ctx.Get(_Auth)
|
|
|
|
|
if !ok {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return val
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) SetAuth(auth any) {
|
|
|
|
|
c.ctx.Set(_Auth, auth)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Authorization() string {
|
|
|
|
|
return c.ctx.GetHeader("Authorization")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Platform() string {
|
|
|
|
|
return c.ctx.GetHeader("X-Platform")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) AbortWithError(err errno.Error) {
|
|
|
|
|
if err != nil {
|
|
|
|
|
httpCode := err.GetHttpCode()
|
|
|
|
|
if httpCode == 0 {
|
|
|
|
|
httpCode = http.StatusInternalServerError
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.ctx.AbortWithStatus(httpCode)
|
|
|
|
|
c.ctx.Set(_AbortErrorName, err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) abortError() errno.Error {
|
|
|
|
|
err, _ := c.ctx.Get(_AbortErrorName)
|
|
|
|
|
return err.(errno.Error)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) Alias() string {
|
|
|
|
|
path, ok := c.ctx.Get(_Alias)
|
|
|
|
|
if !ok {
|
|
|
|
|
return ""
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return path.(string)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) setAlias(path string) {
|
|
|
|
|
if path = strings.TrimSpace(path); path != "" {
|
|
|
|
|
c.ctx.Set(_Alias, path)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RequestInputParams 获取所有参数
|
|
|
|
|
func (c *context) RequestInputParams() url.Values {
|
|
|
|
|
_ = c.ctx.Request.ParseForm()
|
|
|
|
|
return c.ctx.Request.Form
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RequestQueryParams 获取Query参数
|
|
|
|
|
func (c *context) RequestQueryParams() url.Values {
|
|
|
|
|
query, _ := url.ParseQuery(c.ctx.Request.URL.RawQuery)
|
|
|
|
|
return query
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RequestPostFormParams 获取 PostForm 参数
|
|
|
|
|
func (c *context) RequestPostFormParams() url.Values {
|
|
|
|
|
_ = c.ctx.Request.ParseForm()
|
|
|
|
|
return c.ctx.Request.PostForm
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Request 获取 Request
|
|
|
|
|
func (c *context) Request() *http.Request {
|
|
|
|
|
return c.ctx.Request
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *context) RawData() []byte {
|
|
|
|
|
body, ok := c.ctx.Get(_BodyName)
|
|
|
|
|
if !ok {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return body.([]byte)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Method 请求的method
|
|
|
|
|
func (c *context) Method() string {
|
|
|
|
|
return c.ctx.Request.Method
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Host 请求的host
|
|
|
|
|
func (c *context) Host() string {
|
|
|
|
|
return c.ctx.Request.Host
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Path 请求的路径(不附带querystring)
|
|
|
|
|
func (c *context) Path() string {
|
|
|
|
|
return c.ctx.Request.URL.Path
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// URI unescape后的uri
|
|
|
|
|
func (c *context) URI() string {
|
|
|
|
|
uri, _ := url.QueryUnescape(c.ctx.Request.URL.RequestURI())
|
|
|
|
|
return uri
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RequestContext (包装 Trace + Logger) 获取请求的 context (当client关闭后,会自动canceled)
|
|
|
|
|
func (c *context) RequestContext() StdContext {
|
|
|
|
|
return StdContext{
|
|
|
|
|
//c.ctx.Request.Context(),
|
|
|
|
|
stdCtx.Background(),
|
|
|
|
|
c.Trace(),
|
|
|
|
|
c.Logger(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ResponseWriter 获取 ResponseWriter
|
|
|
|
|
func (c *context) ResponseWriter() gin.ResponseWriter {
|
|
|
|
|
return c.ctx.Writer
|
|
|
|
|
}
|