first commit

This commit is contained in:
2024-07-23 10:23:43 +08:00
commit 7b4c2521a3
126 changed files with 15931 additions and 0 deletions

View File

@ -0,0 +1,242 @@
package client
import (
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/ticker"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/client/service"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
_ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json"
_ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf"
"git.bvbej.com/bvbej/base-golang/tool"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"net/http"
"net/url"
"reflect"
"sync/atomic"
"time"
)
const (
writeWait = 20 * time.Second
pongWait = 60 * time.Second
reconnectWait = 3 * time.Second
pingPeriod = (pongWait * 9) / 10
maxFrameMessageLen = 16 * 1024
maxSendBuffer = 32
)
var (
_ Client = (*client)(nil)
ErrBrokenPipe = errors.New("send to broken pipe")
ErrBufferPoolExceed = errors.New("send buffer exceed")
)
type Client interface {
readLoop()
writeLoop()
ping()
reconnect()
onReceive(msg []byte) error
onSend(msg []byte) error
connect() error
Send(router string, data any) error
Connect(requestHeader http.Header) error
OnReceiveError(f func(error))
OnReconnected(f func(error))
Close()
}
type client struct {
url url.URL
requestHeader http.Header
logger *zap.Logger
session *service.Session
isConnected atomic.Bool
routerCodec codec.Codec
send chan []byte
handlers map[string]*service.Handler // 注册的方法列表
onReceiveErr func(error)
pingTicker ticker.Ticker
checkConnTicker ticker.Ticker
onReconnect func(error)
}
func New(logger *zap.Logger, url url.URL, decoder string, handlers any) (Client, error) {
if !tool.InArray(url.Scheme, []string{"ws", "wss"}) {
return nil, errors.New(`param: scheme not supported`)
}
routerCodec := codec.GetCodec(decoder)
if routerCodec == nil {
return nil, errors.New(`param: codec not supported`)
}
components := service.RegisterHandler(handlers)
if len(components) == 0 {
return nil, errors.New(`param: handlers unqualified`)
}
c := &client{
logger: logger,
isConnected: atomic.Bool{},
routerCodec: routerCodec,
url: url,
send: make(chan []byte, maxSendBuffer),
handlers: components,
pingTicker: ticker.New(pingPeriod),
checkConnTicker: ticker.New(reconnectWait),
}
return c, nil
}
func (c *client) readLoop() {
_ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait))
c.session.Conn.SetPongHandler(func(string) error {
_ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, data, err := c.session.Conn.ReadMessage()
if err != nil {
c.isConnected.Store(false)
break
}
err = c.onReceive(data)
if err != nil && c.onReceiveErr != nil {
c.onReceiveErr(err)
}
}
}
func (c *client) writeLoop() {
for msg := range c.send {
_ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait))
err := c.session.Conn.WriteMessage(websocket.BinaryMessage, msg)
if err != nil {
c.logger.Sugar().Errorf("writeLoop err: %s", err)
}
}
}
func (c *client) ping() {
c.pingTicker.Process(func() {
_ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait))
_ = c.session.Conn.WriteMessage(websocket.PingMessage, nil)
})
}
func (c *client) reconnect() {
c.checkConnTicker.Process(func() {
if c.isConnected.Load() {
return
}
err := c.connect()
if c.onReconnect != nil {
c.onReconnect(err)
}
})
}
func (c *client) connect() error {
conn, _, err := websocket.DefaultDialer.Dial(c.url.String(), c.requestHeader)
if err != nil {
return fmt.Errorf("dial: %s", err)
}
c.session = service.NewSession(conn)
c.isConnected.Store(true)
go c.readLoop()
return nil
}
func (c *client) onReceive(msg []byte) error {
_, msgPack, err := c.routerCodec.Unmarshal(msg)
if err != nil {
return fmt.Errorf("onreceive: %v", err)
}
router, ok := msgPack.Router.(string)
if !ok {
return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router)
}
s, ok := c.handlers[router]
if !ok {
return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err)
}
if msgPack.Err != nil {
return fmt.Errorf("%s:%s", router, msgPack.Err)
}
var args = []reflect.Value{s.Receiver, reflect.ValueOf(c.session), reflect.ValueOf(msgPack.DataPtr)}
s.Method.Func.Call(args)
return nil
}
func (c *client) onSend(msg []byte) (err error) {
defer func() {
if e := recover(); e != nil {
err = ErrBrokenPipe
}
}()
if !c.isConnected.Load() {
return ErrBrokenPipe
}
if len(c.send) >= maxSendBuffer {
return ErrBufferPoolExceed
}
if len(msg) > maxFrameMessageLen {
return
}
c.send <- msg
return nil
}
func (c *client) Connect(requestHeader http.Header) error {
c.requestHeader = requestHeader
err := c.connect()
if err != nil {
return err
}
go c.ping()
go c.writeLoop()
go c.reconnect()
return nil
}
func (c *client) Send(router string, data any) error {
rb, err := c.routerCodec.Marshal(router, data, nil)
if err != nil {
return fmt.Errorf("service: %v", err)
}
return c.onSend(rb)
}
func (c *client) OnReceiveError(f func(error)) {
c.onReceiveErr = f
}
func (c *client) OnReconnected(f func(error)) {
c.onReconnect = f
}
func (c *client) Close() {
close(c.send)
c.pingTicker.Stop()
c.checkConnTicker.Stop()
_ = c.session.Conn.Close()
}

