229 lines
5.5 KiB
Go
Raw Normal View History

2024-07-23 10:23:43 +08:00
package upload
import (
"context"
"crypto/sha256"
"errors"
"fmt"
2024-07-31 16:49:14 +08:00
"gitea.bvbej.com/bvbej/base-golang/pkg/color"
"gitea.bvbej.com/bvbej/base-golang/pkg/ticker"
"gitea.bvbej.com/bvbej/base-golang/pkg/token"
2024-07-23 10:23:43 +08:00
"github.com/rs/cors"
"github.com/tus/tusd/pkg/filestore"
tus "github.com/tus/tusd/pkg/handler"
"go.uber.org/zap"
"net/http"
"os"
"strings"
"sync"
"time"
)
var _ Server = (*server)(nil)
type Server interface {
GetUploadToken(string, string, time.Duration) string
GetFileInfo(string) (*tus.FileInfo, error)
Start(func(string, string, tus.FileInfo)) error
Stop() error
}
type server struct {
headerTokenKey string
uploading sync.Map
config Config
token token.Token
store filestore.FileStore
logger *zap.Logger
httpServer *http.Server
ctx context.Context
done context.CancelFunc
checker ticker.Ticker
completedEvent func(sha256, param string, info tus.FileInfo)
}
type Config struct {
ListenAddr string
Path string
Dir string
Secret string
DisableDownload bool
Debug bool
}
func New(conf Config, logger *zap.Logger) Server {
ctx, cancelFunc := context.WithCancel(context.Background())
return &server{
config: conf,
uploading: sync.Map{},
headerTokenKey: "Authorization",
logger: logger,
token: token.New(conf.Secret),
ctx: ctx,
done: cancelFunc,
checker: ticker.New(time.Minute),
}
}
func (s *server) GetUploadToken(sha256, param string, ttl time.Duration) string {
sign, _ := s.token.JwtSign(sha256, param, ttl)
return sign
}
func (s *server) GetFileInfo(id string) (*tus.FileInfo, error) {
upload, err := s.store.GetUpload(context.Background(), id)
if err != nil {
return nil, err
}
info, err := upload.GetInfo(context.Background())
if err != nil {
return nil, err
}
return &info, nil
}
func (s *server) Start(completedEvent func(sha256, param string, info tus.FileInfo)) error {
s.completedEvent = completedEvent
composer := tus.NewStoreComposer()
if err := os.MkdirAll(s.config.Dir, os.ModePerm); err != nil {
return err
}
s.store = filestore.New(s.config.Dir)
s.store.UseIn(composer)
handler, err := tus.NewHandler(tus.Config{
StoreComposer: composer,
BasePath: s.config.Path,
Logger: zap.NewStdLog(s.logger),
NotifyCompleteUploads: true,
NotifyTerminatedUploads: true,
DisableTermination: true,
DisableDownload: s.config.DisableDownload,
RespectForwardedHeaders: strings.Contains(s.config.ListenAddr, "127.0.0.1"),
PreUploadCreateCallback: func(hook tus.HookEvent) error {
authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey)
jwtClaims, err := s.token.JwtParse(authStr)
if err == nil {
_, ok := s.uploading.Load(authStr)
if !ok {
s.uploading.Store(authStr, jwtClaims.ExpiresAt.Time)
return nil
}
return errors.New("repeated")
}
return errors.New("unauthorized")
},
PreFinishResponseCallback: func(hook tus.HookEvent) error {
authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey)
jwtParse, err := s.token.JwtParse(authStr)
if err != nil {
return errors.New("token expired")
}
_, ok := s.uploading.Load(authStr)
if ok {
s.uploading.Delete(authStr)
}
upload, err := s.store.GetUpload(context.Background(), hook.Upload.ID)
if err != nil {
return err
}
info, err := upload.GetInfo(context.Background())
path, exist := info.Storage["Path"]
if err != nil || !exist {
return errors.New("file not found")
}
content, err := os.ReadFile(path)
if err != nil {
return err
}
hash := sha256.New()
hash.Write(content)
sha256Byte := hash.Sum(nil)
sha256String := fmt.Sprintf("%x", sha256Byte)
if !s.config.Debug && sha256String != strings.ToLower(jwtParse.ID) {
_ = os.Remove(path)
_ = os.Remove(path + ".info")
return errors.New("file check error")
}
return nil
},
})
if err != nil {
return err
}
go func() {
for {
select {
case event := <-handler.CompleteUploads:
authStr := event.HTTPRequest.Header.Get(s.headerTokenKey)
jwtParse, _ := s.token.JwtParse(authStr)
if s.completedEvent != nil {
go func() {
s.completedEvent(jwtParse.ID, jwtParse.Subject, event.Upload)
}()
}
case <-s.ctx.Done():
return
}
}
}()
go func() {
for {
select {
case event := <-handler.TerminatedUploads:
upload, _ := s.store.GetUpload(context.Background(), event.Upload.ID)
if upload != nil {
info, _ := upload.GetInfo(context.Background())
path, exist := info.Storage["Path"]
if exist {
_ = os.Remove(path)
_ = os.Remove(path + ".info")
}
}
case <-s.ctx.Done():
return
}
}
}()
s.checker.Process(func() {
s.uploading.Range(func(key, value any) bool {
t := value.(time.Time)
if t.Before(time.Now()) {
s.uploading.Delete(key)
}
return true
})
})
//监听服务
addr := s.config.ListenAddr
mux := http.NewServeMux()
mux.Handle(s.config.Path, http.StripPrefix(s.config.Path, handler))
s.httpServer = &http.Server{
Addr: addr,
Handler: cors.AllowAll().Handler(mux),
}
go func() {
if err = s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Sugar().Fatal("upload server startup err", zap.Error(err))
}
}()
fmt.Println(color.Green(fmt.Sprintf("* [register tusd listen %s]", addr)))
return nil
}
func (s *server) Stop() error {
s.done()
return s.httpServer.Close()
}