407 lines
8.5 KiB
Go
407 lines
8.5 KiB
Go
package mux
|
||
|
||
import (
|
||
"bytes"
|
||
stdCtx "context"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync"
|
||
|
||
"gitea.bvbej.com/bvbej/base-golang/pkg/errno"
|
||
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/gin-gonic/gin/binding"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
type HandlerFunc func(c Context)
|
||
|
||
type Trace = trace.T
|
||
|
||
const (
|
||
_Alias = "_alias_"
|
||
_TraceName = "_trace_"
|
||
_LoggerName = "_logger_"
|
||
_BodyName = "_body_"
|
||
_PayloadName = "_payload_"
|
||
_GraphPayloadName = "_graph_payload_"
|
||
_AbortErrorName = "_abort_error_"
|
||
_Auth = "_auth_"
|
||
)
|
||
|
||
var contextPool = &sync.Pool{
|
||
New: func() any {
|
||
return new(context)
|
||
},
|
||
}
|
||
|
||
func newContext(ctx *gin.Context) Context {
|
||
getContext := contextPool.Get().(*context)
|
||
getContext.ctx = ctx
|
||
return getContext
|
||
}
|
||
|
||
func releaseContext(ctx Context) {
|
||
c := ctx.(*context)
|
||
c.ctx = nil
|
||
contextPool.Put(c)
|
||
}
|
||
|
||
var _ Context = (*context)(nil)
|
||
|
||
type Context interface {
|
||
init()
|
||
|
||
Context() *gin.Context
|
||
|
||
// ShouldBindQuery 反序列化 query
|
||
// tag: `form:"xxx"` (注:不要写成 query)
|
||
ShouldBindQuery(obj any) error
|
||
|
||
// ShouldBindPostForm 反序列化 x-www-from-urlencoded
|
||
// tag: `form:"xxx"`
|
||
ShouldBindPostForm(obj any) error
|
||
|
||
// ShouldBindForm 同时反序列化 form-data;
|
||
// tag: `form:"xxx"`
|
||
ShouldBindForm(obj any) error
|
||
|
||
// ShouldBindJSON 反序列化 post-json
|
||
// tag: `json:"xxx"`
|
||
ShouldBindJSON(obj any) error
|
||
|
||
// ShouldBindURI 反序列化 path 参数(如路由路径为 /user/:name)
|
||
// tag: `uri:"xxx"`
|
||
ShouldBindURI(obj any) error
|
||
|
||
// Redirect 重定向
|
||
Redirect(code int, location string)
|
||
|
||
// Trace 获取 Trace 对象
|
||
Trace() Trace
|
||
setTrace(trace Trace)
|
||
disableTrace()
|
||
|
||
// Logger 获取 Logger 对象
|
||
Logger() *zap.Logger
|
||
setLogger(logger *zap.Logger)
|
||
|
||
// Payload 正确返回
|
||
Payload(payload any)
|
||
getPayload() any
|
||
|
||
// GraphPayload GraphQL返回值 与 api 返回结构不同
|
||
GraphPayload(payload any)
|
||
getGraphPayload() any
|
||
|
||
// HTML 返回界面
|
||
HTML(name string, obj any)
|
||
|
||
// AbortWithError 错误返回
|
||
AbortWithError(err errno.Error)
|
||
abortError() errno.Error
|
||
|
||
// Header 获取 Header 对象
|
||
Header() http.Header
|
||
// GetHeader 获取 Header
|
||
GetHeader(key string) string
|
||
// SetHeader 设置 Header
|
||
SetHeader(key, value string)
|
||
|
||
Auth() any
|
||
SetAuth(auth any)
|
||
|
||
// Authorization 获取请求认证信息
|
||
Authorization() string
|
||
|
||
// Platform 平台标识
|
||
Platform() string
|
||
|
||
// Alias 设置路由别名 for metrics uri
|
||
Alias() string
|
||
setAlias(path string)
|
||
|
||
// RequestInputParams 获取所有参数
|
||
RequestInputParams() url.Values
|
||
// RequestQueryParams 获取 Query 参数
|
||
RequestQueryParams() url.Values
|
||
// RequestPostFormParams 获取 PostForm 参数
|
||
RequestPostFormParams() url.Values
|
||
// Request 获取 Request 对象
|
||
Request() *http.Request
|
||
// RawData 获取 Request.Body
|
||
RawData() []byte
|
||
// Method 获取 Request.Method
|
||
Method() string
|
||
// Host 获取 Request.Host
|
||
Host() string
|
||
// Path 获取 请求的路径 Request.URL.Path (不附带 querystring)
|
||
Path() string
|
||
// URI 获取 unescape 后的 Request.URL.RequestURI()
|
||
URI() string
|
||
// RequestContext 获取请求的 context (当 client 关闭后,会自动 canceled)
|
||
RequestContext() StdContext
|
||
// ResponseWriter 获取 ResponseWriter 对象
|
||
ResponseWriter() gin.ResponseWriter
|
||
}
|
||
|
||
type context struct {
|
||
ctx *gin.Context
|
||
}
|
||
|
||
type StdContext struct {
|
||
stdCtx.Context
|
||
Trace
|
||
*zap.Logger
|
||
}
|
||
|
||
func (c *context) init() {
|
||
body, err := c.ctx.GetRawData()
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
|
||
c.ctx.Set(_BodyName, body) // cache body是为了trace使用
|
||
c.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // re-construct req body
|
||
}
|
||
|
||
func (c *context) Context() *gin.Context {
|
||
return c.ctx
|
||
}
|
||
|
||
// ShouldBindQuery 反序列化querystring
|
||
// tag: `form:"xxx"` (注:不要写成query)
|
||
func (c *context) ShouldBindQuery(obj any) error {
|
||
return c.ctx.ShouldBindWith(obj, binding.Query)
|
||
}
|
||
|
||
// ShouldBindPostForm 反序列化 postform (querystring 会被忽略)
|
||
// tag: `form:"xxx"`
|
||
func (c *context) ShouldBindPostForm(obj any) error {
|
||
return c.ctx.ShouldBindWith(obj, binding.FormPost)
|
||
}
|
||
|
||
// ShouldBindForm 同时反序列化querystring和postform;
|
||
// 当querystring和postform存在相同字段时,postform优先使用。
|
||
// tag: `form:"xxx"`
|
||
func (c *context) ShouldBindForm(obj any) error {
|
||
return c.ctx.ShouldBindWith(obj, binding.Form)
|
||
}
|
||
|
||
// ShouldBindJSON 反序列化postjson
|
||
// tag: `json:"xxx"`
|
||
func (c *context) ShouldBindJSON(obj any) error {
|
||
return c.ctx.ShouldBindWith(obj, binding.JSON)
|
||
}
|
||
|
||
// ShouldBindURI 反序列化path参数(如路由路径为 /user/:name)
|
||
// tag: `uri:"xxx"`
|
||
func (c *context) ShouldBindURI(obj any) error {
|
||
return c.ctx.ShouldBindUri(obj)
|
||
}
|
||
|
||
// Redirect 重定向
|
||
func (c *context) Redirect(code int, location string) {
|
||
c.ctx.Redirect(code, location)
|
||
}
|
||
|
||
func (c *context) Trace() Trace {
|
||
t, ok := c.ctx.Get(_TraceName)
|
||
if !ok || t == nil {
|
||
return nil
|
||
}
|
||
|
||
return t.(Trace)
|
||
}
|
||
|
||
func (c *context) setTrace(trace Trace) {
|
||
c.ctx.Set(_TraceName, trace)
|
||
}
|
||
|
||
func (c *context) disableTrace() {
|
||
c.setTrace(nil)
|
||
}
|
||
|
||
func (c *context) Logger() *zap.Logger {
|
||
logger, ok := c.ctx.Get(_LoggerName)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
|
||
return logger.(*zap.Logger)
|
||
}
|
||
|
||
func (c *context) setLogger(logger *zap.Logger) {
|
||
c.ctx.Set(_LoggerName, logger)
|
||
}
|
||
|
||
func (c *context) getPayload() any {
|
||
if payload, ok := c.ctx.Get(_PayloadName); ok != false {
|
||
return payload
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *context) Payload(payload any) {
|
||
c.ctx.Set(_PayloadName, payload)
|
||
}
|
||
|
||
func (c *context) getGraphPayload() any {
|
||
if payload, ok := c.ctx.Get(_GraphPayloadName); ok != false {
|
||
return payload
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *context) GraphPayload(payload any) {
|
||
c.ctx.Set(_GraphPayloadName, payload)
|
||
}
|
||
|
||
func (c *context) HTML(name string, obj any) {
|
||
c.ctx.HTML(200, name+".html", obj)
|
||
}
|
||
|
||
func (c *context) Header() http.Header {
|
||
header := c.ctx.Request.Header
|
||
|
||
clone := make(http.Header, len(header))
|
||
for k, v := range header {
|
||
value := make([]string, len(v))
|
||
copy(value, v)
|
||
|
||
clone[k] = value
|
||
}
|
||
return clone
|
||
}
|
||
|
||
func (c *context) GetHeader(key string) string {
|
||
return c.ctx.GetHeader(key)
|
||
}
|
||
|
||
func (c *context) SetHeader(key, value string) {
|
||
c.ctx.Header(key, value)
|
||
}
|
||
|
||
func (c *context) Auth() any {
|
||
val, ok := c.ctx.Get(_Auth)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
return val
|
||
}
|
||
|
||
func (c *context) SetAuth(auth any) {
|
||
c.ctx.Set(_Auth, auth)
|
||
}
|
||
|
||
func (c *context) Authorization() string {
|
||
return c.ctx.GetHeader("Authorization")
|
||
}
|
||
|
||
func (c *context) Platform() string {
|
||
return c.ctx.GetHeader("X-Platform")
|
||
}
|
||
|
||
func (c *context) AbortWithError(err errno.Error) {
|
||
if err != nil {
|
||
httpCode := err.GetHttpCode()
|
||
if httpCode == 0 {
|
||
httpCode = http.StatusInternalServerError
|
||
}
|
||
|
||
c.ctx.AbortWithStatus(httpCode)
|
||
c.ctx.Set(_AbortErrorName, err)
|
||
}
|
||
}
|
||
|
||
func (c *context) abortError() errno.Error {
|
||
err, _ := c.ctx.Get(_AbortErrorName)
|
||
return err.(errno.Error)
|
||
}
|
||
|
||
func (c *context) Alias() string {
|
||
path, ok := c.ctx.Get(_Alias)
|
||
if !ok {
|
||
return ""
|
||
}
|
||
|
||
return path.(string)
|
||
}
|
||
|
||
func (c *context) setAlias(path string) {
|
||
if path = strings.TrimSpace(path); path != "" {
|
||
c.ctx.Set(_Alias, path)
|
||
}
|
||
}
|
||
|
||
// RequestInputParams 获取所有参数
|
||
func (c *context) RequestInputParams() url.Values {
|
||
_ = c.ctx.Request.ParseForm()
|
||
return c.ctx.Request.Form
|
||
}
|
||
|
||
// RequestQueryParams 获取Query参数
|
||
func (c *context) RequestQueryParams() url.Values {
|
||
query, _ := url.ParseQuery(c.ctx.Request.URL.RawQuery)
|
||
return query
|
||
}
|
||
|
||
// RequestPostFormParams 获取 PostForm 参数
|
||
func (c *context) RequestPostFormParams() url.Values {
|
||
_ = c.ctx.Request.ParseForm()
|
||
return c.ctx.Request.PostForm
|
||
}
|
||
|
||
// Request 获取 Request
|
||
func (c *context) Request() *http.Request {
|
||
return c.ctx.Request
|
||
}
|
||
|
||
func (c *context) RawData() []byte {
|
||
body, ok := c.ctx.Get(_BodyName)
|
||
if !ok {
|
||
return nil
|
||
}
|
||
|
||
return body.([]byte)
|
||
}
|
||
|
||
// Method 请求的method
|
||
func (c *context) Method() string {
|
||
return c.ctx.Request.Method
|
||
}
|
||
|
||
// Host 请求的host
|
||
func (c *context) Host() string {
|
||
return c.ctx.Request.Host
|
||
}
|
||
|
||
// Path 请求的路径(不附带querystring)
|
||
func (c *context) Path() string {
|
||
return c.ctx.Request.URL.Path
|
||
}
|
||
|
||
// URI unescape后的uri
|
||
func (c *context) URI() string {
|
||
uri, _ := url.QueryUnescape(c.ctx.Request.URL.RequestURI())
|
||
return uri
|
||
}
|
||
|
||
// RequestContext (包装 Trace + Logger) 获取请求的 context (当client关闭后,会自动canceled)
|
||
func (c *context) RequestContext() StdContext {
|
||
return StdContext{
|
||
//c.ctx.Request.Context(),
|
||
stdCtx.Background(),
|
||
c.Trace(),
|
||
c.Logger(),
|
||
}
|
||
}
|
||
|
||
// ResponseWriter 获取 ResponseWriter
|
||
func (c *context) ResponseWriter() gin.ResponseWriter {
|
||
return c.ctx.Writer
|
||
}
|