View File

@ -0,0 +1,28 @@
package service
import (
"reflect"
)
var (
typeOfBytes = reflect.TypeOf(([]byte)(nil))
typeOfSession = reflect.TypeOf(NewSession(nil))
)
// 方法检测
func isHandlerMethod(method reflect.Method) bool {
mt := method.Type
if method.PkgPath != "" {
return false
}
if mt.NumIn() != 3 {
return false
}
if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfSession {
return false
}
if mt.In(2).Kind() != reflect.Ptr && mt.In(2) != typeOfBytes {
return false
}
return true
}

View File

@ -0,0 +1,19 @@
package service
import (
"github.com/gorilla/websocket"
"time"
)
type Session struct {
Conn *websocket.Conn
Time time.Time
}
func NewSession(conn *websocket.Conn) *Session {
s := &Session{
Conn: conn,
Time: time.Now(),
}
return s
}

View File

@ -0,0 +1,57 @@
package service
import (
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/util"
"reflect"
"strings"
)
type Handler struct {
Receiver reflect.Value // 值
Method reflect.Method // 方法
Type reflect.Type // 类型
IsRawArg bool // 数据是否需要序列化
}
func RegisterHandler(components ...any) map[string]*Handler {
methods := make(map[string]*Handler)
for _, component := range components {
rt := reflect.TypeOf(component)
rv := reflect.ValueOf(component)
typeName := reflect.Indirect(rv).Type().Name()
if typeName == "" {
continue
}
if !util.IsExported(typeName) {
continue
}
for m := 0; m < rt.NumMethod(); m++ {
method := rt.Method(m)
mt := method.Type
mn := method.Name
if isHandlerMethod(method) {
raw := false
if mt.In(2) == typeOfBytes {
raw = true
}
router := fmt.Sprintf("%s.%s", strings.ToLower(typeName), strings.ToLower(mn))
methods[router] = &Handler{
Receiver: rv,
Method: method,
Type: mt.In(2),
IsRawArg: raw,
}
}
}
}
for router, handler := range methods {
codec.RegisterMessage(router, handler.Type)
}
return methods
}

View File

@ -0,0 +1,26 @@
package codec
type Codec interface {
Marshal(router string, dataPtr any, err error) ([]byte, error)
Unmarshal([]byte) (int, *MsgPack, error)
ToString(any) string
}
var codecsList = make(map[string]Codec)
func RegisterCodec(name string, codec Codec) {
if codec == nil {
panic("codec: Register provide is nil")
}
if _, dup := codecsList[name]; dup {
panic("codec: Register called twice for provide " + name)
}
codecsList[name] = codec
}
func GetCodec(name string) Codec {
if v, ok := codecsList[name]; ok {
return v
}
return nil
}

View File

