first commit

This commit is contained in:
2024-07-23 10:23:43 +08:00
commit 7b4c2521a3
126 changed files with 15931 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
package peer
type ConnectionCallBack interface {
OnClosed(*Session)
OnReceive(*Session, []byte) error
}

View File

@@ -0,0 +1,98 @@
package connect
import (
"context"
"errors"
"fmt"
"git.bvbej.com/bvbej/base-golang/pkg/mux"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"net/http"
"net/url"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
"github.com/gorilla/websocket"
)
var _ WsAcceptor = (*wsAcceptor)(nil)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type WsAcceptor interface {
Start(addr string) error
Stop()
GinHandle(ctx *gin.Context)
HandlerFunc() mux.HandlerFunc
}
type wsAcceptor struct {
server *http.Server
sessMgr *peer.SessionManager
logger *zap.Logger
}
func NewWsAcceptor(sessMgr *peer.SessionManager, loggers *zap.Logger) WsAcceptor {
return &wsAcceptor{
sessMgr: sessMgr,
logger: loggers,
}
}
func (ws *wsAcceptor) Start(addr string) error {
urlObj, err := url.Parse(addr)
if err != nil {
return fmt.Errorf("websocket urlparse failed. url(%s) %v", addr, err)
}
if urlObj.Path == "" {
return fmt.Errorf("websocket start failed. expect path in url to listen addr:%s", addr)
}
http.HandleFunc(urlObj.Path, func(w http.ResponseWriter, r *http.Request) {
c, upgradeErr := upgrader.Upgrade(w, r, nil)
if upgradeErr != nil {
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
return
}
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
})
ws.server = &http.Server{Addr: urlObj.Host}
err = ws.server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("websocket ListenAndServe addr:%s failed:%v", addr, err)
}
return nil
}
func (ws *wsAcceptor) Stop() {
ws.sessMgr.CloseAllSession()
if ws.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
if err := ws.server.Shutdown(ctx); err != nil {
ws.logger.Sugar().Errorf("server shutdown err:[%s]", err)
}
}
}
func (ws *wsAcceptor) GinHandle(ctx *gin.Context) {
c, upgradeErr := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
if upgradeErr != nil {
ws.logger.Sugar().Errorf("upgrade http failed: %s", upgradeErr)
return
}
ws.sessMgr.Register <- peer.NewSession(NewConnection(c, ws.sessMgr))
}
func (ws *wsAcceptor) HandlerFunc() mux.HandlerFunc {
return func(c mux.Context) {
ws.GinHandle(c.Context())
}
}

View File

