From 2415f59f7f93a6391c486b6359df0943279e8b5e Mon Sep 17 00:00:00 2001 From: bvbej Date: Wed, 31 Jul 2024 17:02:30 +0800 Subject: [PATCH] =?UTF-8?q?[=F0=9F=9A=80]=20downloader?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/downloader/base/constants.go | 23 ++ pkg/downloader/base/model.go | 37 +++ pkg/downloader/controller/controller.go | 172 ++++++++++ pkg/downloader/fetcher/fetcher.go | 52 +++ pkg/downloader/protocol/http/fetcher.go | 409 ++++++++++++++++++++++++ pkg/downloader/protocol/http/model.go | 24 ++ pkg/downloader/util/timer.go | 25 ++ 7 files changed, 742 insertions(+) create mode 100644 pkg/downloader/base/constants.go create mode 100644 pkg/downloader/base/model.go create mode 100644 pkg/downloader/controller/controller.go create mode 100644 pkg/downloader/fetcher/fetcher.go create mode 100644 pkg/downloader/protocol/http/fetcher.go create mode 100644 pkg/downloader/protocol/http/model.go create mode 100644 pkg/downloader/util/timer.go diff --git a/pkg/downloader/base/constants.go b/pkg/downloader/base/constants.go new file mode 100644 index 0000000..9e1a542 --- /dev/null +++ b/pkg/downloader/base/constants.go @@ -0,0 +1,23 @@ +package base + +type Status int + +const ( + DownloadStatusReady Status = iota + DownloadStatusStart + DownloadStatusPause + DownloadStatusError + DownloadStatusDone +) + +const ( + HttpCodeOK = 200 + HttpCodePartialContent = 206 + + HttpHeaderRange = "Range" + HttpHeaderContentLength = "Content-Length" + HttpHeaderContentRange = "Content-Range" + HttpHeaderContentDisposition = "Content-Disposition" + + HttpHeaderRangeFormat = "bytes=%d-%d" +) diff --git a/pkg/downloader/base/model.go b/pkg/downloader/base/model.go new file mode 100644 index 0000000..05a352e --- /dev/null +++ b/pkg/downloader/base/model.go @@ -0,0 +1,37 @@ +package base + +// Request 下载请求 +type Request struct { + // 下载链接 + URL string + // 附加信息 + Extra any +} + +// Resource 资源信息 +type Resource struct { + Req *Request + // 资源总大小 + TotalSize int64 + // 是否支持断点下载 + Range bool + // 资源所包含的文件列表 + Files []*FileInfo +} + +// FileInfo 文件信息 +type FileInfo struct { + Name string + Path string + Size int64 +} + +// Options 下载选项 +type Options struct { + // 保存文件名 + Name string + // 保存目录 + Path string + // 并发连接数 + Connections int +} diff --git a/pkg/downloader/controller/controller.go b/pkg/downloader/controller/controller.go new file mode 100644 index 0000000..c97c238 --- /dev/null +++ b/pkg/downloader/controller/controller.go @@ -0,0 +1,172 @@ +package controller + +import ( + "golang.org/x/net/proxy" + "net" + "net/http" + "net/url" + "os" + "time" +) + +type Controller interface { + Touch(name string, size int64) (file *os.File, err error) + Open(name string) (file *os.File, err error) + Write(name string, offset int64, buf []byte) (int, error) + Close(name string) error + ContextDialer() (proxy.Dialer, error) + ContextCookie() http.CookieJar + ContextTimeout() time.Duration + ContextProxy() func(*http.Request) (*url.URL, error) +} + +type Option func(*option) + +type option struct { + CookieJar http.CookieJar + Timeout time.Duration + Dialer proxy.Dialer + Proxy func(*http.Request) (*url.URL, error) +} + +func WithCookie(cookieJar http.CookieJar) Option { + return func(opt *option) { + opt.CookieJar = cookieJar + } +} + +func WithTimeout(timeout time.Duration) Option { + return func(opt *option) { + opt.Timeout = timeout + } +} + +func WithDialer(dialer proxy.Dialer) Option { + return func(opt *option) { + opt.Dialer = dialer + } +} + +func WithProxy(fn func(*http.Request) (*url.URL, error)) Option { + return func(opt *option) { + opt.Proxy = fn + } +} + +type DefaultController struct { + *option + Files map[string]*os.File +} + +func NewController(options ...Option) *DefaultController { + opt := new(option) + for _, f := range options { + f(opt) + } + if opt.Timeout == 0 { + opt.Timeout = time.Second * 30 + } + if opt.Dialer == nil { + opt.Dialer = proxy.FromEnvironment() + } + return &DefaultController{ + Files: make(map[string]*os.File), + option: opt, + } +} + +func (c *DefaultController) Touch(name string, size int64) (file *os.File, err error) { + file, err = os.Create(name) + if size > 0 { + err = os.Truncate(name, size) + if err != nil { + return nil, err + } + } + if err == nil { + c.Files[name] = file + } + return +} + +func (c *DefaultController) Open(name string) (file *os.File, err error) { + file, err = os.OpenFile(name, os.O_RDWR, os.ModePerm) + if err == nil { + c.Files[name] = file + } + return +} + +func (c *DefaultController) Write(name string, offset int64, buf []byte) (int, error) { + return c.Files[name].WriteAt(buf, offset) +} + +func (c *DefaultController) Close(name string) error { + err := c.Files[name].Close() + delete(c.Files, name) + return err +} + +func (c *DefaultController) ContextDialer() (proxy.Dialer, error) { + return &DialerWarp{dialer: c.Dialer}, nil +} + +func (c *DefaultController) ContextCookie() http.CookieJar { + return c.CookieJar +} + +func (c *DefaultController) ContextTimeout() time.Duration { + return c.Timeout +} + +func (c *DefaultController) ContextProxy() func(*http.Request) (*url.URL, error) { + return c.Proxy +} + +type DialerWarp struct { + dialer proxy.Dialer +} + +type ConnWarp struct { + conn net.Conn +} + +func (c *ConnWarp) Read(b []byte) (n int, err error) { + return c.conn.Read(b) +} + +func (c *ConnWarp) Write(b []byte) (n int, err error) { + return c.conn.Write(b) +} + +func (c *ConnWarp) Close() error { + return c.conn.Close() +} + +func (c *ConnWarp) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *ConnWarp) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *ConnWarp) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *ConnWarp) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *ConnWarp) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func (d *DialerWarp) Dial(network, addr string) (c net.Conn, err error) { + conn, err := d.dialer.Dial(network, addr) + if err != nil { + return nil, err + } + return &ConnWarp{conn: conn}, nil +} diff --git a/pkg/downloader/fetcher/fetcher.go b/pkg/downloader/fetcher/fetcher.go new file mode 100644 index 0000000..14d4e0a --- /dev/null +++ b/pkg/downloader/fetcher/fetcher.go @@ -0,0 +1,52 @@ +package fetcher + +import ( + "gitea.bvbej.com/bvbej/base-golang/pkg/downloader/base" + "gitea.bvbej.com/bvbej/base-golang/pkg/downloader/controller" +) + +// Fetcher 对应协议的下载支持 +type Fetcher interface { + // Setup 设置文件相关信息 + Setup(ctl controller.Controller) + // Resolve 解析请求 + Resolve(req *base.Request) (res *base.Resource, err error) + // Create 创建任务 + Create(res *base.Resource, opts *base.Options) (err error) + // Start 开始 + Start() (err error) + // Pause 暂停 + Pause() (err error) + // Continue 继续 + Continue() (err error) + // Progress 获取任务各个文件下载进度 + Progress() Progress + // Wait 该方法会一直阻塞,直到任务下载结束 + Wait() (err error) +} + +type DefaultFetcher struct { + Ctl controller.Controller + DoneCh chan error +} + +func (f *DefaultFetcher) Setup(ctl controller.Controller) { + f.Ctl = ctl + f.DoneCh = make(chan error, 1) +} + +func (f *DefaultFetcher) Wait() (err error) { + return <-f.DoneCh +} + +// Progress 获取任务中各个文件的已下载字节数 +type Progress []int64 + +// TotalDownloaded 获取任务总下载字节数 +func (p Progress) TotalDownloaded() int64 { + total := int64(0) + for _, d := range p { + total += d + } + return total +} diff --git a/pkg/downloader/protocol/http/fetcher.go b/pkg/downloader/protocol/http/fetcher.go new file mode 100644 index 0000000..d9ed596 --- /dev/null +++ b/pkg/downloader/protocol/http/fetcher.go @@ -0,0 +1,409 @@ +package http + +import ( + "bytes" + "context" + "fmt" + "gitea.bvbej.com/bvbej/base-golang/pkg/downloader/base" + "gitea.bvbej.com/bvbej/base-golang/pkg/downloader/fetcher" + "golang.org/x/sync/errgroup" + "io" + "mime" + "net" + "net/http" + "net/url" + "path" + "path/filepath" + "strconv" + "strings" + "time" +) + +type RequestError struct { + Code int + Msg string +} + +func NewRequestError(code int, msg string) *RequestError { + return &RequestError{Code: code, Msg: msg} +} + +func (re *RequestError) Error() string { + return fmt.Sprintf("http request fail,code:%d", re.Code) +} + +type Fetcher struct { + *fetcher.DefaultFetcher + + res *base.Resource + opts *base.Options + status base.Status + clients []*http.Response + chunks []*Chunk + + ctx context.Context + cancel context.CancelFunc + pauseCh chan any +} + +func NewFetcher() *Fetcher { + return &Fetcher{ + DefaultFetcher: new(fetcher.DefaultFetcher), + pauseCh: make(chan any), + } +} + +var protocols = []string{"HTTP", "HTTPS"} + +func FetcherBuilder() ([]string, func() fetcher.Fetcher) { + return protocols, func() fetcher.Fetcher { + return NewFetcher() + } +} + +func (f *Fetcher) Resolve(req *base.Request) (*base.Resource, error) { + httpReq, err := f.buildRequest(nil, req) + if err != nil { + return nil, err + } + client, err := f.buildClient() + if err != nil { + return nil, err + } + // 只访问一个字节,测试资源是否支持Range请求 + httpReq.Header.Set(base.HttpHeaderRange, fmt.Sprintf(base.HttpHeaderRangeFormat, 0, 0)) + httpResp, err := client.Do(httpReq) + if err != nil { + return nil, err + } + // 拿到响应头就关闭,不用加defer + _ = httpResp.Body.Close() + res := &base.Resource{ + Req: req, + Range: false, + Files: []*base.FileInfo{}, + } + if base.HttpCodePartialContent == httpResp.StatusCode { + // 返回206响应码表示支持断点下载 + res.Range = true + // 解析资源大小: bytes 0-1000/1001 => 1001 + contentTotal := path.Base(httpResp.Header.Get(base.HttpHeaderContentRange)) + if contentTotal != "" { + parse, err := strconv.ParseInt(contentTotal, 10, 64) + if err != nil { + return nil, err + } + res.TotalSize = parse + } + } else if base.HttpCodeOK == httpResp.StatusCode { + // 返回200响应码,不支持断点下载,通过Content-Length头获取文件大小,获取不到的话可能是chunked编码 + contentLength := httpResp.Header.Get(base.HttpHeaderContentLength) + if contentLength != "" { + parse, err := strconv.ParseInt(contentLength, 10, 64) + if err != nil { + return nil, err + } + res.TotalSize = parse + } + } else { + return nil, NewRequestError(httpResp.StatusCode, httpResp.Status) + } + file := &base.FileInfo{ + Size: res.TotalSize, + } + contentDisposition := httpResp.Header.Get(base.HttpHeaderContentDisposition) + if contentDisposition != "" { + _, params, _ := mime.ParseMediaType(contentDisposition) + filename := params["filename"] + if filename != "" { + file.Name = filename + } + } + // Get file filename by URL + if file.Name == "" && strings.Count(req.URL, "/") > 2 { + file.Name = filepath.Base(req.URL) + } + // unknown file filename + if file.Name == "" { + file.Name = "unknown" + } + res.Files = append(res.Files, file) + return res, nil +} + +func (f *Fetcher) Create(res *base.Resource, opts *base.Options) error { + f.res = res + f.opts = opts + f.status = base.DownloadStatusReady + return nil +} + +func (f *Fetcher) Start() (err error) { + // 创建文件 + name := f.filename() + _, err = f.Ctl.Touch(name, f.res.TotalSize) + if err != nil { + return err + } + f.status = base.DownloadStatusStart + if f.res.Range { + // 每个连接平均需要下载的分块大小 + chunkSize := f.res.TotalSize / int64(f.opts.Connections) + f.chunks = make([]*Chunk, f.opts.Connections) + f.clients = make([]*http.Response, f.opts.Connections) + for i := 0; i < f.opts.Connections; i++ { + var ( + begin = chunkSize * int64(i) + end int64 + ) + if i == f.opts.Connections-1 { + // 最后一个分块需要保证把文件下载完 + end = f.res.TotalSize - 1 + } else { + end = begin + chunkSize - 1 + } + chunk := NewChunk(begin, end) + f.chunks[i] = chunk + } + } else { + // 只支持单连接下载 + f.chunks = make([]*Chunk, 1) + f.clients = make([]*http.Response, 1) + f.chunks[0] = NewChunk(0, 0) + } + f.fetch() + return +} + +func (f *Fetcher) Pause() (err error) { + if base.DownloadStatusStart != f.status { + return + } + f.status = base.DownloadStatusPause + f.cancel() + <-f.pauseCh + return +} + +func (f *Fetcher) Continue() (err error) { + if base.DownloadStatusStart == f.status || base.DownloadStatusDone == f.status { + return + } + f.status = base.DownloadStatusStart + var name = f.filename() + _, err = f.Ctl.Open(name) + if err != nil { + return err + } + f.fetch() + return +} + +func (f *Fetcher) Progress() fetcher.Progress { + p := make(fetcher.Progress, 0) + if len(f.chunks) > 0 { + total := int64(0) + for _, chunk := range f.chunks { + total += chunk.Downloaded + } + p = append(p, total) + } + return p +} + +func (f *Fetcher) filename() string { + // 创建文件 + var filename = f.opts.Name + if filename == "" { + filename = f.res.Files[0].Name + } + return filepath.Join(f.opts.Path, filename) +} + +func (f *Fetcher) fetch() { + f.ctx, f.cancel = context.WithCancel(context.Background()) + eg, _ := errgroup.WithContext(f.ctx) + for i := 0; i < f.opts.Connections; i++ { + eg.Go(func() error { + return f.fetchChunk(i) + }) + } + + go func() { + err := eg.Wait() + // 下载停止,关闭文件句柄 + _ = f.Ctl.Close(f.filename()) + if f.status == base.DownloadStatusPause { + f.pauseCh <- nil + } else { + if err != nil { + f.status = base.DownloadStatusError + } else { + f.status = base.DownloadStatusDone + } + f.DoneCh <- err + } + }() +} + +func (f *Fetcher) fetchChunk(index int) (err error) { + filename := f.filename() + chunk := f.chunks[index] + + httpReq, err := f.buildRequest(f.ctx, f.res.Req) + if err != nil { + return err + } + + client, err := f.buildClient() + if err != nil { + return err + } + + var buf = make([]byte, 8192) + + // 重试10次 + for i := 0; i < 10; i++ { + // 如果下载完成直接返回 + if chunk.Status == base.DownloadStatusDone { + return + } + // 如果已暂停直接跳出 + if f.status == base.DownloadStatusPause { + break + } + var ( + resp *http.Response + retry bool + ) + if f.res.Range { + httpReq.Header.Set(base.HttpHeaderRange, + fmt.Sprintf(base.HttpHeaderRangeFormat, chunk.Begin+chunk.Downloaded, chunk.End)) + } else { + chunk.Downloaded = 0 + } + err = func() error { + resp, err = client.Do(httpReq) + if err != nil { + return err + } + f.clients[index] = resp + if resp.StatusCode != base.HttpCodeOK && resp.StatusCode != base.HttpCodePartialContent { + err = NewRequestError(resp.StatusCode, resp.Status) + return err + } + return nil + }() + if err != nil { + //请求失败3s后重试 + time.Sleep(time.Second * 3) + continue + } + + // 请求成功就重置错误次数,连续失败10次才终止 + i = 0 + + retry, err = func() (bool, error) { + defer func() { + _ = resp.Body.Close() + }() + var n int + for { + n, err = resp.Body.Read(buf) + if n > 0 { + _, err = f.Ctl.Write(filename, chunk.Begin+chunk.Downloaded, buf[:n]) + if err != nil { + return false, err + } + chunk.Downloaded += int64(n) + } + if err != nil { + if err != io.EOF { + return true, err + } + break + } + } + return false, nil + }() + if !retry { + // 下载成功,跳出重试 + break + } + } + + if f.status == base.DownloadStatusPause { + chunk.Status = base.DownloadStatusPause + } else if chunk.Downloaded >= chunk.End-chunk.Begin+1 { + chunk.Status = base.DownloadStatusDone + } else { + if err != nil { + chunk.Status = base.DownloadStatusError + } else { + chunk.Status = base.DownloadStatusDone + } + } + return +} + +func (f *Fetcher) buildClient() (*http.Client, error) { + dialer, err := f.Ctl.ContextDialer() + if err != nil { + return nil, err + } + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + } + if f.Ctl.ContextProxy() != nil { + transport.Proxy = f.Ctl.ContextProxy() + } + return &http.Client{ + Jar: f.Ctl.ContextCookie(), + Timeout: f.Ctl.ContextTimeout(), + Transport: transport, + }, nil +} + +func (f *Fetcher) buildRequest(ctx context.Context, req *base.Request) (httpReq *http.Request, err error) { + reqUrl, err := url.Parse(req.URL) + if err != nil { + return + } + + var ( + method string + body io.Reader + ) + headers := make(map[string][]string) + if req.Extra == nil { + method = http.MethodGet + } else { + extra := req.Extra.(Extra) + if extra.Method != "" { + method = extra.Method + } else { + method = http.MethodGet + } + if len(extra.Header) > 0 { + for k, v := range extra.Header { + headers[k] = []string{v} + } + } + if extra.Body != "" { + body = io.NopCloser(bytes.NewBufferString(extra.Body)) + } + } + + if ctx != nil { + httpReq, err = http.NewRequestWithContext(ctx, method, reqUrl.String(), body) + } else { + httpReq, err = http.NewRequest(method, reqUrl.String(), body) + } + if err != nil { + return + } + httpReq.Header = headers + return httpReq, nil +} diff --git a/pkg/downloader/protocol/http/model.go b/pkg/downloader/protocol/http/model.go new file mode 100644 index 0000000..8349c69 --- /dev/null +++ b/pkg/downloader/protocol/http/model.go @@ -0,0 +1,24 @@ +package http + +import "gitea.bvbej.com/bvbej/base-golang/pkg/downloader/base" + +type Chunk struct { + Status base.Status + Begin int64 + End int64 + Downloaded int64 +} + +func NewChunk(begin int64, end int64) *Chunk { + return &Chunk{ + Status: base.DownloadStatusReady, + Begin: begin, + End: end, + } +} + +type Extra struct { + Method string + Header map[string]string + Body string +} diff --git a/pkg/downloader/util/timer.go b/pkg/downloader/util/timer.go new file mode 100644 index 0000000..d57ddff --- /dev/null +++ b/pkg/downloader/util/timer.go @@ -0,0 +1,25 @@ +package util + +import "time" + +// Timer 计时器 +type Timer struct { + t int64 + used int64 +} + +func (t *Timer) Start() { + t.t = time.Now().UnixNano() +} + +func (t *Timer) Pause() { + t.used += time.Now().UnixNano() - t.t +} + +func (t *Timer) Continue() { + t.t = time.Now().UnixNano() +} + +func (t *Timer) Used() int64 { + return (time.Now().UnixNano() - t.t) + t.used +}