@ -0,0 +1,97 @@
package json
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
)
type jsonCodec struct{}
type jsonReq struct {
Router string `json:"router"`
Data []byte `json:"data"`
Error string `json:"error,omitempty"`
}
type jsonAck struct {
Router string `json:"router"`
Data string `json:"data"`
Error string `json:"error,omitempty"`
}
func init() {
codec.RegisterCodec("json_codec", new(jsonCodec))
}
func (*jsonCodec) Marshal(router string, dataPtr any, retErr error) ([]byte, error) {
if router == "" {
return nil, fmt.Errorf("marshal: router is empty")
}
if dataPtr == nil && retErr == nil {
return nil, fmt.Errorf("marshal data in package is nil. router:%s dt:%T", router, dataPtr)
}
ack := &jsonAck{
Router: router,
}
if dataPtr != nil {
data, err := json.Marshal(dataPtr)
if err != nil {
return nil, fmt.Errorf("marshal json marshal failed. routerr:%s dt:%T err:%v", router, dataPtr, err)
}
ack.Data = base64.StdEncoding.EncodeToString(data)
}
if retErr != nil {
ack.Error = retErr.Error()
}
ackByte, err := json.Marshal(ack)
if err != nil {
return nil, fmt.Errorf("marshal json marshal failed. routerr:%s dt:%T err:%v", router, dataPtr, err)
}
return ackByte, nil
}
func (*jsonCodec) Unmarshal(msg []byte) (int, *codec.MsgPack, error) {
var i = len(msg)
req := &jsonReq{}
err := json.Unmarshal(msg, req)
if err != nil {
return i, nil, errors.New("unmarshal split message id failed")
}
var router = req.Router
msgPack := &codec.MsgPack{Router: router}
dt := codec.GetMessage(router)
if dt == nil {
return i, nil, fmt.Errorf("unmarshal message not registed. router:%s", router)
}
if req.Data != nil {
err = json.Unmarshal(req.Data, dt)
if err != nil {
return i, nil, fmt.Errorf("unmarshal json unmarshal failed. dt:%T msg:%s err:%v", dt, string(msg), err)
}
}
msgPack.DataPtr = dt
if req.Error != "" {
msgPack.Err = errors.New(req.Error)
}
return i, msgPack, nil
}
func (*jsonCodec) ToString(data any) string {
ab, err := json.Marshal(data)
if err != nil {
return fmt.Sprintf("invalid type %T", data)
}
return string(ab)
}

View File

@ -0,0 +1,43 @@
package codec
import (
"fmt"
"reflect"
"sync"
)
type MsgPack struct {
Router any
DataPtr any
Err error
}
var modelMap = make(map[any]reflect.Type)
var modelMapLock sync.RWMutex
func RegisterMessage(router any, datePtr any) {
modelMapLock.Lock()
defer modelMapLock.Unlock()
if _, ok := modelMap[router]; ok {
fmt.Println(fmt.Sprintf("codec: repeat registration. router:%s ", router))
return
}
if t, ok := datePtr.(reflect.Type); ok {
modelMap[router] = t.Elem()
} else {
t := reflect.TypeOf(datePtr)
if t.Kind() != reflect.Ptr {
panic(fmt.Errorf("codec: cannot use non-ptr message struct `%s`", t))
}
modelMap[router] = t.Elem()
}
}
func GetMessage(router any) any {
modelMapLock.RLock()
defer modelMapLock.RUnlock()
if ptr, ok := modelMap[router]; ok {
return reflect.New(ptr).Interface()
}
return nil
}

View File

@ -0,0 +1,85 @@
package protobuf
import (
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf/protocol"
"google.golang.org/protobuf/proto"
)
type protobufCodec struct{}
func init() {
codec.RegisterCodec("protobuf_codec", new(protobufCodec))
}
func (*protobufCodec) Marshal(router string, dataPtr any, retErr error) ([]byte, error) {
if router == "" {
return nil, fmt.Errorf("marshal: empty router")
}
if dataPtr == nil && retErr == nil {
return nil, fmt.Errorf("marshal: empty data")
}
ack := &protocol.TransPack{
Router: router,
}
if dataPtr != nil {
pbMsg, ok := dataPtr.(proto.Message)
if !ok {
return nil, fmt.Errorf("marshal: dataptr only support proto.Message type. router:%s dt:%T ",
router, dataPtr)
}
data, err := proto.Marshal(pbMsg)
if err != nil {
return nil, fmt.Errorf("marshal:protocol buffer marshal failed. router:%s dt:%T err:%v",
router, dataPtr, err)
}
ack.Data = data
} else {
ack.Error = retErr.Error()
}
ackByte, err := proto.Marshal(ack)
if err != nil {
return nil, fmt.Errorf("marshal:protocol buffer marshal failed. router:%s dt:%T err:%v",
router, ack, err)
}
return ackByte, nil
}
func (*protobufCodec) Unmarshal(msg []byte) (int, *codec.MsgPack, error) {
var l = len(msg)
req := &protocol.TransPack{}
err := proto.Unmarshal(msg, req)
if err != nil {
return l, nil, errors.New("unmarshal split message id failed.")
}
var router = req.Router
msgPack := &codec.MsgPack{Router: router}
dt := codec.GetMessage(router)
if dt == nil {
return l, nil, fmt.Errorf("unmarshal message not registed. router:%s",
router)
}
if req.Data != nil {
err = proto.Unmarshal(req.Data, dt.(proto.Message))
if err != nil {
return l, nil, fmt.Errorf("unmarshal failed. router:%s", router)
}
}
msgPack.DataPtr = dt
if req.Error != "" {
msgPack.Err = errors.New(req.Error)
}
return l, msgPack, nil
}
func (*protobufCodec) ToString(data any) string {
pbMsg, ok := data.(proto.Message)
if !ok {
return fmt.Sprintf("invalid type %T", data)
}
marshal, _ := proto.Marshal(pbMsg)
return string(marshal)
}

