base-golang/pkg/proxy/io.go
2024-07-23 10:23:43 +08:00

143 lines
2.8 KiB
Go

package proxy
import (
"context"
"golang.org/x/time/rate"
"io"
"net"
"runtime/debug"
"sync"
"time"
)
const burstLimit = 1000 * 1000 * 1000
type Reader struct {
r io.Reader
limiter *rate.Limiter
ctx context.Context
}
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
ctx: context.Background(),
}
}
func (s *Reader) SetRateLimit(bytesPerSec float64) {
s.limiter = rate.NewLimiter(rate.Limit(bytesPerSec), burstLimit)
s.limiter.AllowN(time.Now(), burstLimit) // spend initial burst
}
func (s *Reader) Read(p []byte) (int, error) {
if s.limiter == nil {
return s.r.Read(p)
}
n, err := s.r.Read(p)
if err != nil {
return n, err
}
if err := s.limiter.WaitN(s.ctx, n); err != nil {
return n, err
}
return n, nil
}
func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
return
}
func CloseConn(conn net.Conn) {
if conn != nil {
_ = conn.SetDeadline(time.Now().Add(time.Millisecond))
_ = conn.Close()
}
}
func IoBind(dst io.ReadWriter, src io.ReadWriter, fn func(isSrcErr bool, err error), cfn func(count int, isPositive bool), bytesPreSec float64) {
var one = &sync.Once{}
go func() {
defer func() {
if e := recover(); e != nil {
logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
var isSrcErr bool
if bytesPreSec > 0 {
newReader := NewReader(src)
newReader.SetRateLimit(bytesPreSec)
_, isSrcErr, err = IoCopy(dst, newReader, func(c int) {
cfn(c, false)
})
} else {
_, isSrcErr, err = IoCopy(dst, src, func(c int) {
cfn(c, false)
})
}
if err != nil {
one.Do(func() {
fn(isSrcErr, err)
})
}
}()
go func() {
defer func() {
if e := recover(); e != nil {
logger.Sugar().Errorf("IoBind crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
}
}()
var err error
var isSrcErr bool
if bytesPreSec > 0 {
newReader := NewReader(dst)
newReader.SetRateLimit(bytesPreSec)
_, isSrcErr, err = IoCopy(src, newReader, func(c int) {
cfn(c, true)
})
} else {
_, isSrcErr, err = IoCopy(src, dst, func(c int) {
cfn(c, true)
})
}
if err != nil {
one.Do(func() {
fn(isSrcErr, err)
})
}
}()
}
func IoCopy(dst io.Writer, src io.Reader, fn ...func(count int)) (written int64, isSrcErr bool, err error) {
buf := make([]byte, 32*1024)
for {
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
if len(fn) == 1 {
fn[0](nw)
}
}
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
err = er
isSrcErr = true
break
}
}
return written, isSrcErr, err
}