base-golang/pkg/mux/context.go
2024-07-31 16:49:14 +08:00

407 lines
8.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mux
import (
"bytes"
stdCtx "context"
"io"
"net/http"
"net/url"
"strings"
"sync"
"gitea.bvbej.com/bvbej/base-golang/pkg/errno"
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"go.uber.org/zap"
)
type HandlerFunc func(c Context)
type Trace = trace.T
const (
_Alias = "_alias_"
_TraceName = "_trace_"
_LoggerName = "_logger_"
_BodyName = "_body_"
_PayloadName = "_payload_"
_GraphPayloadName = "_graph_payload_"
_AbortErrorName = "_abort_error_"
_Auth = "_auth_"
)
var contextPool = &sync.Pool{
New: func() any {
return new(context)
},
}
func newContext(ctx *gin.Context) Context {
getContext := contextPool.Get().(*context)
getContext.ctx = ctx
return getContext
}
func releaseContext(ctx Context) {
c := ctx.(*context)
c.ctx = nil
contextPool.Put(c)
}
var _ Context = (*context)(nil)
type Context interface {
init()
Context() *gin.Context
// ShouldBindQuery 反序列化 query
// tag: `form:"xxx"` (注:不要写成 query)
ShouldBindQuery(obj any) error
// ShouldBindPostForm 反序列化 x-www-from-urlencoded
// tag: `form:"xxx"`
ShouldBindPostForm(obj any) error
// ShouldBindForm 同时反序列化 form-data;
// tag: `form:"xxx"`
ShouldBindForm(obj any) error
// ShouldBindJSON 反序列化 post-json
// tag: `json:"xxx"`
ShouldBindJSON(obj any) error
// ShouldBindURI 反序列化 path 参数(如路由路径为 /user/:name)
// tag: `uri:"xxx"`
ShouldBindURI(obj any) error
// Redirect 重定向
Redirect(code int, location string)
// Trace 获取 Trace 对象
Trace() Trace
setTrace(trace Trace)
disableTrace()
// Logger 获取 Logger 对象
Logger() *zap.Logger
setLogger(logger *zap.Logger)
// Payload 正确返回
Payload(payload any)
getPayload() any
// GraphPayload GraphQL返回值 与 api 返回结构不同
GraphPayload(payload any)
getGraphPayload() any
// HTML 返回界面
HTML(name string, obj any)
// AbortWithError 错误返回
AbortWithError(err errno.Error)
abortError() errno.Error
// Header 获取 Header 对象
Header() http.Header
// GetHeader 获取 Header
GetHeader(key string) string
// SetHeader 设置 Header
SetHeader(key, value string)
Auth() any
SetAuth(auth any)
// Authorization 获取请求认证信息
Authorization() string
// Platform 平台标识
Platform() string
// Alias 设置路由别名 for metrics uri
Alias() string
setAlias(path string)
// RequestInputParams 获取所有参数
RequestInputParams() url.Values
// RequestQueryParams 获取 Query 参数
RequestQueryParams() url.Values
// RequestPostFormParams 获取 PostForm 参数
RequestPostFormParams() url.Values
// Request 获取 Request 对象
Request() *http.Request
// RawData 获取 Request.Body
RawData() []byte
// Method 获取 Request.Method
Method() string
// Host 获取 Request.Host
Host() string
// Path 获取 请求的路径 Request.URL.Path (不附带 querystring)
Path() string
// URI 获取 unescape 后的 Request.URL.RequestURI()
URI() string
// RequestContext 获取请求的 context (当 client 关闭后,会自动 canceled)
RequestContext() StdContext
// ResponseWriter 获取 ResponseWriter 对象
ResponseWriter() gin.ResponseWriter
}
type context struct {
ctx *gin.Context
}
type StdContext struct {
stdCtx.Context
Trace
*zap.Logger
}
func (c *context) init() {
body, err := c.ctx.GetRawData()
if err != nil {
panic(err)
}
c.ctx.Set(_BodyName, body) // cache body是为了trace使用
c.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // re-construct req body
}
func (c *context) Context() *gin.Context {
return c.ctx
}
// ShouldBindQuery 反序列化querystring
// tag: `form:"xxx"` (注不要写成query)
func (c *context) ShouldBindQuery(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.Query)
}
// ShouldBindPostForm 反序列化 postform (querystring 会被忽略)
// tag: `form:"xxx"`
func (c *context) ShouldBindPostForm(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.FormPost)
}
// ShouldBindForm 同时反序列化querystring和postform;
// 当querystring和postform存在相同字段时postform优先使用。
// tag: `form:"xxx"`
func (c *context) ShouldBindForm(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.Form)
}
// ShouldBindJSON 反序列化postjson
// tag: `json:"xxx"`
func (c *context) ShouldBindJSON(obj any) error {
return c.ctx.ShouldBindWith(obj, binding.JSON)
}
// ShouldBindURI 反序列化path参数(如路由路径为 /user/:name)
// tag: `uri:"xxx"`
func (c *context) ShouldBindURI(obj any) error {
return c.ctx.ShouldBindUri(obj)
}
// Redirect 重定向
func (c *context) Redirect(code int, location string) {
c.ctx.Redirect(code, location)
}
func (c *context) Trace() Trace {
t, ok := c.ctx.Get(_TraceName)
if !ok || t == nil {
return nil
}
return t.(Trace)
}
func (c *context) setTrace(trace Trace) {
c.ctx.Set(_TraceName, trace)
}
func (c *context) disableTrace() {
c.setTrace(nil)
}
func (c *context) Logger() *zap.Logger {
logger, ok := c.ctx.Get(_LoggerName)
if !ok {
return nil
}
return logger.(*zap.Logger)
}
func (c *context) setLogger(logger *zap.Logger) {
c.ctx.Set(_LoggerName, logger)
}
func (c *context) getPayload() any {
if payload, ok := c.ctx.Get(_PayloadName); ok != false {
return payload
}
return nil
}
func (c *context) Payload(payload any) {
c.ctx.Set(_PayloadName, payload)
}
func (c *context) getGraphPayload() any {
if payload, ok := c.ctx.Get(_GraphPayloadName); ok != false {
return payload
}
return nil
}
func (c *context) GraphPayload(payload any) {
c.ctx.Set(_GraphPayloadName, payload)
}
func (c *context) HTML(name string, obj any) {
c.ctx.HTML(200, name+".html", obj)
}
func (c *context) Header() http.Header {
header := c.ctx.Request.Header
clone := make(http.Header, len(header))
for k, v := range header {
value := make([]string, len(v))
copy(value, v)
clone[k] = value
}
return clone
}
func (c *context) GetHeader(key string) string {
return c.ctx.GetHeader(key)
}
func (c *context) SetHeader(key, value string) {
c.ctx.Header(key, value)
}
func (c *context) Auth() any {
val, ok := c.ctx.Get(_Auth)
if !ok {
return nil
}
return val
}
func (c *context) SetAuth(auth any) {
c.ctx.Set(_Auth, auth)
}
func (c *context) Authorization() string {
return c.ctx.GetHeader("Authorization")
}
func (c *context) Platform() string {
return c.ctx.GetHeader("X-Platform")
}
func (c *context) AbortWithError(err errno.Error) {
if err != nil {
httpCode := err.GetHttpCode()
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
c.ctx.AbortWithStatus(httpCode)
c.ctx.Set(_AbortErrorName, err)
}
}
func (c *context) abortError() errno.Error {
err, _ := c.ctx.Get(_AbortErrorName)
return err.(errno.Error)
}
func (c *context) Alias() string {
path, ok := c.ctx.Get(_Alias)
if !ok {
return ""
}
return path.(string)
}
func (c *context) setAlias(path string) {
if path = strings.TrimSpace(path); path != "" {
c.ctx.Set(_Alias, path)
}
}
// RequestInputParams 获取所有参数
func (c *context) RequestInputParams() url.Values {
_ = c.ctx.Request.ParseForm()
return c.ctx.Request.Form
}
// RequestQueryParams 获取Query参数
func (c *context) RequestQueryParams() url.Values {
query, _ := url.ParseQuery(c.ctx.Request.URL.RawQuery)
return query
}
// RequestPostFormParams 获取 PostForm 参数
func (c *context) RequestPostFormParams() url.Values {
_ = c.ctx.Request.ParseForm()
return c.ctx.Request.PostForm
}
// Request 获取 Request
func (c *context) Request() *http.Request {
return c.ctx.Request
}
func (c *context) RawData() []byte {
body, ok := c.ctx.Get(_BodyName)
if !ok {
return nil
}
return body.([]byte)
}
// Method 请求的method
func (c *context) Method() string {
return c.ctx.Request.Method
}
// Host 请求的host
func (c *context) Host() string {
return c.ctx.Request.Host
}
// Path 请求的路径(不附带querystring)
func (c *context) Path() string {
return c.ctx.Request.URL.Path
}
// URI unescape后的uri
func (c *context) URI() string {
uri, _ := url.QueryUnescape(c.ctx.Request.URL.RequestURI())
return uri
}
// RequestContext (包装 Trace + Logger) 获取请求的 context (当client关闭后会自动canceled)
func (c *context) RequestContext() StdContext {
return StdContext{
//c.ctx.Request.Context(),
stdCtx.Background(),
c.Trace(),
c.Logger(),
}
}
// ResponseWriter 获取 ResponseWriter
func (c *context) ResponseWriter() gin.ResponseWriter {
return c.ctx.Writer
}