first commit
This commit is contained in:
6
pkg/websocket/peer/callback.go
Normal file
6
pkg/websocket/peer/callback.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package peer
|
||||
|
||||
type ConnectionCallBack interface {
|
||||
OnClosed(*Session)
|
||||
OnReceive(*Session, []byte) error
|
||||
}
|
98
pkg/websocket/peer/connect/acceptor.go
Normal file
98
pkg/websocket/peer/connect/acceptor.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/mux"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var _ WsAcceptor = (*wsAcceptor)(nil)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type WsAcceptor interface {
|
||||
Start(addr string) error
|
||||
Stop()
|
||||
GinHandle(ctx *gin.Context)
|
||||
HandlerFunc() mux.HandlerFunc
|
||||
}
|
||||
|
||||
type wsAcceptor struct {
|
||||
server *http.Server
|
||||
sessMgr *peer.SessionManager
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewWsAcceptor(sessMgr *peer.SessionManager, loggers *zap.Logger) WsAcceptor {
|
||||
return &wsAcceptor{
|
||||
sessMgr: sessMgr,
|
||||
logger: loggers,
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) Start(addr string) error {
|
||||
urlObj, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("websocket urlparse failed. url(%s) %v", addr, err)
|
||||
}
|
||||
if urlObj.Path == "" {
|
||||
return fmt.Errorf("websocket start failed. expect path in url to listen addr:%s", addr)
|
||||
}
|
||||
|
||||
http.HandleFunc(urlObj.Path, func(w http.ResponseWriter, r *http.Request) {
|
||||
c, upgradeErr := upgrader.Upgrade(w, r, nil)
|
||||
if upgradeErr != nil {
|
||||
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
|
||||
return
|
||||
}
|
||||
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
|
||||
})
|
||||
|
||||
ws.server = &http.Server{Addr: urlObj.Host}
|
||||
err = ws.server.ListenAndServe()
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("websocket ListenAndServe addr:%s failed:%v", addr, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) Stop() {
|
||||
ws.sessMgr.CloseAllSession()
|
||||
|
||||
if ws.server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
if err := ws.server.Shutdown(ctx); err != nil {
|
||||
ws.logger.Sugar().Errorf("server shutdown err:[%s]", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) GinHandle(ctx *gin.Context) {
|
||||
c, upgradeErr := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
|
||||
if upgradeErr != nil {
|
||||
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
|
||||
return
|
||||
}
|
||||
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
|
||||
}
|
||||
|
||||
func (ws *wsAcceptor) HandlerFunc() mux.HandlerFunc {
|
||||
return func(c mux.Context) {
|
||||
ws.GinHandle(c.Context())
|
||||
}
|
||||
}
|
134
pkg/websocket/peer/connect/connection.go
Normal file
134
pkg/websocket/peer/connect/connection.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package connect
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
writeWait = 20 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
maxFrameMessageLen = 16 * 1024 //4 * 4096
|
||||
maxSendBuffer = 16
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBrokenPipe = errors.New("send to broken pipe")
|
||||
ErrBufferPoolExceed = errors.New("send buffer exceed")
|
||||
)
|
||||
|
||||
type wsConnection struct {
|
||||
peer.ConnectionIdentify
|
||||
pm *peer.SessionManager
|
||||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
running bool
|
||||
}
|
||||
|
||||
func NewConnection(conn *websocket.Conn, p *peer.SessionManager) *wsConnection {
|
||||
wsc := &wsConnection{
|
||||
conn: conn,
|
||||
pm: p,
|
||||
send: make(chan []byte, maxSendBuffer),
|
||||
running: true,
|
||||
}
|
||||
|
||||
go wsc.acceptLoop()
|
||||
go wsc.sendLoop()
|
||||
|
||||
return wsc
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Peer() *peer.SessionManager {
|
||||
return ws.pm
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Raw() any {
|
||||
if ws.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return ws.conn
|
||||
}
|
||||
|
||||
func (ws *wsConnection) RemoteAddr() string {
|
||||
if ws.conn == nil {
|
||||
return ""
|
||||
}
|
||||
return ws.conn.RemoteAddr().String()
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Close() {
|
||||
_ = ws.conn.Close()
|
||||
ws.running = false
|
||||
}
|
||||
|
||||
func (ws *wsConnection) Send(msg []byte) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = ErrBrokenPipe
|
||||
}
|
||||
}()
|
||||
if !ws.running {
|
||||
return ErrBrokenPipe
|
||||
}
|
||||
if len(ws.send) >= maxSendBuffer {
|
||||
return ErrBufferPoolExceed
|
||||
}
|
||||
if len(msg) > maxFrameMessageLen {
|
||||
return
|
||||
}
|
||||
ws.send <- msg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *wsConnection) acceptLoop() {
|
||||
defer func() {
|
||||
ws.pm.Unregister <- ws.ID()
|
||||
_ = ws.conn.Close()
|
||||
ws.running = false
|
||||
}()
|
||||
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
ws.conn.SetPongHandler(func(string) error {
|
||||
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
for ws.conn != nil {
|
||||
_, data, err := ws.conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
ws.pm.ProcessMessage(ws.ID(), data)
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsConnection) sendLoop() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
_ = ws.conn.Close()
|
||||
ws.running = false
|
||||
close(ws.send)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case msg := <-ws.send:
|
||||
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := ws.conn.WriteMessage(websocket.BinaryMessage, msg); err != nil {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := ws.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *wsConnection) IsClosed() bool {
|
||||
return !ws.running
|
||||
}
|
23
pkg/websocket/peer/connection.go
Normal file
23
pkg/websocket/peer/connection.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package peer
|
||||
|
||||
type Connection interface {
|
||||
Raw() any
|
||||
Peer() *SessionManager
|
||||
Send(msg []byte) error
|
||||
Close()
|
||||
ID() int64
|
||||
RemoteAddr() string
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
type ConnectionIdentify struct {
|
||||
id int64
|
||||
}
|
||||
|
||||
func (ci *ConnectionIdentify) ID() int64 {
|
||||
return ci.id
|
||||
}
|
||||
|
||||
func (ci *ConnectionIdentify) SetID(id int64) {
|
||||
ci.id = id
|
||||
}
|
16
pkg/websocket/peer/session.go
Normal file
16
pkg/websocket/peer/session.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package peer
|
||||
|
||||
import "time"
|
||||
|
||||
type Session struct {
|
||||
Conn Connection
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
func NewSession(conn Connection) *Session {
|
||||
s := &Session{
|
||||
Conn: conn,
|
||||
Time: time.Now(),
|
||||
}
|
||||
return s
|
||||
}
|
119
pkg/websocket/peer/session_manager.go
Normal file
119
pkg/websocket/peer/session_manager.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
sessionList sync.Map // 使用Id关联会话
|
||||
|
||||
connIDGen int64 // 记录已经生成的会话ID流水号
|
||||
|
||||
count int64 // 记录当前在使用的会话数量
|
||||
|
||||
callback ConnectionCallBack
|
||||
Register chan *Session
|
||||
Unregister chan int64
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) SetIDBase(base int64) {
|
||||
|
||||
atomic.StoreInt64(&mgr.connIDGen, base)
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) Count() int {
|
||||
return int(atomic.LoadInt64(&mgr.count))
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) Add(sess *Session) {
|
||||
|
||||
id := atomic.AddInt64(&mgr.connIDGen, 1)
|
||||
|
||||
atomic.AddInt64(&mgr.count, 1)
|
||||
|
||||
sess.Conn.(interface {
|
||||
SetID(int64)
|
||||
}).SetID(id)
|
||||
|
||||
mgr.sessionList.Store(id, sess)
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) Close(id int64) {
|
||||
if v, ok := mgr.sessionList.Load(id); ok {
|
||||
if mgr.callback != nil {
|
||||
go mgr.callback.OnClosed(v.(*Session))
|
||||
}
|
||||
}
|
||||
mgr.sessionList.Delete(id)
|
||||
atomic.AddInt64(&mgr.count, -1)
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) ProcessMessage(id int64, msg []byte) {
|
||||
if v, ok := mgr.sessionList.Load(id); ok {
|
||||
if mgr.callback != nil {
|
||||
go func() {
|
||||
err := mgr.callback.OnReceive(v.(*Session), msg)
|
||||
if err != nil {
|
||||
v.(*Session).Conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) run() {
|
||||
for {
|
||||
select {
|
||||
case client := <-mgr.Register:
|
||||
mgr.connIDGen++
|
||||
mgr.count++
|
||||
client.Conn.(interface {
|
||||
SetID(int64)
|
||||
}).SetID(mgr.connIDGen)
|
||||
mgr.sessionList.Store(mgr.connIDGen, client)
|
||||
case clientID := <-mgr.Unregister:
|
||||
if v, ok := mgr.sessionList.Load(clientID); ok {
|
||||
if mgr.callback != nil {
|
||||
go mgr.callback.OnClosed(v.(*Session))
|
||||
}
|
||||
}
|
||||
mgr.sessionList.Delete(clientID)
|
||||
mgr.count--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) GetSession(id int64) *Session {
|
||||
if v, ok := mgr.sessionList.Load(id); ok {
|
||||
return v.(*Session)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) VisitSession(callback func(*Session) bool) {
|
||||
mgr.sessionList.Range(func(key, value any) bool {
|
||||
return callback(value.(*Session))
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) CloseAllSession() {
|
||||
mgr.VisitSession(func(sess *Session) bool {
|
||||
sess.Conn.Close()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) SessionCount() int64 {
|
||||
return atomic.LoadInt64(&mgr.count)
|
||||
}
|
||||
|
||||
func NewSessionMgr(callback ConnectionCallBack) *SessionManager {
|
||||
s := &SessionManager{
|
||||
callback: callback,
|
||||
Register: make(chan *Session),
|
||||
Unregister: make(chan int64),
|
||||
}
|
||||
go s.run()
|
||||
return s
|
||||
}
|
Reference in New Issue
Block a user