@@ -0,0 +1,134 @@
package connect
import (
"errors"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/websocket/peer"
"github.com/gorilla/websocket"
)
const (
writeWait = 20 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
maxFrameMessageLen = 16 * 1024 //4 * 4096
maxSendBuffer = 16
)
var (
ErrBrokenPipe = errors.New("send to broken pipe")
ErrBufferPoolExceed = errors.New("send buffer exceed")
)
type wsConnection struct {
peer.ConnectionIdentify
pm *peer.SessionManager
conn *websocket.Conn
send chan []byte
running bool
}
func NewConnection(conn *websocket.Conn, p *peer.SessionManager) *wsConnection {
wsc := &wsConnection{
conn: conn,
pm: p,
send: make(chan []byte, maxSendBuffer),
running: true,
}
go wsc.acceptLoop()
go wsc.sendLoop()
return wsc
}
func (ws *wsConnection) Peer() *peer.SessionManager {
return ws.pm
}
func (ws *wsConnection) Raw() any {
if ws.conn == nil {
return nil
}
return ws.conn
}
func (ws *wsConnection) RemoteAddr() string {
if ws.conn == nil {
return ""
}
return ws.conn.RemoteAddr().String()
}
func (ws *wsConnection) Close() {
_ = ws.conn.Close()
ws.running = false
}
func (ws *wsConnection) Send(msg []byte) (err error) {
defer func() {
if e := recover(); e != nil {
err = ErrBrokenPipe
}
}()
if !ws.running {
return ErrBrokenPipe
}
if len(ws.send) >= maxSendBuffer {
return ErrBufferPoolExceed
}
if len(msg) > maxFrameMessageLen {
return
}
ws.send <- msg
return nil
}
func (ws *wsConnection) acceptLoop() {
defer func() {
ws.pm.Unregister <- ws.ID()
_ = ws.conn.Close()
ws.running = false
}()
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
ws.conn.SetPongHandler(func(string) error {
_ = ws.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for ws.conn != nil {
_, data, err := ws.conn.ReadMessage()
if err != nil {
break
}
ws.pm.ProcessMessage(ws.ID(), data)
}
}
func (ws *wsConnection) sendLoop() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
_ = ws.conn.Close()
ws.running = false
close(ws.send)
}()
for {
select {
case msg := <-ws.send:
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := ws.conn.WriteMessage(websocket.BinaryMessage, msg); err != nil {
return
}
case <-ticker.C:
_ = ws.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := ws.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func (ws *wsConnection) IsClosed() bool {
return !ws.running
}

View File

@@ -0,0 +1,23 @@
package peer
type Connection interface {
Raw() any
Peer() *SessionManager
Send(msg []byte) error
Close()
ID() int64
RemoteAddr() string
IsClosed() bool
}
type ConnectionIdentify struct {
id int64
}
func (ci *ConnectionIdentify) ID() int64 {
return ci.id
}
func (ci *ConnectionIdentify) SetID(id int64) {
ci.id = id
}

View File

@@ -0,0 +1,16 @@
package peer
import "time"
type Session struct {
Conn Connection
Time time.Time
}
func NewSession(conn Connection) *Session {
s := &Session{
Conn: conn,
Time: time.Now(),
}
return s
}

View File

@@ -0,0 +1,119 @@
package peer
import (
"sync"
"sync/atomic"
)
type SessionManager struct {
sessionList sync.Map // 使用Id关联会话
connIDGen int64 // 记录已经生成的会话ID流水号
count int64 // 记录当前在使用的会话数量
callback ConnectionCallBack
Register chan *Session
Unregister chan int64
}
func (mgr *SessionManager) SetIDBase(base int64) {
atomic.StoreInt64(&mgr.connIDGen, base)
}
func (mgr *SessionManager) Count() int {
return int(atomic.LoadInt64(&mgr.count))
}
func (mgr *SessionManager) Add(sess *Session) {
id := atomic.AddInt64(&mgr.connIDGen, 1)
atomic.AddInt64(&mgr.count, 1)
sess.Conn.(interface {
SetID(int64)
}).SetID(id)
mgr.sessionList.Store(id, sess)
}
func (mgr *SessionManager) Close(id int64) {
if v, ok := mgr.sessionList.Load(id); ok {
if mgr.callback != nil {
go mgr.callback.OnClosed(v.(*Session))
}
}
mgr.sessionList.Delete(id)
atomic.AddInt64(&mgr.count, -1)
}
func (mgr *SessionManager) ProcessMessage(id int64, msg []byte) {
if v, ok := mgr.sessionList.Load(id); ok {
if mgr.callback != nil {
go func() {
err := mgr.callback.OnReceive(v.(*Session), msg)
if err != nil {
v.(*Session).Conn.Close()
}
}()
}
}
}
func (mgr *SessionManager) run() {
for {
select {
case client := <-mgr.Register:
mgr.connIDGen++
mgr.count++
client.Conn.(interface {
SetID(int64)
}).SetID(mgr.connIDGen)
mgr.sessionList.Store(mgr.connIDGen, client)
case clientID := <-mgr.Unregister:
if v, ok := mgr.sessionList.Load(clientID); ok {
if mgr.callback != nil {
go mgr.callback.OnClosed(v.(*Session))
}
}
mgr.sessionList.Delete(clientID)
mgr.count--
}
}
}
func (mgr *SessionManager) GetSession(id int64) *Session {
if v, ok := mgr.sessionList.Load(id); ok {
return v.(*Session)
}
return nil
}
func (mgr *SessionManager) VisitSession(callback func(*Session) bool) {
mgr.sessionList.Range(func(key, value any) bool {
return callback(value.(*Session))
})
}
func (mgr *SessionManager) CloseAllSession() {
mgr.VisitSession(func(sess *Session) bool {
sess.Conn.Close()
return true
})
}
func (mgr *SessionManager) SessionCount() int64 {
return atomic.LoadInt64(&mgr.count)
}
func NewSessionMgr(callback ConnectionCallBack) *SessionManager {
s := &SessionManager{
callback: callback,
Register: make(chan *Session),
Unregister: make(chan int64),
}
go s.run()
return s
}