diff --git a/pkg/sse/server.go b/pkg/sse/server.go index 04dc2c0..dde6a99 100644 --- a/pkg/sse/server.go +++ b/pkg/sse/server.go @@ -12,8 +12,8 @@ import ( var _ Server = (*event)(nil) type Server interface { - GinHandle(ctx *gin.Context, user any) HandlerFunc() mux.HandlerFunc + GinHandlerFunc(auth func(c *gin.Context) (string, error)) gin.HandlerFunc Push(user any, name, msg string) bool Broadcast(name, msg string) } @@ -67,40 +67,69 @@ func (stream *event) listen() { } } -func (stream *event) GinHandle(ctx *gin.Context, user any) { - if user == nil { - ctx.AbortWithStatus(http.StatusUnauthorized) - return - } - e := make(chan msgChan) - client := clientChan{ - User: user, - Chan: e, - } - stream.Register <- client - defer func() { - stream.Unregister <- user - }() - - ctx.Writer.Header().Set("Content-Type", "text/event-stream") - ctx.Writer.Header().Set("Cache-Control", "no-cache") - ctx.Writer.Header().Set("Connection", "keep-alive") - ctx.Writer.Header().Set("Transfer-Encoding", "chunked") - - ctx.Stream(func(w io.Writer) bool { - if msg, ok := <-e; ok { - ctx.SSEvent(msg.Name, msg.Message) - return true - } - return false - }) - - ctx.Next() -} - func (stream *event) HandlerFunc() mux.HandlerFunc { return func(c mux.Context) { - stream.GinHandle(c.Context(), c.Auth()) + auth := c.Auth() + if auth == nil { + c.Context().AbortWithStatus(http.StatusBadRequest) + return + } + + e := make(chan msgChan) + client := clientChan{ + User: auth, + Chan: e, + } + stream.Register <- client + defer func() { + stream.Unregister <- auth + }() + + c.Context().Writer.Header().Set("Content-Type", "text/event-stream") + c.Context().Writer.Header().Set("Cache-Control", "no-cache") + c.Context().Writer.Header().Set("Connection", "keep-alive") + c.Context().Writer.Header().Set("Transfer-Encoding", "chunked") + + c.Context().Stream(func(w io.Writer) bool { + if msg, ok := <-e; ok { + c.Context().SSEvent(msg.Name, msg.Message) + return true + } + return false + }) + } +} + +func (stream *event) GinHandlerFunc(auth func(c *gin.Context) (string, error)) gin.HandlerFunc { + return func(c *gin.Context) { + user, err := auth(c) + if err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + + e := make(chan msgChan) + client := clientChan{ + User: user, + Chan: e, + } + stream.Register <- client + defer func() { + stream.Unregister <- user + }() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + + c.Stream(func(w io.Writer) bool { + if msg, ok := <-e; ok { + c.SSEvent(msg.Name, msg.Message) + return true + } + return false + }) } }