View File

@ -0,0 +1,226 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.21.1
// source: base.proto
package protocol
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// 通信包装
type TransPack struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Router string `protobuf:"bytes,1,opt,name=router,proto3" json:"router,omitempty"`
Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
}
func (x *TransPack) Reset() {
*x = TransPack{}
if protoimpl.UnsafeEnabled {
mi := &file_base_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *TransPack) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TransPack) ProtoMessage() {}
func (x *TransPack) ProtoReflect() protoreflect.Message {
mi := &file_base_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TransPack.ProtoReflect.Descriptor instead.
func (*TransPack) Descriptor() ([]byte, []int) {
return file_base_proto_rawDescGZIP(), []int{0}
}
func (x *TransPack) GetRouter() string {
if x != nil {
return x.Router
}
return ""
}
func (x *TransPack) GetData() []byte {
if x != nil {
return x.Data
}
return nil
}
func (x *TransPack) GetError() string {
if x != nil {
return x.Error
}
return ""
}
// 连接检测
type PingPang struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Timestamp int64 `protobuf:"varint,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` //时间戳
}
func (x *PingPang) Reset() {
*x = PingPang{}
if protoimpl.UnsafeEnabled {
mi := &file_base_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PingPang) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PingPang) ProtoMessage() {}
func (x *PingPang) ProtoReflect() protoreflect.Message {
mi := &file_base_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PingPang.ProtoReflect.Descriptor instead.
func (*PingPang) Descriptor() ([]byte, []int) {
return file_base_proto_rawDescGZIP(), []int{1}
}
func (x *PingPang) GetTimestamp() int64 {
if x != nil {
return x.Timestamp
}
return 0
}
var File_base_proto protoreflect.FileDescriptor
var file_base_proto_rawDesc = []byte{
0x0a, 0x0a, 0x62, 0x61, 0x73, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x4d, 0x0a, 0x09, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x50,
0x61, 0x63, 0x6b, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x64,
0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12,
0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x28, 0x0a, 0x08, 0x50, 0x69, 0x6e, 0x67, 0x50, 0x61, 0x6e,
0x67, 0x12, 0x1c, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01,
0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42,
0x27, 0x5a, 0x25, 0x70, 0x6b, 0x67, 0x2f, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74,
0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_base_proto_rawDescOnce sync.Once
file_base_proto_rawDescData = file_base_proto_rawDesc
)
func file_base_proto_rawDescGZIP() []byte {
file_base_proto_rawDescOnce.Do(func() {
file_base_proto_rawDescData = protoimpl.X.CompressGZIP(file_base_proto_rawDescData)
})
return file_base_proto_rawDescData
}
var file_base_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_base_proto_goTypes = []any{
(*TransPack)(nil), // 0: protocol.TransPack
(*PingPang)(nil), // 1: protocol.PingPang
}
var file_base_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_base_proto_init() }
func file_base_proto_init() {
if File_base_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_base_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*TransPack); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_base_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*PingPang); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_base_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_base_proto_goTypes,
DependencyIndexes: file_base_proto_depIdxs,
MessageInfos: file_base_proto_msgTypes,
}.Build()
File_base_proto = out.File
file_base_proto_rawDesc = nil
file_base_proto_goTypes = nil
file_base_proto_depIdxs = nil
}

View File

