143 lines
2.8 KiB
Go
143 lines
2.8 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"golang.org/x/time/rate"
|
|
"io"
|
|
"net"
|
|
"runtime/debug"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const burstLimit = 1000 * 1000 * 1000
|
|
|
|
type Reader struct {
|
|
r io.Reader
|
|
limiter *rate.Limiter
|
|
ctx context.Context
|
|
}
|
|
|
|
func NewReader(r io.Reader) *Reader {
|
|
return &Reader{
|
|
r: r,
|
|
ctx: context.Background(),
|
|
}
|
|
}
|
|
|
|
func (s *Reader) SetRateLimit(bytesPerSec float64) {
|
|
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
|
|
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
|
|
}
|
|
|
|
func (s *Reader) Read(p []byte) (int, error) {
|
|
if s.limiter == nil {
|
|
return s.r.Read(p)
|
|
}
|
|
n, err := s.r.Read(p)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
if err := s.limiter.WaitN(s.ctx, n); err != nil {
|
|
return n, err
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
|
|
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
|
|
return
|
|
}
|
|
|
|
func CloseConn(conn net.Conn) {
|
|
if conn != nil {
|
|
_ = conn.SetDeadline(time.Now().Add(time.Millisecond))
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
|
|
var one = &sync.Once{}
|
|
go func() {
|
|
defer func() {
|
|
if e := recover(); e != nil {
|
|
logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
|
}
|
|
}()
|
|
var err error
|
|
var isSrcErr bool
|
|
if bytesPreSec > 0 {
|
|
newReader := NewReader(src)
|
|
newReader.SetRateLimit(bytesPreSec)
|
|
_, isSrcErr, err = IoCopy(dst, newReader, func(c int) {
|
|
cfn(c, false)
|
|
})
|
|
|
|
} else {
|
|
_, isSrcErr, err = IoCopy(dst, src, func(c int) {
|
|
cfn(c, false)
|
|
})
|
|
}
|
|
if err != nil {
|
|
one.Do(func() {
|
|
fn(isSrcErr, err)
|
|
})
|
|
}
|
|
}()
|
|
go func() {
|
|
defer func() {
|
|
if e := recover(); e != nil {
|
|
logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
|
|
}
|
|
}()
|
|
var err error
|
|
var isSrcErr bool
|
|
if bytesPreSec > 0 {
|
|
newReader := NewReader(dst)
|
|
newReader.SetRateLimit(bytesPreSec)
|
|
_, isSrcErr, err = IoCopy(src, newReader, func(c int) {
|
|
cfn(c, true)
|
|
})
|
|
} else {
|
|
_, isSrcErr, err = IoCopy(src, dst, func(c int) {
|
|
cfn(c, true)
|
|
})
|
|
}
|
|
if err != nil {
|
|
one.Do(func() {
|
|
fn(isSrcErr, err)
|
|
})
|
|
}
|
|
}()
|
|
}
|
|
|
|
func IoCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
nr, er := src.Read(buf)
|
|
if nr > 0 {
|
|
nw, ew := dst.Write(buf[0:nr])
|
|
if nw > 0 {
|
|
written += int64(nw)
|
|
if len(fn) == 1 {
|
|
fn[0](nw)
|
|
}
|
|
}
|
|
if ew != nil {
|
|
err = ew
|
|
break
|
|
}
|
|
if nr != nw {
|
|
err = io.ErrShortWrite
|
|
break
|
|
}
|
|
}
|
|
if er != nil {
|
|
err = er
|
|
isSrcErr = true
|
|
break
|
|
}
|
|
}
|
|
return written, isSrcErr, err
|
|
}
|