first commit
This commit is contained in:
242
pkg/websocket/client/client.go
Normal file
242
pkg/websocket/client/client.go
Normal file
@ -0,0 +1,242 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/ticker"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/client/service"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
|
||||
_ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json"
|
||||
_ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf"
|
||||
"git.bvbej.com/bvbej/base-golang/tool"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
writeWait = 20 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
reconnectWait = 3 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
maxFrameMessageLen = 16 * 1024
|
||||
maxSendBuffer = 32
|
||||
)
|
||||
|
||||
var (
|
||||
_ Client = (*client)(nil)
|
||||
ErrBrokenPipe = errors.New("send to broken pipe")
|
||||
ErrBufferPoolExceed = errors.New("send buffer exceed")
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
readLoop()
|
||||
writeLoop()
|
||||
ping()
|
||||
reconnect()
|
||||
onReceive(msg []byte) error
|
||||
onSend(msg []byte) error
|
||||
connect() error
|
||||
|
||||
Send(router string, data any) error
|
||||
Connect(requestHeader http.Header) error
|
||||
OnReceiveError(f func(error))
|
||||
OnReconnected(f func(error))
|
||||
Close()
|
||||
}
|
||||
|
||||
type client struct {
|
||||
url url.URL
|
||||
requestHeader http.Header
|
||||
logger *zap.Logger
|
||||
session *service.Session
|
||||
isConnected atomic.Bool
|
||||
routerCodec codec.Codec
|
||||
send chan []byte
|
||||
handlers map[string]*service.Handler // 注册的方法列表
|
||||
onReceiveErr func(error)
|
||||
pingTicker ticker.Ticker
|
||||
checkConnTicker ticker.Ticker
|
||||
onReconnect func(error)
|
||||
}
|
||||
|
||||
func New(logger *zap.Logger, url url.URL, decoder string, handlers any) (Client, error) {
|
||||
if !tool.InArray(url.Scheme, []string{"ws", "wss"}) {
|
||||
return nil, errors.New(`param: scheme not supported`)
|
||||
}
|
||||
|
||||
routerCodec := codec.GetCodec(decoder)
|
||||
if routerCodec == nil {
|
||||
return nil, errors.New(`param: codec not supported`)
|
||||
}
|
||||
|
||||
components := service.RegisterHandler(handlers)
|
||||
if len(components) == 0 {
|
||||
return nil, errors.New(`param: handlers unqualified`)
|
||||
}
|
||||
|
||||
c := &client{
|
||||
logger: logger,
|
||||
isConnected: atomic.Bool{},
|
||||
routerCodec: routerCodec,
|
||||
url: url,
|
||||
send: make(chan []byte, maxSendBuffer),
|
||||
handlers: components,
|
||||
pingTicker: ticker.New(pingPeriod),
|
||||
checkConnTicker: ticker.New(reconnectWait),
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *client) readLoop() {
|
||||
_ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.session.Conn.SetPongHandler(func(string) error {
|
||||
_ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, data, err := c.session.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
c.isConnected.Store(false)
|
||||
break
|
||||
}
|
||||
err = c.onReceive(data)
|
||||
if err != nil && c.onReceiveErr != nil {
|
||||
c.onReceiveErr(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) writeLoop() {
|
||||
for msg := range c.send {
|
||||
_ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
err := c.session.Conn.WriteMessage(websocket.BinaryMessage, msg)
|
||||
if err != nil {
|
||||
c.logger.Sugar().Errorf("writeLoop err: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ping() {
|
||||
c.pingTicker.Process(func() {
|
||||
_ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
_ = c.session.Conn.WriteMessage(websocket.PingMessage, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) reconnect() {
|
||||
c.checkConnTicker.Process(func() {
|
||||
if c.isConnected.Load() {
|
||||
return
|
||||
}
|
||||
err := c.connect()
|
||||
if c.onReconnect != nil {
|
||||
c.onReconnect(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) connect() error {
|
||||
conn, _, err := websocket.DefaultDialer.Dial(c.url.String(), c.requestHeader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial: %s", err)
|
||||
}
|
||||
|
||||
c.session = service.NewSession(conn)
|
||||
c.isConnected.Store(true)
|
||||
|
||||
go c.readLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) onReceive(msg []byte) error {
|
||||
_, msgPack, err := c.routerCodec.Unmarshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("onreceive: %v", err)
|
||||
}
|
||||
router, ok := msgPack.Router.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router)
|
||||
}
|
||||
s, ok := c.handlers[router]
|
||||
if !ok {
|
||||
return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err)
|
||||
}
|
||||
|
||||
if msgPack.Err != nil {
|
||||
return fmt.Errorf("%s:%s", router, msgPack.Err)
|
||||
}
|
||||
|
||||
var args = []reflect.Value{s.Receiver, reflect.ValueOf(c.session), reflect.ValueOf(msgPack.DataPtr)}
|
||||
s.Method.Func.Call(args)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) onSend(msg []byte) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = ErrBrokenPipe
|
||||
}
|
||||
}()
|
||||
|
||||
if !c.isConnected.Load() {
|
||||
return ErrBrokenPipe
|
||||
}
|
||||
|
||||
if len(c.send) >= maxSendBuffer {
|
||||
return ErrBufferPoolExceed
|
||||
}
|
||||
|
||||
if len(msg) > maxFrameMessageLen {
|
||||
return
|
||||
}
|
||||
|
||||
c.send <- msg
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) Connect(requestHeader http.Header) error {
|
||||
c.requestHeader = requestHeader
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go c.ping()
|
||||
go c.writeLoop()
|
||||
go c.reconnect()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) Send(router string, data any) error {
|
||||
rb, err := c.routerCodec.Marshal(router, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service: %v", err)
|
||||
}
|
||||
return c.onSend(rb)
|
||||
}
|
||||
|
||||
func (c *client) OnReceiveError(f func(error)) {
|
||||
c.onReceiveErr = f
|
||||
}
|
||||
|
||||
func (c *client) OnReconnected(f func(error)) {
|
||||
c.onReconnect = f
|
||||
}
|
||||
|
||||
func (c *client) Close() {
|
||||
close(c.send)
|
||||
c.pingTicker.Stop()
|
||||
c.checkConnTicker.Stop()
|
||||
_ = c.session.Conn.Close()
|
||||
}
|
28
pkg/websocket/client/service/method.go
Normal file
28
pkg/websocket/client/service/method.go
Normal file
@ -0,0 +1,28 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var (
|
||||
typeOfBytes = reflect.TypeOf(([]byte)(nil))
|
||||
typeOfSession = reflect.TypeOf(NewSession(nil))
|
||||
)
|
||||
|
||||
// 方法检测
|
||||
func isHandlerMethod(method reflect.Method) bool {
|
||||
mt := method.Type
|
||||
if method.PkgPath != "" {
|
||||
return false
|
||||
}
|
||||
if mt.NumIn() != 3 {
|
||||
return false
|
||||
}
|
||||
if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfSession {
|
||||
return false
|
||||
}
|
||||
if mt.In(2).Kind() != reflect.Ptr && mt.In(2) != typeOfBytes {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
19
pkg/websocket/client/service/peer.go
Normal file
19
pkg/websocket/client/service/peer.go
Normal file
@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/gorilla/websocket"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Conn *websocket.Conn
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
func NewSession(conn *websocket.Conn) *Session {
|
||||
s := &Session{
|
||||
Conn: conn,
|
||||
Time: time.Now(),
|
||||
}
|
||||
return s
|
||||
}
|
57
pkg/websocket/client/service/receiver.go
Normal file
57
pkg/websocket/client/service/receiver.go
Normal file
@ -0,0 +1,57 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/util"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Receiver reflect.Value // 值
|
||||
Method reflect.Method // 方法
|
||||
Type reflect.Type // 类型
|
||||
IsRawArg bool // 数据是否需要序列化
|
||||
}
|
||||
|
||||
func RegisterHandler(components ...any) map[string]*Handler {
|
||||
methods := make(map[string]*Handler)
|
||||
for _, component := range components {
|
||||
rt := reflect.TypeOf(component)
|
||||
rv := reflect.ValueOf(component)
|
||||
|
||||
typeName := reflect.Indirect(rv).Type().Name()
|
||||
if typeName == "" {
|
||||
continue
|
||||
}
|
||||
if !util.IsExported(typeName) {
|
||||
continue
|
||||
}
|
||||
|
||||
for m := 0; m < rt.NumMethod(); m++ {
|
||||
method := rt.Method(m)
|
||||
mt := method.Type
|
||||
mn := method.Name
|
||||
if isHandlerMethod(method) {
|
||||
raw := false
|
||||
if mt.In(2) == typeOfBytes {
|
||||
raw = true
|
||||
}
|
||||
router := fmt.Sprintf("%s.%s", strings.ToLower(typeName), strings.ToLower(mn))
|
||||
methods[router] = &Handler{
|
||||
Receiver: rv,
|
||||
Method: method,
|
||||
Type: mt.In(2),
|
||||
IsRawArg: raw,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for router, handler := range methods {
|
||||
codec.RegisterMessage(router, handler.Type)
|
||||
}
|
||||
|
||||
return methods
|
||||
}
|
26
pkg/websocket/codec/codec.go
Normal file
26
pkg/websocket/codec/codec.go
Normal file
@ -0,0 +1,26 @@
|
||||
package codec
|
||||
|
||||
type Codec interface {
|
||||
Marshal(router string, dataPtr any, err error) ([]byte, error)
|
||||
Unmarshal([]byte) (int, *MsgPack, error)
|
||||
ToString(any) string
|
||||
}
|
||||
|
||||
var codecsList = make(map[string]Codec)
|
||||
|
||||
func RegisterCodec(name string, codec Codec) {
|
||||
if codec == nil {
|
||||
panic("codec: Register provide is nil")
|
||||
}
|
||||
if _, dup := codecsList[name]; dup {
|
||||
panic("codec: Register called twice for provide " + name)
|
||||
}
|
||||
codecsList[name] = codec
|
||||
}
|
||||
|
||||
func GetCodec(name string) Codec {
|
||||
if v, ok := codecsList[name]; ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
97
pkg/websocket/codec/json/json.go
Normal file
97
pkg/websocket/codec/json/json.go
Normal file
@ -0,0 +1,97 @@
|
||||
package json
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
|
||||
)
|
||||
|
||||
type jsonCodec struct{}
|
||||
|
||||
type jsonReq struct {
|
||||
Router string `json:"router"`
|
||||
Data []byte `json:"data"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type jsonAck struct {
|
||||
Router string `json:"router"`
|
||||
Data string `json:"data"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
codec.RegisterCodec("json_codec", new(jsonCodec))
|
||||
}
|
||||
|
||||
func (*jsonCodec) Marshal(router string, dataPtr any, retErr error) ([]byte, error) {
|
||||
if router == "" {
|
||||
return nil, fmt.Errorf("marshal: router is empty")
|
||||
}
|
||||
|
||||
if dataPtr == nil && retErr == nil {
|
||||
return nil, fmt.Errorf("marshal data in package is nil. router:%s dt:%T", router, dataPtr)
|
||||
}
|
||||
|
||||
ack := &jsonAck{
|
||||
Router: router,
|
||||
}
|
||||
|
||||
if dataPtr != nil {
|
||||
data, err := json.Marshal(dataPtr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal json marshal failed. routerr:%s dt:%T err:%v", router, dataPtr, err)
|
||||
}
|
||||
ack.Data = base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
if retErr != nil {
|
||||
ack.Error = retErr.Error()
|
||||
}
|
||||
|
||||
ackByte, err := json.Marshal(ack)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal json marshal failed. routerr:%s dt:%T err:%v", router, dataPtr, err)
|
||||
}
|
||||
return ackByte, nil
|
||||
}
|
||||
|
||||
func (*jsonCodec) Unmarshal(msg []byte) (int, *codec.MsgPack, error) {
|
||||
var i = len(msg)
|
||||
|
||||
req := &jsonReq{}
|
||||
err := json.Unmarshal(msg, req)
|
||||
if err != nil {
|
||||
return i, nil, errors.New("unmarshal split message id failed")
|
||||
}
|
||||
|
||||
var router = req.Router
|
||||
msgPack := &codec.MsgPack{Router: router}
|
||||
dt := codec.GetMessage(router)
|
||||
if dt == nil {
|
||||
return i, nil, fmt.Errorf("unmarshal message not registed. router:%s", router)
|
||||
}
|
||||
|
||||
if req.Data != nil {
|
||||
err = json.Unmarshal(req.Data, dt)
|
||||
if err != nil {
|
||||
return i, nil, fmt.Errorf("unmarshal json unmarshal failed. dt:%T msg:%s err:%v", dt, string(msg), err)
|
||||
}
|
||||
}
|
||||
msgPack.DataPtr = dt
|
||||
if req.Error != "" {
|
||||
msgPack.Err = errors.New(req.Error)
|
||||
}
|
||||
|
||||
return i, msgPack, nil
|
||||
}
|
||||
|
||||
func (*jsonCodec) ToString(data any) string {
|
||||
ab, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("invalid type %T", data)
|
||||
}
|
||||
return string(ab)
|
||||
}
|
43
pkg/websocket/codec/meta.go
Normal file
43
pkg/websocket/codec/meta.go
Normal file
@ -0,0 +1,43 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type MsgPack struct {
|
||||
Router any
|
||||
DataPtr any
|
||||
Err error
|
||||
}
|
||||
|
||||
var modelMap = make(map[any]reflect.Type)
|
||||
var modelMapLock sync.RWMutex
|
||||
|
||||
func RegisterMessage(router any, datePtr any) {
|
||||
modelMapLock.Lock()
|
||||
defer modelMapLock.Unlock()
|
||||
if _, ok := modelMap[router]; ok {
|
||||
fmt.Println(fmt.Sprintf("codec: repeat registration. router:%s ", router))
|
||||
return
|
||||
}
|
||||
if t, ok := datePtr.(reflect.Type); ok {
|
||||
modelMap[router] = t.Elem()
|
||||
} else {
|
||||
t := reflect.TypeOf(datePtr)
|
||||
if t.Kind() != reflect.Ptr {
|
||||
panic(fmt.Errorf("codec: cannot use non-ptr message struct `%s`", t))
|
||||
}
|
||||
modelMap[router] = t.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
func GetMessage(router any) any {
|
||||
modelMapLock.RLock()
|
||||
defer modelMapLock.RUnlock()
|
||||
if ptr, ok := modelMap[router]; ok {
|
||||
return reflect.New(ptr).Interface()
|
||||
}
|
||||
return nil
|
||||
}
|
85
pkg/websocket/codec/protobuf/protobuf.go
Normal file
85
pkg/websocket/codec/protobuf/protobuf.go
Normal file
@ -0,0 +1,85 @@
|
||||
package protobuf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf/protocol"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type protobufCodec struct{}
|
||||
|
||||
func init() {
|
||||
codec.RegisterCodec("protobuf_codec", new(protobufCodec))
|
||||
}
|
||||
|
||||
func (*protobufCodec) Marshal(router string, dataPtr any, retErr error) ([]byte, error) {
|
||||
if router == "" {
|
||||
return nil, fmt.Errorf("marshal: empty router")
|
||||
}
|
||||
if dataPtr == nil && retErr == nil {
|
||||
return nil, fmt.Errorf("marshal: empty data")
|
||||
}
|
||||
ack := &protocol.TransPack{
|
||||
Router: router,
|
||||
}
|
||||
if dataPtr != nil {
|
||||
pbMsg, ok := dataPtr.(proto.Message)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("marshal: dataptr only support proto.Message type. router:%s dt:%T ",
|
||||
router, dataPtr)
|
||||
}
|
||||
data, err := proto.Marshal(pbMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal:protocol buffer marshal failed. router:%s dt:%T err:%v",
|
||||
router, dataPtr, err)
|
||||
}
|
||||
ack.Data = data
|
||||
} else {
|
||||
ack.Error = retErr.Error()
|
||||
}
|
||||
ackByte, err := proto.Marshal(ack)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal:protocol buffer marshal failed. router:%s dt:%T err:%v",
|
||||
router, ack, err)
|
||||
}
|
||||
return ackByte, nil
|
||||
}
|
||||
|
||||
func (*protobufCodec) Unmarshal(msg []byte) (int, *codec.MsgPack, error) {
|
||||
var l = len(msg)
|
||||
req := &protocol.TransPack{}
|
||||
err := proto.Unmarshal(msg, req)
|
||||
if err != nil {
|
||||
return l, nil, errors.New("unmarshal split message id failed.")
|
||||
}
|
||||
var router = req.Router
|
||||
msgPack := &codec.MsgPack{Router: router}
|
||||
dt := codec.GetMessage(router)
|
||||
if dt == nil {
|
||||
return l, nil, fmt.Errorf("unmarshal message not registed. router:%s",
|
||||
router)
|
||||
}
|
||||
if req.Data != nil {
|
||||
err = proto.Unmarshal(req.Data, dt.(proto.Message))
|
||||
if err != nil {
|
||||
return l, nil, fmt.Errorf("unmarshal failed. router:%s", router)
|
||||
}
|
||||
}
|
||||
msgPack.DataPtr = dt
|
||||
if req.Error != "" {
|
||||
msgPack.Err = errors.New(req.Error)
|
||||
}
|
||||
return l, msgPack, nil
|
||||
}
|
||||
|
||||
func (*protobufCodec) ToString(data any) string {
|
||||
pbMsg, ok := data.(proto.Message)
|
||||
if !ok {
|
||||
return fmt.Sprintf("invalid type %T", data)
|
||||
}
|
||||
marshal, _ := proto.Marshal(pbMsg)
|
||||
return string(marshal)
|
||||
}
|
226
pkg/websocket/codec/protobuf/protocol/base.pb.go
Normal file
226
pkg/websocket/codec/protobuf/protocol/base.pb.go
Normal file
@ -0,0 +1,226 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.21.1
|
||||
// source: base.proto
|
||||
|
||||
package protocol
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// 通信包装
|
||||
type TransPack struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Router string `protobuf:"bytes,1,opt,name=router,proto3" json:"router,omitempty"`
|
||||
Data []byte `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"`
|
||||
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (x *TransPack) Reset() {
|
||||
*x = TransPack{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_base_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *TransPack) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*TransPack) ProtoMessage() {}
|
||||
|
||||
func (x *TransPack) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_base_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use TransPack.ProtoReflect.Descriptor instead.
|
||||
func (*TransPack) Descriptor() ([]byte, []int) {
|
||||
return file_base_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *TransPack) GetRouter() string {
|
||||
if x != nil {
|
||||
return x.Router
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *TransPack) GetData() []byte {
|
||||
if x != nil {
|
||||
return x.Data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *TransPack) GetError() string {
|
||||
if x != nil {
|
||||
return x.Error
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// 连接检测
|
||||
type PingPang struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Timestamp int64 `protobuf:"varint,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` //时间戳
|
||||
}
|
||||
|
||||
func (x *PingPang) Reset() {
|
||||
*x = PingPang{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_base_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *PingPang) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*PingPang) ProtoMessage() {}
|
||||
|
||||
func (x *PingPang) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_base_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use PingPang.ProtoReflect.Descriptor instead.
|
||||
func (*PingPang) Descriptor() ([]byte, []int) {
|
||||
return file_base_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *PingPang) GetTimestamp() int64 {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
var File_base_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_base_proto_rawDesc = []byte{
|
||||
0x0a, 0x0a, 0x62, 0x61, 0x73, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0x4d, 0x0a, 0x09, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x50,
|
||||
0x61, 0x63, 0x6b, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x64,
|
||||
0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12,
|
||||
0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
|
||||
0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x28, 0x0a, 0x08, 0x50, 0x69, 0x6e, 0x67, 0x50, 0x61, 0x6e,
|
||||
0x67, 0x12, 0x1c, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42,
|
||||
0x27, 0x5a, 0x25, 0x70, 0x6b, 0x67, 0x2f, 0x77, 0x65, 0x62, 0x73, 0x6f, 0x63, 0x6b, 0x65, 0x74,
|
||||
0x2f, 0x63, 0x6f, 0x64, 0x65, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f,
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_base_proto_rawDescOnce sync.Once
|
||||
file_base_proto_rawDescData = file_base_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_base_proto_rawDescGZIP() []byte {
|
||||
file_base_proto_rawDescOnce.Do(func() {
|
||||
file_base_proto_rawDescData = protoimpl.X.CompressGZIP(file_base_proto_rawDescData)
|
||||
})
|
||||
return file_base_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_base_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_base_proto_goTypes = []any{
|
||||
(*TransPack)(nil), // 0: protocol.TransPack
|
||||
(*PingPang)(nil), // 1: protocol.PingPang
|
||||
}
|
||||
var file_base_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_base_proto_init() }
|
||||
func file_base_proto_init() {
|
||||
if File_base_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_base_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*TransPack); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_base_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*PingPang); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_base_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_base_proto_goTypes,
|
||||
DependencyIndexes: file_base_proto_depIdxs,
|
||||
MessageInfos: file_base_proto_msgTypes,
|
||||
}.Build()
|
||||
File_base_proto = out.File
|
||||
file_base_proto_rawDesc = nil
|
||||
file_base_proto_goTypes = nil
|
||||
file_base_proto_depIdxs = nil
|
||||
}
|
15
pkg/websocket/codec/protobuf/protocol/base.proto
Normal file
15
pkg/websocket/codec/protobuf/protocol/base.proto
Normal file
@ -0,0 +1,15 @@
|
||||
syntax = "proto3";
|
||||
package protocol;
|
||||
option go_package = "pkg/websocket/codec/protobuf/protocol";
|
||||
|
||||
//通信包装
|
||||
message TransPack {
|
||||
string router = 1;
|
||||
bytes data = 2;
|
||||
string error = 3;
|
||||
}
|
||||
|
||||
//连接检测
|
||||
message PingPang {
|
||||
int64 timestamp = 1; //时间戳
|
||||
}
|
6
pkg/websocket/peer/callback.go
Normal file
6
pkg/websocket/peer/callback.go
Normal file
@ -0,0 +1,6 @@
|
||||
package peer
|
||||
|
||||
type ConnectionCallBack interface {
|
||||
OnClosed(*Session)
|
||||
OnReceive(*Session, []byte) error
|
||||
}
|
98
pkg/websocket/peer/connect/acceptor.go
Normal file
98
pkg/websocket/peer/connect/acceptor.go
Normal file
@ -0,0 +1,98 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/mux"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var _ WsAcceptor = (*wsAcceptor)(nil)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type WsAcceptor interface {
|
||||
Start(addr string) error
|
||||
Stop()
|
||||
GinHandle(ctx *gin.Context)
|
||||
HandlerFunc() mux.HandlerFunc
|
||||
}
|
||||
|
||||
type wsAcceptor struct {
|
||||
server *http.Server
|
||||
sessMgr *peer.SessionManager
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewWsAcceptor(sessMgr *peer.SessionManager, loggers *zap.Logger) WsAcceptor {
|
||||
return &wsAcceptor{
|
||||
sessMgr: sessMgr,
|
||||
logger: loggers,
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) Start(addr string) error {
|
||||
urlObj, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("websocket urlparse failed. url(%s) %v", addr, err)
|
||||
}
|
||||
if urlObj.Path == "" {
|
||||
return fmt.Errorf("websocket start failed. expect path in url to listen addr:%s", addr)
|
||||
}
|
||||
|
||||
http.HandleFunc(urlObj.Path, func(w http.ResponseWriter, r *http.Request) {
|
||||
c, upgradeErr := upgrader.Upgrade(w, r, nil)
|
||||
if upgradeErr != nil {
|
||||
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
|
||||
return
|
||||
}
|
||||
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
|
||||
})
|
||||
|
||||
ws.server = &http.Server{Addr: urlObj.Host}
|
||||
err = ws.server.ListenAndServe()
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("websocket ListenAndServe addr:%s failed:%v", addr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) Stop() {
|
||||
ws.sessMgr.CloseAllSession()
|
||||
|
||||
if ws.server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
if err := ws.server.Shutdown(ctx); err != nil {
|
||||
ws.logger.Sugar().Errorf("server shutdown err:[%s]", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) GinHandle(ctx *gin.Context) {
|
||||
c, upgradeErr := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
|
||||
if upgradeErr != nil {
|
||||
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
|
||||
return
|
||||
}
|
||||
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) HandlerFunc() mux.HandlerFunc {
|
||||
return func(c mux.Context) {
|
||||
ws.GinHandle(c.Context())
|
||||
}
|
||||
}
|
134
pkg/websocket/peer/connect/connection.go
Normal file
134
pkg/websocket/peer/connect/connection.go
Normal file
@ -0,0 +1,134 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
writeWait = 20 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
maxFrameMessageLen = 16 * 1024 //4 * 4096
|
||||
maxSendBuffer = 16
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBrokenPipe = errors.New("send to broken pipe")
|
||||
ErrBufferPoolExceed = errors.New("send buffer exceed")
|
||||
)
|
||||
|
||||
type wsConnection struct {
|
||||
peer.ConnectionIdentify
|
||||
pm *peer.SessionManager
|
||||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
running bool
|
||||
}
|
||||
|
||||
func NewConnection(conn *websocket.Conn, p *peer.SessionManager) *wsConnection {
|
||||
wsc := &wsConnection{
|
||||
conn: conn,
|
||||
pm: p,
|
||||
send: make(chan []byte, maxSendBuffer),
|
||||
running: true,
|
||||
}
|
||||
|
||||
go wsc.acceptLoop()
|
||||
go wsc.sendLoop()
|
||||
|
||||
return wsc
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Peer() *peer.SessionManager {
|
||||
return ws.pm
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Raw() any {
|
||||
if ws.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return ws.conn
|
||||
}
|
||||
|
||||
func (ws *wsConnection) RemoteAddr() string {
|
||||
if ws.conn == nil {
|
||||
return ""
|
||||
}
|
||||
return ws.conn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Close() {
|
||||
_ = ws.conn.Close()
|
||||
ws.running = false
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Send(msg []byte) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = ErrBrokenPipe
|
||||
}
|
||||
}()
|
||||
if !ws.running {
|
||||
return ErrBrokenPipe
|
||||
}
|
||||
if len(ws.send) >= maxSendBuffer {
|
||||
return ErrBufferPoolExceed
|
||||
}
|
||||
if len(msg) > maxFrameMessageLen {
|
||||
return
|
||||
}
|
||||
ws.send <- msg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *wsConnection) acceptLoop() {
|
||||
defer func() {
|
||||
ws.pm.Unregister <- ws.ID()
|
||||
_ = ws.conn.Close()
|
||||
ws.running = false
|
||||
}()
|
||||
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
ws.conn.SetPongHandler(func(string) error {
|
||||
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
for ws.conn != nil {
|
||||
_, data, err := ws.conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
ws.pm.ProcessMessage(ws.ID(), data)
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsConnection) sendLoop() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
_ = ws.conn.Close()
|
||||
ws.running = false
|
||||
close(ws.send)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case msg := <-ws.send:
|
||||
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := ws.conn.WriteMessage(websocket.BinaryMessage, msg); err != nil {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := ws.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsConnection) IsClosed() bool {
|
||||
return !ws.running
|
||||
}
|
23
pkg/websocket/peer/connection.go
Normal file
23
pkg/websocket/peer/connection.go
Normal file
@ -0,0 +1,23 @@
|
||||
package peer
|
||||
|
||||
type Connection interface {
|
||||
Raw() any
|
||||
Peer() *SessionManager
|
||||
Send(msg []byte) error
|
||||
Close()
|
||||
ID() int64
|
||||
RemoteAddr() string
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
type ConnectionIdentify struct {
|
||||
id int64
|
||||
}
|
||||
|
||||
func (ci *ConnectionIdentify) ID() int64 {
|
||||
return ci.id
|
||||
}
|
||||
|
||||
func (ci *ConnectionIdentify) SetID(id int64) {
|
||||
ci.id = id
|
||||
}
|
16
pkg/websocket/peer/session.go
Normal file
16
pkg/websocket/peer/session.go
Normal file
@ -0,0 +1,16 @@
|
||||
package peer
|
||||
|
||||
import "time"
|
||||
|
||||
type Session struct {
|
||||
Conn Connection
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
func NewSession(conn Connection) *Session {
|
||||
s := &Session{
|
||||
Conn: conn,
|
||||
Time: time.Now(),
|
||||
}
|
||||
return s
|
||||
}
|
119
pkg/websocket/peer/session_manager.go
Normal file
119
pkg/websocket/peer/session_manager.go
Normal file
@ -0,0 +1,119 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
sessionList sync.Map // 使用Id关联会话
|
||||
|
||||
connIDGen int64 // 记录已经生成的会话ID流水号
|
||||
|
||||
count int64 // 记录当前在使用的会话数量
|
||||
|
||||
callback ConnectionCallBack
|
||||
Register chan *Session
|
||||
Unregister chan int64
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) SetIDBase(base int64) {
|
||||
|
||||
atomic.StoreInt64(&mgr.connIDGen, base)
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) Count() int {
|
||||
return int(atomic.LoadInt64(&mgr.count))
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) Add(sess *Session) {
|
||||
|
||||
id := atomic.AddInt64(&mgr.connIDGen, 1)
|
||||
|
||||
atomic.AddInt64(&mgr.count, 1)
|
||||
|
||||
sess.Conn.(interface {
|
||||
SetID(int64)
|
||||
}).SetID(id)
|
||||
|
||||
mgr.sessionList.Store(id, sess)
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) Close(id int64) {
|
||||
if v, ok := mgr.sessionList.Load(id); ok {
|
||||
if mgr.callback != nil {
|
||||
go mgr.callback.OnClosed(v.(*Session))
|
||||
}
|
||||
}
|
||||
mgr.sessionList.Delete(id)
|
||||
atomic.AddInt64(&mgr.count, -1)
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) ProcessMessage(id int64, msg []byte) {
|
||||
if v, ok := mgr.sessionList.Load(id); ok {
|
||||
if mgr.callback != nil {
|
||||
go func() {
|
||||
err := mgr.callback.OnReceive(v.(*Session), msg)
|
||||
if err != nil {
|
||||
v.(*Session).Conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) run() {
|
||||
for {
|
||||
select {
|
||||
case client := <-mgr.Register:
|
||||
mgr.connIDGen++
|
||||
mgr.count++
|
||||
client.Conn.(interface {
|
||||
SetID(int64)
|
||||
}).SetID(mgr.connIDGen)
|
||||
mgr.sessionList.Store(mgr.connIDGen, client)
|
||||
case clientID := <-mgr.Unregister:
|
||||
if v, ok := mgr.sessionList.Load(clientID); ok {
|
||||
if mgr.callback != nil {
|
||||
go mgr.callback.OnClosed(v.(*Session))
|
||||
}
|
||||
}
|
||||
mgr.sessionList.Delete(clientID)
|
||||
mgr.count--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) GetSession(id int64) *Session {
|
||||
if v, ok := mgr.sessionList.Load(id); ok {
|
||||
return v.(*Session)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) VisitSession(callback func(*Session) bool) {
|
||||
mgr.sessionList.Range(func(key, value any) bool {
|
||||
return callback(value.(*Session))
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) CloseAllSession() {
|
||||
mgr.VisitSession(func(sess *Session) bool {
|
||||
sess.Conn.Close()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) SessionCount() int64 {
|
||||
return atomic.LoadInt64(&mgr.count)
|
||||
}
|
||||
|
||||
func NewSessionMgr(callback ConnectionCallBack) *SessionManager {
|
||||
s := &SessionManager{
|
||||
callback: callback,
|
||||
Register: make(chan *Session),
|
||||
Unregister: make(chan int64),
|
||||
}
|
||||
go s.run()
|
||||
return s
|
||||
}
|
27
pkg/websocket/service/component.go
Normal file
27
pkg/websocket/service/component.go
Normal file
@ -0,0 +1,27 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
)
|
||||
|
||||
type Component interface {
|
||||
Init()
|
||||
OnSessionClose(*peer.Session) bool
|
||||
OnRequestFinished(*peer.Session, string, any, string, time.Duration)
|
||||
}
|
||||
|
||||
type ComponentBase struct{}
|
||||
|
||||
func (c *ComponentBase) Init() {
|
||||
|
||||
}
|
||||
|
||||
func (c *ComponentBase) OnSessionClose(session *peer.Session) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *ComponentBase) OnRequestFinished(session *peer.Session, router string, req any, errMsg string, delta time.Duration) {
|
||||
|
||||
}
|
33
pkg/websocket/service/method.go
Normal file
33
pkg/websocket/service/method.go
Normal file
@ -0,0 +1,33 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var (
|
||||
typeOfError = reflect.TypeOf((*error)(nil)).Elem()
|
||||
typeOfBytes = reflect.TypeOf(([]byte)(nil))
|
||||
typeOfSession = reflect.TypeOf(peer.NewSession(nil))
|
||||
)
|
||||
|
||||
// 方法检测
|
||||
func isHandlerMethod(method reflect.Method) bool {
|
||||
mt := method.Type
|
||||
if method.PkgPath != "" {
|
||||
return false
|
||||
}
|
||||
if mt.NumIn() != 3 {
|
||||
return false
|
||||
}
|
||||
if mt.NumOut() != 2 {
|
||||
return false
|
||||
}
|
||||
if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfSession {
|
||||
return false
|
||||
}
|
||||
if (mt.In(2).Kind() != reflect.Ptr && mt.In(2) != typeOfBytes) || mt.Out(1) != typeOfError || mt.Out(0).Kind() != reflect.Ptr {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
98
pkg/websocket/service/service.go
Normal file
98
pkg/websocket/service/service.go
Normal file
@ -0,0 +1,98 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/util"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Receiver reflect.Value // 值
|
||||
Method reflect.Method // 方法
|
||||
Type reflect.Type // 类型
|
||||
IsRawArg bool // 数据是否需要序列化
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
Name string // 服务名
|
||||
Type reflect.Type // 服务类型
|
||||
Receiver reflect.Value // 服务值
|
||||
Handlers map[string]*Handler // 注册的方法列表
|
||||
Component Component
|
||||
}
|
||||
|
||||
func NewService(comp Component) *Service {
|
||||
s := &Service{
|
||||
Type: reflect.TypeOf(comp),
|
||||
Receiver: reflect.ValueOf(comp),
|
||||
Component: comp,
|
||||
}
|
||||
s.Name = strings.ToLower(reflect.Indirect(s.Receiver).Type().Name())
|
||||
//调用初始化方法
|
||||
s.Component.Init()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Service) SuitableHandlerMethods(typ reflect.Type) map[string]*Handler {
|
||||
methods := make(map[string]*Handler)
|
||||
for m := 0; m < typ.NumMethod(); m++ {
|
||||
method := typ.Method(m)
|
||||
mt := method.Type
|
||||
mn := method.Name
|
||||
if isHandlerMethod(method) {
|
||||
raw := false
|
||||
if mt.In(2) == typeOfBytes {
|
||||
raw = true
|
||||
}
|
||||
mn = strings.ToLower(mn)
|
||||
methods[mn] = &Handler{Method: method, Type: mt.In(2), IsRawArg: raw}
|
||||
}
|
||||
}
|
||||
return methods
|
||||
}
|
||||
|
||||
func (s *Service) ExtractHandler() error {
|
||||
typeName := reflect.Indirect(s.Receiver).Type().Name()
|
||||
if typeName == "" {
|
||||
return errors.New("no service name for type " + s.Type.String())
|
||||
}
|
||||
if !util.IsExported(typeName) {
|
||||
return errors.New("type " + typeName + " is not exported")
|
||||
}
|
||||
s.Handlers = s.SuitableHandlerMethods(s.Type)
|
||||
for i := range s.Handlers {
|
||||
s.Handlers[i].Receiver = s.Receiver
|
||||
}
|
||||
if reflect.Indirect(s.Receiver).NumField() > 0 {
|
||||
filedNum := reflect.Indirect(s.Receiver).NumField()
|
||||
for i := 0; i < filedNum; i++ {
|
||||
ty := reflect.Indirect(s.Receiver).Field(i).Type().Name()
|
||||
if ty == ChildName {
|
||||
h := s.SuitableHandlerMethods(reflect.Indirect(s.Receiver).Field(i).Elem().Type())
|
||||
for ih, v := range h {
|
||||
s.Handlers[ih] = v
|
||||
s.Handlers[ih].Receiver = reflect.Indirect(s.Receiver).Field(i).Elem()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(s.Handlers) == 0 {
|
||||
str := "service: "
|
||||
method := s.SuitableHandlerMethods(reflect.PtrTo(s.Type))
|
||||
if len(method) != 0 {
|
||||
str = "type " + s.Name + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
|
||||
} else {
|
||||
str = "type " + s.Name + " has no exported methods of suitable type"
|
||||
}
|
||||
return errors.New(str)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) OnSessionClose(session *peer.Session) bool {
|
||||
return s.Component.OnSessionClose(session)
|
||||
}
|
59
pkg/websocket/service/service_manager.go
Normal file
59
pkg/websocket/service/service_manager.go
Normal file
@ -0,0 +1,59 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/codec"
|
||||
_ "git.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
serviceLogger *zap.Logger
|
||||
|
||||
RegisteredServiceList = make(map[string]*Service) // all registered service
|
||||
RouterCodec codec.Codec
|
||||
)
|
||||
|
||||
const ChildName = "Hbase"
|
||||
|
||||
type Hbase any
|
||||
|
||||
func RegisterService(logger *zap.Logger, comp ...Component) {
|
||||
serviceLogger = logger
|
||||
for _, v := range comp {
|
||||
s := NewService(v)
|
||||
if _, ok := RegisteredServiceList[s.Name]; ok {
|
||||
serviceLogger.Sugar().Errorf("service: service already defined: %s", s.Name)
|
||||
}
|
||||
if err := s.ExtractHandler(); err != nil {
|
||||
serviceLogger.Sugar().Errorf("service: extract handler function failed: %v", err)
|
||||
}
|
||||
RegisteredServiceList[s.Name] = s
|
||||
for name, handler := range s.Handlers {
|
||||
router := fmt.Sprintf("%s.%s", s.Name, name)
|
||||
//注册消息 用于解码
|
||||
codec.RegisterMessage(router, handler.Type)
|
||||
serviceLogger.Sugar().Debugf("service: router %s param %s registed", router, handler.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SetCodec(name string) error {
|
||||
RouterCodec = codec.GetCodec(name)
|
||||
if RouterCodec == nil {
|
||||
return fmt.Errorf("service: codec %s not registered", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Send(session *peer.Session, router string, data any) error {
|
||||
if RouterCodec == nil {
|
||||
return fmt.Errorf("service: codec not set")
|
||||
}
|
||||
rb, err := RouterCodec.Marshal(router, data, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service: %v", err)
|
||||
}
|
||||
return session.Conn.Send(rb)
|
||||
}
|
102
pkg/websocket/service/session_callback.go
Normal file
102
pkg/websocket/service/session_callback.go
Normal file
@ -0,0 +1,102 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
)
|
||||
|
||||
type callBackEntity struct{}
|
||||
|
||||
func GetSessionManager() *peer.SessionManager {
|
||||
return peer.NewSessionMgr(&callBackEntity{})
|
||||
}
|
||||
|
||||
func (cb *callBackEntity) OnClosed(session *peer.Session) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Println(fmt.Sprintf("OnClosed: session:%d err:%v", session.Conn.ID(), err))
|
||||
}
|
||||
}()
|
||||
for _, v := range RegisteredServiceList {
|
||||
if ok := v.OnSessionClose(session); ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *callBackEntity) OnReceive(session *peer.Session, msg []byte) error {
|
||||
_, msgPack, err := RouterCodec.Unmarshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("onreceive: %v", err)
|
||||
}
|
||||
router, ok := msgPack.Router.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router)
|
||||
}
|
||||
routerArr := strings.Split(router, ".")
|
||||
if len(routerArr) != 2 {
|
||||
return fmt.Errorf("onreceive: invalid router:%s", msgPack.Router)
|
||||
}
|
||||
s, ok := RegisteredServiceList[routerArr[0]]
|
||||
if !ok {
|
||||
return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err)
|
||||
}
|
||||
h, ok := s.Handlers[routerArr[1]]
|
||||
if !ok {
|
||||
return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err)
|
||||
}
|
||||
t1 := time.Now()
|
||||
|
||||
var args = []reflect.Value{h.Receiver, reflect.ValueOf(session), reflect.ValueOf(msgPack.DataPtr)}
|
||||
var res any
|
||||
var rb []byte
|
||||
res, err = CallHandlerFunc(h.Method, args)
|
||||
if res != nil && !reflect.ValueOf(res).IsNil() {
|
||||
rb, err = RouterCodec.Marshal(router, res, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service: %v", err)
|
||||
}
|
||||
err = session.Conn.Send(rb)
|
||||
if err != nil {
|
||||
serviceLogger.Sugar().Warnf("warn! service send msg failed router:%s err:%v", router, err)
|
||||
}
|
||||
} else {
|
||||
rb, err = RouterCodec.Marshal(router, nil, err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service: %v", err)
|
||||
}
|
||||
err = session.Conn.Send(rb)
|
||||
if err != nil {
|
||||
serviceLogger.Sugar().Warnf("warn! service send msg failed router:%s err:%v", router, err)
|
||||
}
|
||||
}
|
||||
var errs string
|
||||
if err != nil {
|
||||
errs = err.Error()
|
||||
}
|
||||
dt := time.Since(t1)
|
||||
go s.Component.OnRequestFinished(session, router, RouterCodec.ToString(msgPack.DataPtr), errs, dt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func CallHandlerFunc(foo reflect.Method, args []reflect.Value) (retValue any, retErr error) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Println(fmt.Sprintf("CallHandlerFunc: %v", err))
|
||||
retValue = nil
|
||||
retErr = fmt.Errorf("CallHandlerFunc: call method pkg:%s method:%s err:%v", foo.PkgPath, foo.Name, err)
|
||||
}
|
||||
}()
|
||||
if ret := foo.Func.Call(args); len(ret) > 0 {
|
||||
var err error = nil
|
||||
if r1 := ret[1].Interface(); r1 != nil {
|
||||
err = r1.(error)
|
||||
}
|
||||
return ret[0].Interface(), err
|
||||
}
|
||||
return nil, fmt.Errorf("CallHandlerFunc: call method pkg:%s method:%s", foo.PkgPath, foo.Name)
|
||||
}
|
11
pkg/websocket/util/util.go
Normal file
11
pkg/websocket/util/util.go
Normal file
@ -0,0 +1,11 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func IsExported(name string) bool {
|
||||
w, _ := utf8.DecodeRuneInString(name)
|
||||
return unicode.IsUpper(w)
|
||||
}
|
Reference in New Issue
Block a user