@ -0,0 +1,15 @@
syntax = "proto3";
package protocol;
option go_package = "pkg/websocket/codec/protobuf/protocol";
//通信包装
message TransPack {
string router = 1;
bytes data = 2;
string error = 3;
}
//连接检测
message PingPang {
int64 timestamp = 1; //时间戳
}

View File

@ -0,0 +1,6 @@
package peer
type ConnectionCallBack interface {
OnClosed(*Session)
OnReceive(*Session, []byte) error
}

View File

@ -0,0 +1,98 @@
package connect
import (
"context"
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/mux"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"net/http"
"net/url"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
"github.com/gorilla/websocket"
)
var _ WsAcceptor = (*wsAcceptor)(nil)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type WsAcceptor interface {
Start(addr string) error
Stop()
GinHandle(ctx *gin.Context)
HandlerFunc() mux.HandlerFunc
}
type wsAcceptor struct {
server *http.Server
sessMgr *peer.SessionManager
logger *zap.Logger
}
func NewWsAcceptor(sessMgr *peer.SessionManager, loggers *zap.Logger) WsAcceptor {
return &wsAcceptor{
sessMgr: sessMgr,
logger: loggers,
}
}
func (ws *wsAcceptor) Start(addr string) error {
urlObj, err := url.Parse(addr)
if err != nil {
return fmt.Errorf("websocket urlparse failed. url(%s) %v", addr, err)
}
if urlObj.Path == "" {
return fmt.Errorf("websocket start failed. expect path in url to listen addr:%s", addr)
}
http.HandleFunc(urlObj.Path, func(w http.ResponseWriter, r *http.Request) {
c, upgradeErr := upgrader.Upgrade(w, r, nil)
if upgradeErr != nil {
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
return
}
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
})
ws.server = &http.Server{Addr: urlObj.Host}
err = ws.server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("websocket ListenAndServe addr:%s failed:%v", addr, err)
}
return nil
}
func (ws *wsAcceptor) Stop() {
ws.sessMgr.CloseAllSession()
if ws.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
if err := ws.server.Shutdown(ctx); err != nil {
ws.logger.Sugar().Errorf("server shutdown err:[%s]", err)
}
}
}
func (ws *wsAcceptor) GinHandle(ctx *gin.Context) {
c, upgradeErr := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
if upgradeErr != nil {
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
return
}
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
}
func (ws *wsAcceptor) HandlerFunc() mux.HandlerFunc {
return func(c mux.Context) {
ws.GinHandle(c.Context())
}
}

View File

