base-golang/pkg/mux/context.go

407 lines
8.5 KiB
Go
Raw Normal View History

2024-07-23 10:23:43 +08:00
package mux
import (
"bytes"
stdCtx "context"
"io"
"net/http"
"net/url"
"strings"
"sync"
"git.bvbej.com/bvbej/base-golang/pkg/errno"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"go.uber.org/zap"
)
type HandlerFunc func(c Context)
type Trace = trace.T
const (
_Alias = "_alias_"
_TraceName = "_trace_"
_LoggerName = "_logger_"
_BodyName = "_body_"
_PayloadName = "_payload_"
_GraphPayloadName = "_graph_payload_"
_AbortErrorName = "_abort_error_"
_Auth = "_auth_"
)
var contextPool = &sync.Pool{
New: func() any {
return new(context)
},
}
func newContext(ctx *gin.Context) Context {
getContext := contextPool.Get().(*context)
getContext.ctx = ctx
return getContext
}
func releaseContext(ctx Context) {
c := ctx.(*context)
c.ctx = nil
contextPool.Put(c)
}
var _ Context = (*context)(nil)
type Context interface {
init()
Context() *gin.Context
// ShouldBindQuery 反序列化 query
// tag: `form:"xxx"` (注:不要写成 query)
ShouldBindQuery(obj any) error
// ShouldBindPostForm 反序列化 x-www-from-urlencoded
// tag: `form:"xxx"`
ShouldBindPostForm(obj any) error
// ShouldBindForm 同时反序列化 form-data;
// tag: `form:"xxx"`
ShouldBindForm(obj any) error
// ShouldBindJSON 反序列化 post-json
// tag: `json:"xxx"`
ShouldBindJSON(obj any) error
// ShouldBindURI 反序列化 path 参数(如路由路径为 /user/:name)
// tag: `uri:"xxx"`
ShouldBindURI(obj any) error
// Redirect 重定向
Redirect(code int, location string)
// Trace 获取 Trace 对象
Trace() Trace
setTrace(trace Trace)
disableTrace()
// Logger 获取 Logger 对象
Logger() *zap.Logger
setLogger(logger *zap.Logger)
// Payload 正确返回
Payload(payload any)
getPayload() any
// GraphPayload GraphQL返回值 与 api 返回结构不同
GraphPayload(payload any)
getGraphPayload() any
// HTML 返回界面
HTML(name string, obj any)
// AbortWithError 错误返回
AbortWithError(err errno.Error)
abortError() errno.Error
// Header 获取 Header 对象
Header() http.Header
// GetHeader 获取 Header
GetHeader(key string) string
// SetHeader 设置 Header
SetHeader(key, value string)
Auth() any
SetAuth(auth any)
// Authorization 获取请求认证信息
Authorization() string
// Platform 平台标识
Platform() string
// Alias 设置路由别名 for metrics uri
Alias() string
setAlias(path string)
// RequestInputParams 获取所有参数
RequestInputParams() url.Values
// RequestQueryParams 获取 Query 参数
RequestQueryParams() url.Values
// RequestPostFormParams 获取 PostForm 参数
RequestPostFormParams() url.Values
// Request 获取 Request 对象
Request() *http.Request
// RawData 获取 Request.Body
RawData() []byte
// Method 获取 Request.Method
Method() string
// Host 获取 Request.Host
Host() string
// Path 获取 请求的路径 Request.URL.Path (不附带 querystring)
Path() string
// URI 获取 unescape 后的 Request.URL.RequestURI()
URI() string
// RequestContext 获取请求的 context (当 client 关闭后,会自动 canceled)
RequestContext() StdContext
// ResponseWriter 获取 ResponseWriter 对象
ResponseWriter() gin.ResponseWriter
}
type context struct {
ctx *gin.Context
}
type StdContext struct {
stdCtx.Context
Trace
*zap.Logger
}
func (c *context) init() {
body, err := c.ctx.GetRawData()
if err != nil {
panic(err)
}
c.ctx.Set(_BodyName, body) // cache body是为了trace使用
c.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // re-construct req body
}
func (c *context) Context() *gin.Context {
return c.ctx
}
// ShouldBindQuery 反序列化querystring
// tag: `form:"xxx"` (注不要写成query)
func (c *context) ShouldBindQuery(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.Query)
}
// ShouldBindPostForm 反序列化 postform (querystring 会被忽略)
// tag: `form:"xxx"`
func (c *context) ShouldBindPostForm(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.FormPost)
}
// ShouldBindForm 同时反序列化querystring和postform;
// 当querystring和postform存在相同字段时postform优先使用。
// tag: `form:"xxx"`
func (c *context) ShouldBindForm(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.Form)
}
// ShouldBindJSON 反序列化postjson
// tag: `json:"xxx"`
func (c *context) ShouldBindJSON(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.JSON)
}
// ShouldBindURI 反序列化path参数(如路由路径为 /user/:name)
// tag: `uri:"xxx"`
func (c *context) ShouldBindURI(obj any) error {
return c.ctx.ShouldBindUri(obj)
}
// Redirect 重定向
func (c *context) Redirect(code int, location string) {
c.ctx.Redirect(code, location)
}
func (c *context) Trace() Trace {
t, ok := c.ctx.Get(_TraceName)
if !ok || t == nil {
return nil
}
return t.(Trace)
}
func (c *context) setTrace(trace Trace) {
c.ctx.Set(_TraceName, trace)
}
func (c *context) disableTrace() {
c.setTrace(nil)
}
func (c *context) Logger() *zap.Logger {
logger, ok := c.ctx.Get(_LoggerName)
if !ok {
return nil
}
return logger.(*zap.Logger)
}
func (c *context) setLogger(logger *zap.Logger) {
c.ctx.Set(_LoggerName, logger)
}
func (c *context) getPayload() any {
if payload, ok := c.ctx.Get(_PayloadName); ok != false {
return payload
}
return nil
}
func (c *context) Payload(payload any) {
c.ctx.Set(_PayloadName, payload)
}
func (c *context) getGraphPayload() any {
if payload, ok := c.ctx.Get(_GraphPayloadName); ok != false {
return payload
}
return nil
}
func (c *context) GraphPayload(payload any) {
c.ctx.Set(_GraphPayloadName, payload)
}
func (c *context) HTML(name string, obj any) {
c.ctx.HTML(200, name+".html", obj)
}
func (c *context) Header() http.Header {
header := c.ctx.Request.Header
clone := make(http.Header, len(header))
for k, v := range header {
value := make([]string, len(v))
copy(value, v)
clone[k] = value
}
return clone
}
func (c *context) GetHeader(key string) string {
return c.ctx.GetHeader(key)
}
func (c *context) SetHeader(key, value string) {
c.ctx.Header(key, value)
}
func (c *context) Auth() any {
val, ok := c.ctx.Get(_Auth)
if !ok {
return nil
}
return val
}
func (c *context) SetAuth(auth any) {
c.ctx.Set(_Auth, auth)
}
func (c *context) Authorization() string {
return c.ctx.GetHeader("Authorization")
}
func (c *context) Platform() string {
return c.ctx.GetHeader("X-Platform")
}
func (c *context) AbortWithError(err errno.Error) {
if err != nil {
httpCode := err.GetHttpCode()
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
c.ctx.AbortWithStatus(httpCode)
c.ctx.Set(_AbortErrorName, err)
}
}
func (c *context) abortError() errno.Error {
err, _ := c.ctx.Get(_AbortErrorName)
return err.(errno.Error)
}
func (c *context) Alias() string {
path, ok := c.ctx.Get(_Alias)
if !ok {
return ""
}
return path.(string)
}
func (c *context) setAlias(path string) {
if path = strings.TrimSpace(path); path != "" {
c.ctx.Set(_Alias, path)
}
}
// RequestInputParams 获取所有参数
func (c *context) RequestInputParams() url.Values {
_ = c.ctx.Request.ParseForm()
return c.ctx.Request.Form
}
// RequestQueryParams 获取Query参数
func (c *context) RequestQueryParams() url.Values {
query, _ := url.ParseQuery(c.ctx.Request.URL.RawQuery)
return query
}
// RequestPostFormParams 获取 PostForm 参数
func (c *context) RequestPostFormParams() url.Values {
_ = c.ctx.Request.ParseForm()
return c.ctx.Request.PostForm
}
// Request 获取 Request
func (c *context) Request() *http.Request {
return c.ctx.Request
}
func (c *context) RawData() []byte {
body, ok := c.ctx.Get(_BodyName)
if !ok {
return nil
}
return body.([]byte)
}
// Method 请求的method
func (c *context) Method() string {
return c.ctx.Request.Method
}
// Host 请求的host
func (c *context) Host() string {
return c.ctx.Request.Host
}
// Path 请求的路径(不附带querystring)
func (c *context) Path() string {
return c.ctx.Request.URL.Path
}
// URI unescape后的uri
func (c *context) URI() string {
uri, _ := url.QueryUnescape(c.ctx.Request.URL.RequestURI())
return uri
}
// RequestContext (包装 Trace + Logger) 获取请求的 context (当client关闭后会自动canceled)
func (c *context) RequestContext() StdContext {
return StdContext{
//c.ctx.Request.Context(),
stdCtx.Background(),
c.Trace(),
c.Logger(),
}
}
// ResponseWriter 获取 ResponseWriter
func (c *context) ResponseWriter() gin.ResponseWriter {
return c.ctx.Writer
}