package client import ( "errors" "fmt" "gitea.bvbej.com/bvbej/base-golang/pkg/ticker" "gitea.bvbej.com/bvbej/base-golang/pkg/websocket/client/service" "gitea.bvbej.com/bvbej/base-golang/pkg/websocket/codec" _ "gitea.bvbej.com/bvbej/base-golang/pkg/websocket/codec/json" _ "gitea.bvbej.com/bvbej/base-golang/pkg/websocket/codec/protobuf" "gitea.bvbej.com/bvbej/base-golang/tool" "github.com/gorilla/websocket" "go.uber.org/zap" "net/http" "net/url" "reflect" "sync/atomic" "time" ) const ( writeWait = 20 * time.Second pongWait = 60 * time.Second reconnectWait = 3 * time.Second pingPeriod = (pongWait * 9) / 10 maxFrameMessageLen = 16 * 1024 maxSendBuffer = 32 ) var ( _ Client = (*client)(nil) ErrBrokenPipe = errors.New("send to broken pipe") ErrBufferPoolExceed = errors.New("send buffer exceed") ) type Client interface { readLoop() writeLoop() ping() reconnect() onReceive(msg []byte) error onSend(msg []byte) error connect() error Send(router string, data any) error Connect(requestHeader http.Header) error OnReceiveError(f func(error)) OnReconnected(f func(error)) Close() } type client struct { url url.URL requestHeader http.Header logger *zap.Logger session *service.Session isConnected atomic.Bool routerCodec codec.Codec send chan []byte handlers map[string]*service.Handler // 注册的方法列表 onReceiveErr func(error) pingTicker ticker.Ticker checkConnTicker ticker.Ticker onReconnect func(error) } func New(logger *zap.Logger, url url.URL, decoder string, handlers any) (Client, error) { if !tool.InArray(url.Scheme, []string{"ws", "wss"}) { return nil, errors.New(`param: scheme not supported`) } routerCodec := codec.GetCodec(decoder) if routerCodec == nil { return nil, errors.New(`param: codec not supported`) } components := service.RegisterHandler(handlers) if len(components) == 0 { return nil, errors.New(`param: handlers unqualified`) } c := &client{ logger: logger, isConnected: atomic.Bool{}, routerCodec: routerCodec, url: url, send: make(chan []byte, maxSendBuffer), handlers: components, pingTicker: ticker.New(pingPeriod), checkConnTicker: ticker.New(reconnectWait), } return c, nil } func (c *client) readLoop() { _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait)) c.session.Conn.SetPongHandler(func(string) error { _ = c.session.Conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, data, err := c.session.Conn.ReadMessage() if err != nil { c.isConnected.Store(false) break } err = c.onReceive(data) if err != nil && c.onReceiveErr != nil { c.onReceiveErr(err) } } } func (c *client) writeLoop() { for msg := range c.send { _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait)) err := c.session.Conn.WriteMessage(websocket.BinaryMessage, msg) if err != nil { c.logger.Sugar().Errorf("writeLoop err: %s", err) } } } func (c *client) ping() { c.pingTicker.Process(func() { _ = c.session.Conn.SetWriteDeadline(time.Now().Add(writeWait)) _ = c.session.Conn.WriteMessage(websocket.PingMessage, nil) }) } func (c *client) reconnect() { c.checkConnTicker.Process(func() { if c.isConnected.Load() { return } err := c.connect() if c.onReconnect != nil { c.onReconnect(err) } }) } func (c *client) connect() error { conn, _, err := websocket.DefaultDialer.Dial(c.url.String(), c.requestHeader) if err != nil { return fmt.Errorf("dial: %s", err) } c.session = service.NewSession(conn) c.isConnected.Store(true) go c.readLoop() return nil } func (c *client) onReceive(msg []byte) error { _, msgPack, err := c.routerCodec.Unmarshal(msg) if err != nil { return fmt.Errorf("onreceive: %v", err) } router, ok := msgPack.Router.(string) if !ok { return fmt.Errorf("onreceive: invalid router:%v", msgPack.Router) } s, ok := c.handlers[router] if !ok { return fmt.Errorf("onreceive: function not registed router:%s err:%v", msgPack.Router, err) } if msgPack.Err != nil { return fmt.Errorf("%s:%s", router, msgPack.Err) } var args = []reflect.Value{s.Receiver, reflect.ValueOf(c.session), reflect.ValueOf(msgPack.DataPtr)} s.Method.Func.Call(args) return nil } func (c *client) onSend(msg []byte) (err error) { defer func() { if e := recover(); e != nil { err = ErrBrokenPipe } }() if !c.isConnected.Load() { return ErrBrokenPipe } if len(c.send) >= maxSendBuffer { return ErrBufferPoolExceed } if len(msg) > maxFrameMessageLen { return } c.send <- msg return nil } func (c *client) Connect(requestHeader http.Header) error { c.requestHeader = requestHeader err := c.connect() if err != nil { return err } go c.ping() go c.writeLoop() go c.reconnect() return nil } func (c *client) Send(router string, data any) error { rb, err := c.routerCodec.Marshal(router, data, nil) if err != nil { return fmt.Errorf("service: %v", err) } return c.onSend(rb) } func (c *client) OnReceiveError(f func(error)) { c.onReceiveErr = f } func (c *client) OnReconnected(f func(error)) { c.onReconnect = f } func (c *client) Close() { close(c.send) c.pingTicker.Stop() c.checkConnTicker.Stop() _ = c.session.Conn.Close() }