@ -0,0 +1,134 @@
package connect
import (
"errors"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
"github.com/gorilla/websocket"
)
const (
writeWait = 20 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
maxFrameMessageLen = 16 * 1024 //4 * 4096
maxSendBuffer = 16
)
var (
ErrBrokenPipe = errors.New("send to broken pipe")
ErrBufferPoolExceed = errors.New("send buffer exceed")
)
type wsConnection struct {
peer.ConnectionIdentify
pm *peer.SessionManager
conn *websocket.Conn
send chan []byte
running bool
}
func NewConnection(conn *websocket.Conn, p *peer.SessionManager) *wsConnection {
wsc := &wsConnection{
conn: conn,
pm: p,
send: make(chan []byte, maxSendBuffer),
running: true,
}
go wsc.acceptLoop()
go wsc.sendLoop()
return wsc
}
func (ws *wsConnection) Peer() *peer.SessionManager {
return ws.pm
}
func (ws *wsConnection) Raw() any {
if ws.conn == nil {
return nil
}
return ws.conn
}
func (ws *wsConnection) RemoteAddr() string {
if ws.conn == nil {
return ""
}
return ws.conn.RemoteAddr().String()
}
func (ws *wsConnection) Close() {
_ = ws.conn.Close()
ws.running = false
}
func (ws *wsConnection) Send(msg []byte) (err error) {
defer func() {
if e := recover(); e != nil {
err = ErrBrokenPipe
}
}()
if !ws.running {
return ErrBrokenPipe
}
if len(ws.send) >= maxSendBuffer {
return ErrBufferPoolExceed
}
if len(msg) > maxFrameMessageLen {
return
}
ws.send <- msg
return nil
}
func (ws *wsConnection) acceptLoop() {
defer func() {
ws.pm.Unregister <- ws.ID()
_ = ws.conn.Close()
ws.running = false
}()
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
ws.conn.SetPongHandler(func(string) error {
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for ws.conn != nil {
_, data, err := ws.conn.ReadMessage()
if err != nil {
break
}
ws.pm.ProcessMessage(ws.ID(), data)
}
}
func (ws *wsConnection) sendLoop() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
_ = ws.conn.Close()
ws.running = false
close(ws.send)
}()
for {
select {
case msg := <-ws.send:
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := ws.conn.WriteMessage(websocket.BinaryMessage, msg); err != nil {
return
}
case <-ticker.C:
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := ws.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func (ws *wsConnection) IsClosed() bool {
return !ws.running
}

View File

@ -0,0 +1,23 @@
package peer
type Connection interface {
Raw() any
Peer() *SessionManager
Send(msg []byte) error
Close()
ID() int64
RemoteAddr() string
IsClosed() bool
}
type ConnectionIdentify struct {
id int64
}
func (ci *ConnectionIdentify) ID() int64 {
return ci.id
}
func (ci *ConnectionIdentify) SetID(id int64) {
ci.id = id
}

View File

@ -0,0 +1,16 @@
package peer
import "time"
type Session struct {
Conn Connection
Time time.Time
}
func NewSession(conn Connection) *Session {
s := &Session{
Conn: conn,
Time: time.Now(),
}
return s
}

View File

@ -0,0 +1,119 @@
package peer
import (
"sync"
"sync/atomic"
)
type SessionManager struct {
sessionList sync.Map // 使用Id关联会话
connIDGen int64 // 记录已经生成的会话ID流水号
count int64 // 记录当前在使用的会话数量
callback ConnectionCallBack
Register chan *Session
Unregister chan int64
}
func (mgr *SessionManager) SetIDBase(base int64) {
atomic.StoreInt64(&mgr.connIDGen, base)
}
func (mgr *SessionManager) Count() int {
return int(atomic.LoadInt64(&mgr.count))
}
func (mgr *SessionManager) Add(sess *Session) {
id := atomic.AddInt64(&mgr.connIDGen, 1)
atomic.AddInt64(&mgr.count, 1)
sess.Conn.(interface {
SetID(int64)
}).SetID(id)
mgr.sessionList.Store(id, sess)
}
func (mgr *SessionManager) Close(id int64) {
if v, ok := mgr.sessionList.Load(id); ok {
if mgr.callback != nil {
go mgr.callback.OnClosed(v.(*Session))
}
}
mgr.sessionList.Delete(id)
atomic.AddInt64(&mgr.count, -1)
}
func (mgr *SessionManager) ProcessMessage(id int64, msg []byte) {
if v, ok := mgr.sessionList.Load(id); ok {
if mgr.callback != nil {
go func() {
err := mgr.callback.OnReceive(v.(*Session), msg)
if err != nil {
v.(*Session).Conn.Close()
}
}()
}
}
}
func (mgr *SessionManager) run() {
for {
select {
case client := <-mgr.Register:
mgr.connIDGen++
mgr.count++
client.Conn.(interface {
SetID(int64)
}).SetID(mgr.connIDGen)
mgr.sessionList.Store(mgr.connIDGen, client)
case clientID := <-mgr.Unregister:
if v, ok := mgr.sessionList.Load(clientID); ok {
if mgr.callback != nil {
go mgr.callback.OnClosed(v.(*Session))
}
}
mgr.sessionList.Delete(clientID)
mgr.count--
}
}
}
func (mgr *SessionManager) GetSession(id int64) *Session {
if v, ok := mgr.sessionList.Load(id); ok {
return v.(*Session)
}
return nil
}
func (mgr *SessionManager) VisitSession(callback func(*Session) bool) {
mgr.sessionList.Range(func(key, value any) bool {
return callback(value.(*Session))
})
}
func (mgr *SessionManager) CloseAllSession() {
mgr.VisitSession(func(sess *Session) bool {
sess.Conn.Close()
return true
})
}
func (mgr *SessionManager) SessionCount() int64 {
return atomic.LoadInt64(&mgr.count)
}
func NewSessionMgr(callback ConnectionCallBack) *SessionManager {
s := &SessionManager{
callback: callback,
Register: make(chan *Session),
Unregister: make(chan int64),
}
go s.run()
return s
}

View File

@ -0,0 +1,27 @@
package service
import (
"time"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
)
type Component interface {
Init()
OnSessionClose(*peer.Session) bool
OnRequestFinished(*peer.Session, string, any, string, time.Duration)
}
type ComponentBase struct{}
func (c *ComponentBase) Init() {
}
func (c *ComponentBase) OnSessionClose(session *peer.Session) bool {
return false
}
func (c *ComponentBase) OnRequestFinished(session *peer.Session, router string, req any, errMsg string, delta time.Duration) {
}

View File

@ -0,0 +1,33 @@
package service
import (
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
"reflect"
)
var (
typeOfError = reflect.TypeOf((*error)(nil)).Elem()
typeOfBytes = reflect.TypeOf(([]byte)(nil))
typeOfSession = reflect.TypeOf(peer.NewSession(nil))
)
// 方法检测
func isHandlerMethod(method reflect.Method) bool {
mt := method.Type
if method.PkgPath != "" {
return false
}
if mt.NumIn() != 3 {
return false
}
if mt.NumOut() != 2 {
return false
}
if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfSession {
return false
}
if (mt.In(2).Kind() != reflect.Ptr && mt.In(2) != typeOfBytes) || mt.Out(1) != typeOfError || mt.Out(0).Kind() != reflect.Ptr {
return false
}
return true
}

View File

@ -0,0 +1,98 @@
package service
import (
"errors"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/util"
"reflect"
"strings"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
)
type Handler struct {
Receiver reflect.Value // 值
Method reflect.Method // 方法
Type reflect.Type // 类型
IsRawArg bool // 数据是否需要序列化
}
type Service struct {
Name string // 服务名
Type reflect.Type // 服务类型
Receiver reflect.Value // 服务值
Handlers map[string]*Handler // 注册的方法列表
Component Component
}
func NewService(comp Component) *Service {
s := &Service{
Type: reflect.TypeOf(comp),
Receiver: reflect.ValueOf(comp),
Component: comp,
}
s.Name = strings.ToLower(reflect.Indirect(s.Receiver).Type().Name())
//调用初始化方法
s.Component.Init()
return s
}
func (s *Service) SuitableHandlerMethods(typ reflect.Type) map[string]*Handler {
methods := make(map[string]*Handler)
for m := 0; m < typ.NumMethod(); m++ {
method := typ.Method(m)
mt := method.Type
mn := method.Name
if isHandlerMethod(method) {
raw := false
if mt.In(2) == typeOfBytes {
raw = true
}
mn = strings.ToLower(mn)
methods[mn] = &Handler{Method: method, Type: mt.In(2), IsRawArg: raw}
}
}
return methods
}
func (s *Service) ExtractHandler() error {
typeName := reflect.Indirect(s.Receiver).Type().Name()
if typeName == "" {
return errors.New("no service name for type " + s.Type.String())
}
if !util.IsExported(typeName) {
return errors.New("type " + typeName + " is not exported")
}
s.Handlers = s.SuitableHandlerMethods(s.Type)
for i := range s.Handlers {
s.Handlers[i].Receiver = s.Receiver
}
if reflect.Indirect(s.Receiver).NumField() > 0 {
filedNum := reflect.Indirect(s.Receiver).NumField()
for i := 0; i < filedNum; i++ {
ty := reflect.Indirect(s.Receiver).Field(i).Type().Name()
if ty == ChildName {
h := s.SuitableHandlerMethods(reflect.Indirect(s.Receiver).Field(i).Elem().Type())
for ih, v := range h {
s.Handlers[ih] = v
s.Handlers[ih].Receiver = reflect.Indirect(s.Receiver).Field(i).Elem()
}
}
}
}
if len(s.Handlers) == 0 {
str := "service: "
method := s.SuitableHandlerMethods(reflect.PtrTo(s.Type))
if len(method) != 0 {
str = "type " + s.Name + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
} else {
str = "type " + s.Name + " has no exported methods of suitable type"
}
return errors.New(str)
}
return nil
}
func (s *Service) OnSessionClose(session *peer.Session) bool {
return s.Component.OnSessionClose(session)
}

View File

@ -0,0 +1,59 @@
package service
import (
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
_ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
"go.uber.org/zap"
)
var (
serviceLogger *zap.Logger
RegisteredServiceList = make(map[string]*Service) // all registered service
RouterCodec codec.Codec
)
const ChildName = "Hbase"
type Hbase any
func RegisterService(logger *zap.Logger, comp ...Component) {
serviceLogger = logger
for _, v := range comp {
s := NewService(v)
if _, ok := RegisteredServiceList[s.Name]; ok {
serviceLogger.Sugar().Errorf("service: service already defined: %s", s.Name)
}
if err := s.ExtractHandler(); err != nil {
serviceLogger.Sugar().Errorf("service: extract handler function failed: %v", err)
}
RegisteredServiceList[s.Name] = s
for name, handler := range s.Handlers {
router := fmt.Sprintf("%s.%s", s.Name, name)
//注册消息 用于解码
codec.RegisterMessage(router, handler.Type)
serviceLogger.Sugar().Debugf("service: router %s param %s registed", router, handler.Type)
}
}
}
func SetCodec(name string) error {
RouterCodec = codec.GetCodec(name)
if RouterCodec == nil {
return fmt.Errorf("service: codec %s not registered", name)
}
return nil
}
func Send(session *peer.Session, router string, data any) error {
if RouterCodec == nil {
return fmt.Errorf("service: codec not set")
}
rb, err := RouterCodec.Marshal(router, data, nil)
if err != nil {
return fmt.Errorf("service: %v", err)
}
return session.Conn.Send(rb)
}

View File

@ -0,0 +1,102 @@
package service
import (
"fmt"
"reflect"
"strings"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
)
type callBackEntity struct{}
func GetSessionManager() *peer.SessionManager {
return peer.NewSessionMgr(&callBackEntity{})
}
func (cb *callBackEntity) OnClosed(session *peer.Session) {
defer func() {
if err := recover(); err != nil {
fmt.Println(fmt.Sprintf("OnClosed: session:%d err:%v", session.Conn.ID(), err))
}
}()
for _, v := range RegisteredServiceList {
if ok := v.OnSessionClose(session); ok {
return
}
}
}
func (cb *callBackEntity) OnReceive(session *peer.Session, msg []byte) error {
_, msgPack, err := RouterCodec.Unmarshal(msg)
if err != nil {
return fmt.Errorf("onreceive: %v", err)
}
router, ok := msgPack.Router.(string)
if !ok {
return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router)
}
routerArr := strings.Split(router, ".")
if len(routerArr) != 2 {
return fmt.Errorf("onreceive: invalid router:%s", msgPack.Router)
}
s, ok := RegisteredServiceList[routerArr[0]]
if !ok {
return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err)
}
h, ok := s.Handlers[routerArr[1]]
if !ok {
return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err)
}
t1 := time.Now()
var args = []reflect.Value{h.Receiver, reflect.ValueOf(session), reflect.ValueOf(msgPack.DataPtr)}
var res any
var rb []byte
res, err = CallHandlerFunc(h.Method, args)
if res != nil && !reflect.ValueOf(res).IsNil() {
rb, err = RouterCodec.Marshal(router, res, nil)
if err != nil {
return fmt.Errorf("service: %v", err)
}
err = session.Conn.Send(rb)
if err != nil {
serviceLogger.Sugar().Warnf("warn! service send msg failed router:%s err:%v", router, err)
}
} else {
rb, err = RouterCodec.Marshal(router, nil, err)
if err != nil {
return fmt.Errorf("service: %v", err)
}
err = session.Conn.Send(rb)
if err != nil {
serviceLogger.Sugar().Warnf("warn! service send msg failed router:%s err:%v", router, err)
}
}
var errs string
if err != nil {
errs = err.Error()
}
dt := time.Since(t1)
go s.Component.OnRequestFinished(session, router, RouterCodec.ToString(msgPack.DataPtr), errs, dt)
return nil
}
func CallHandlerFunc(foo reflect.Method, args []reflect.Value) (retValue any, retErr error) {
defer func() {
if err := recover(); err != nil {
fmt.Println(fmt.Sprintf("CallHandlerFunc: %v", err))
retValue = nil
retErr = fmt.Errorf("CallHandlerFunc: call method pkg:%s method:%s err:%v", foo.PkgPath, foo.Name, err)
}
}()
if ret := foo.Func.Call(args); len(ret) > 0 {
var err error = nil
if r1 := ret[1].Interface(); r1 != nil {
err = r1.(error)
}
return ret[0].Interface(), err
}
return nil, fmt.Errorf("CallHandlerFunc: call method pkg:%s method:%s", foo.PkgPath, foo.Name)
}

View File

@ -0,0 +1,11 @@
package util
import (
"unicode"
"unicode/utf8"
)
func IsExported(name string) bool {
w, _ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(w)
}