Files
base-golang/pkg/storage/s3_client.go
2026-02-25 15:00:32 +08:00

442 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//go:build !minio
package storage
import (
"context"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"io"
"path/filepath"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go"
)
// Config S3配置
type Config struct {
Endpoint string `yaml:"endpoint" json:"endpoint"` // S3地址可选留空使用AWS
AccessKeyID string `yaml:"access_key_id" json:"access_key_id"` // AccessKey
SecretAccessKey string `yaml:"secret_access_key" json:"secret_access_key"` // SecretKey
UseSSL bool `yaml:"use_ssl" json:"use_ssl"` // 是否使用SSL
BucketName string `yaml:"bucket_name" json:"bucket_name"` // 默认桶名称
Region string `yaml:"region" json:"region"` // 区域
CDNDomain string `yaml:"cdn_domain" json:"cdn_domain"` // CDN域名可选
PresignExpires time.Duration `yaml:"presign_expires" json:"presign_expires"` // 预签名URL过期时间默认15分钟
}
// Client S3客户端
type Client struct {
client *s3.Client
presignClient *s3.PresignClient
config *Config
}
// NewClient 创建S3客户端
func NewClient(ctx context.Context, cfg *Config) (*Client, error) {
// 设置默认值
if cfg.PresignExpires == 0 {
cfg.PresignExpires = 15 * time.Minute
}
if cfg.UseSSL && cfg.Endpoint == "" {
cfg.UseSSL = true // AWS默认使用SSL
}
// 构建AWS配置选项
var opts []func(*config.LoadOptions) error
// 设置区域
if cfg.Region != "" {
opts = append(opts, config.WithRegion(cfg.Region))
} else {
opts = append(opts, config.WithRegion("us-east-1")) // 默认区域
cfg.Region = "us-east-1"
}
// 设置凭证
if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" {
opts = append(opts, config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
))
}
awsCfg, err := config.LoadDefaultConfig(ctx, opts...)
if err != nil {
return nil, fmt.Errorf("加载AWS配置失败: %w", err)
}
// 创建S3客户端选项
s3Opts := []func(*s3.Options){
func(o *s3.Options) {
// 自定义端点用于兼容S3的对象存储
if cfg.Endpoint != "" {
o.BaseEndpoint = aws.String(cfg.Endpoint)
o.UsePathStyle = true // 使用路径风格访问兼容MinIO
}
// SSL设置
if !cfg.UseSSL {
o.EndpointOptions.DisableHTTPS = true
}
},
}
s3Client := s3.NewFromConfig(awsCfg, s3Opts...)
presignClient := s3.NewPresignClient(s3Client)
c := &Client{
client: s3Client,
presignClient: presignClient,
config: cfg,
}
// 确保默认桶存在
if cfg.BucketName != "" {
timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
exists, err := c.BucketExists(timeoutCtx, cfg.BucketName)
if err != nil {
return nil, fmt.Errorf("检查桶失败: %w", err)
}
if !exists {
if err := c.CreateBucket(timeoutCtx, cfg.BucketName); err != nil {
return nil, fmt.Errorf("创建桶失败: %w", err)
}
}
}
return c, nil
}
// UploadToken 上传凭证
type UploadToken struct {
Key string `json:"key"` // 文件存储路径
UploadURL string `json:"upload_url"` // 预签名上传URL
ExpiresAt time.Time `json:"expires_at"` // 过期时间
BucketName string `json:"bucket_name"` // 桶名称
AccessURL string `json:"access_url"` // 访问URL可选
}
// DownloadToken 下载凭证
type DownloadToken struct {
Key string `json:"key"` // 文件key
DownloadURL string `json:"download_url"` // 预签名下载URL
ExpiresAt time.Time `json:"expires_at"` // 过期时间
Filename string `json:"filename"` // 文件名(可选)
}
// FileInfo 文件信息
type FileInfo struct {
Key string `json:"key"` // 文件key
Size int64 `json:"size"` // 文件大小
ETag string `json:"etag"` // ETagMD5
ContentType string `json:"content_type"` // Content-Type
LastModified time.Time `json:"last_modified"` // 最后修改时间
Metadata map[string]string `json:"metadata"` // 元数据
URL string `json:"url"` // 访问URL
Exists bool `json:"exists"` // 是否存在
}
// GenerateUploadToken 生成上传凭证
func (c *Client) GenerateUploadToken(ctx context.Context, key string, bucketName ...string) (*UploadToken, error) {
bucket := c.config.BucketName
if len(bucketName) > 0 && bucketName[0] != "" {
bucket = bucketName[0]
}
// 生成预签名PUT URL
presignedReq, err := c.presignClient.PresignPutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}, s3.WithPresignExpires(c.config.PresignExpires))
if err != nil {
return nil, fmt.Errorf("生成上传凭证失败: %w", err)
}
token := &UploadToken{
Key: key,
UploadURL: presignedReq.URL,
ExpiresAt: time.Now().Add(c.config.PresignExpires),
BucketName: bucket,
}
// 如果配置了CDN域名生成访问URL
if c.config.CDNDomain != "" {
token.AccessURL = c.buildCDNURL(bucket, key)
}
return token, nil
}
// GenerateDownloadToken 生成下载凭证
func (c *Client) GenerateDownloadToken(ctx context.Context, key string, bucketName ...string) (*DownloadToken, error) {
bucket := c.config.BucketName
if len(bucketName) > 0 && bucketName[0] != "" {
bucket = bucketName[0]
}
// 检查文件是否存在
exists, err := c.FileExists(ctx, key, bucket)
if err != nil {
return nil, err
}
if !exists {
return nil, fmt.Errorf("文件不存在: %s", key)
}
// 生成预签名GET URL
presignedReq, err := c.presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}, s3.WithPresignExpires(c.config.PresignExpires))
if err != nil {
return nil, fmt.Errorf("生成下载凭证失败: %w", err)
}
token := &DownloadToken{
Key: key,
DownloadURL: presignedReq.URL,
ExpiresAt: time.Now().Add(c.config.PresignExpires),
Filename: filepath.Base(key),
}
return token, nil
}
// VerifyFile 验证文件完整性
func (c *Client) VerifyFile(ctx context.Context, key string, expectedMD5 string, bucketName ...string) (*FileInfo, error) {
bucket := c.config.BucketName
if len(bucketName) > 0 && bucketName[0] != "" {
bucket = bucketName[0]
}
// 获取文件信息
output, err := c.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
// 检查是否是 NoSuchKey 错误
if c.isNotFoundError(err) {
return &FileInfo{
Key: key,
Exists: false,
}, nil
}
return nil, fmt.Errorf("获取文件信息失败: %w", err)
}
// 转换元数据
metadata := make(map[string]string)
for k, v := range output.Metadata {
metadata[k] = v
}
contentType := ""
if output.ContentType != nil {
contentType = *output.ContentType
}
fileInfo := &FileInfo{
Key: key,
Size: aws.ToInt64(output.ContentLength),
ETag: strings.Trim(aws.ToString(output.ETag), "\""), // 去除引号
ContentType: contentType,
LastModified: aws.ToTime(output.LastModified),
Metadata: metadata,
Exists: true,
URL: c.buildAccessURL(bucket, key),
}
// 如果提供了期望的MD5进行验证
if expectedMD5 != "" {
if !c.compareMD5(fileInfo.ETag, expectedMD5) {
return fileInfo, fmt.Errorf("文件MD5不匹配期望: %s, 实际: %s", expectedMD5, fileInfo.ETag)
}
}
return fileInfo, nil
}
// CalculateFileMD5 计算文件MD5从S3下载并计算
func (c *Client) CalculateFileMD5(ctx context.Context, key string, bucketName ...string) (string, error) {
bucket := c.config.BucketName
if len(bucketName) > 0 && bucketName[0] != "" {
bucket = bucketName[0]
}
// 下载文件
output, err := c.client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
return "", fmt.Errorf("下载文件失败: %w", err)
}
defer func() { _ = output.Body.Close() }()
// 计算MD5
hash := md5.New()
if _, err := io.Copy(hash, output.Body); err != nil {
return "", fmt.Errorf("计算MD5失败: %w", err)
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
// FileExists 检查文件是否存在
func (c *Client) FileExists(ctx context.Context, key string, bucketName ...string) (bool, error) {
bucket := c.config.BucketName
if len(bucketName) > 0 && bucketName[0] != "" {
bucket = bucketName[0]
}
_, err := c.client.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
// 检查是否是 NoSuchKey 错误
if c.isNotFoundError(err) {
return false, nil
}
return false, fmt.Errorf("检查文件失败: %w", err)
}
return true, nil
}
// DeleteFile 删除文件
func (c *Client) DeleteFile(ctx context.Context, key string, bucketName ...string) error {
bucket := c.config.BucketName
if len(bucketName) > 0 && bucketName[0] != "" {
bucket = bucketName[0]
}
_, err := c.client.DeleteObject(ctx, &s3.DeleteObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})
if err != nil {
return fmt.Errorf("删除文件失败: %w", err)
}
return nil
}
// GetFileInfo 获取文件信息
func (c *Client) GetFileInfo(ctx context.Context, key string, bucketName ...string) (*FileInfo, error) {
return c.VerifyFile(ctx, key, "", bucketName...)
}
// BucketExists 检查桶是否存在
func (c *Client) BucketExists(ctx context.Context, bucketName string) (bool, error) {
_, err := c.client.HeadBucket(ctx, &s3.HeadBucketInput{
Bucket: aws.String(bucketName),
})
if err != nil {
// 检查是否是 NotFound 错误
if c.isNotFoundError(err) {
return false, nil
}
return false, fmt.Errorf("检查桶失败: %w", err)
}
return true, nil
}
// CreateBucket 创建桶
func (c *Client) CreateBucket(ctx context.Context, bucketName string) error {
input := &s3.CreateBucketInput{
Bucket: aws.String(bucketName),
}
// 如果不是 us-east-1 区域,需要设置 LocationConstraint
if c.config.Region != "" && c.config.Region != "us-east-1" {
input.CreateBucketConfiguration = &types.CreateBucketConfiguration{
LocationConstraint: types.BucketLocationConstraint(c.config.Region),
}
}
_, err := c.client.CreateBucket(ctx, input)
if err != nil {
return fmt.Errorf("创建桶失败: %w", err)
}
return nil
}
// SetBucketPublic 设置桶为公开访问
func (c *Client) SetBucketPublic(ctx context.Context, bucketName string) error {
policy := fmt.Sprintf(`{
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Principal": {"AWS": ["*"]},
"Action": ["s3:GetObject"],
"Resource": ["arn:aws:s3:::%s/*"]
}]
}`, bucketName)
_, err := c.client.PutBucketPolicy(ctx, &s3.PutBucketPolicyInput{
Bucket: aws.String(bucketName),
Policy: aws.String(policy),
})
if err != nil {
return fmt.Errorf("设置桶策略失败: %w", err)
}
return nil
}
// buildAccessURL 构建访问URL
func (c *Client) buildAccessURL(bucket, key string) string {
if c.config.CDNDomain != "" {
return c.buildCDNURL(bucket, key)
}
protocol := "http"
if c.config.UseSSL {
protocol = "https"
}
if c.config.Endpoint != "" {
return fmt.Sprintf("%s://%s/%s/%s", protocol, c.config.Endpoint, bucket, key)
}
// AWS S3 默认URL格式
return fmt.Sprintf("%s://%s.s3.%s.amazonaws.com/%s", protocol, bucket, c.config.Region, key)
}
// buildCDNURL 构建CDN URL
func (c *Client) buildCDNURL(bucket, key string) string {
return fmt.Sprintf("%s/%s/%s", strings.TrimRight(c.config.CDNDomain, "/"), bucket, key)
}
// compareMD5 比较MD5
func (c *Client) compareMD5(etag, md5 string) bool {
etag = strings.ToLower(strings.Trim(etag, "\""))
md5 = strings.ToLower(strings.Trim(md5, "\""))
return etag == md5
}
// isNotFoundError 检查是否为NotFound错误
func (c *Client) isNotFoundError(err error) bool {
if err == nil {
return false
}
var apiErr smithy.APIError
if errors.As(err, &apiErr) {
code := apiErr.ErrorCode()
return code == "NotFound" || code == "NoSuchKey" || code == "NoSuchBucket"
}
return false
}