Files
base-golang/pkg/database/mongo.go
2026-02-07 15:36:21 +08:00

265 lines
5.9 KiB
Go

package database
import (
"context"
"fmt"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
// MongoDB MongoDB接口
type MongoDB interface {
// GetDB 获取数据库实例
GetDB() *mongo.Database
// GetClient 获取客户端实例
GetClient() *mongo.Client
// GetCollection 获取集合
GetCollection(name string) *mongo.Collection
// Ping 检查连接
Ping(ctx context.Context) error
// Close 关闭连接
Close(ctx context.Context) error
// WithContext 创建带超时的上下文
WithContext(ctx context.Context) (context.Context, context.CancelFunc)
}
// MongoDBConfig MongoDB配置
type MongoDBConfig struct {
// 地址,支持多个地址: "localhost:27017,localhost:27018"
Addr string `yaml:"addr" json:"addr"`
// 用户名
User string `yaml:"user" json:"user"`
// 密码
Pass string `yaml:"pass" json:"pass"`
// 数据库名称
Name string `yaml:"name" json:"name"`
// 连接超时(秒)
Timeout time.Duration `yaml:"timeout" json:"timeout"`
// 最大连接池大小
MaxPoolSize uint64 `yaml:"max_pool_size" json:"max_pool_size"`
// 最小连接池大小
MinPoolSize uint64 `yaml:"min_pool_size" json:"min_pool_size"`
// 最大连接空闲时间(秒)
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time" json:"max_conn_idle_time"`
// 是否使用副本集
ReplicaSet string `yaml:"replica_set" json:"replica_set"`
// 是否使用TLS
UseTLS bool `yaml:"use_tls" json:"use_tls"`
// 认证数据库
AuthSource string `yaml:"auth_source" json:"auth_source"`
}
// DefaultMongoDBConfig 默认配置
func DefaultMongoDBConfig() *MongoDBConfig {
return &MongoDBConfig{
Timeout: 10, // 10秒
MaxPoolSize: 100,
MinPoolSize: 10,
MaxConnIdleTime: 60, // 60秒
AuthSource: "admin",
}
}
// mongoDB MongoDB实现
type mongoDB struct {
client *mongo.Client
db *mongo.Database
config *MongoDBConfig
timeout time.Duration
}
// NewMongoDB 创建MongoDB实例
func NewMongoDB(cfg *MongoDBConfig) (MongoDB, error) {
if cfg == nil {
return nil, fmt.Errorf("配置不能为空")
}
// 验证必填参数
if cfg.Addr == "" {
return nil, fmt.Errorf("地址不能为空")
}
if cfg.Name == "" {
return nil, fmt.Errorf("数据库名称不能为空")
}
timeout := cfg.Timeout * time.Second
if timeout == 0 {
timeout = 10 * time.Second
}
// 构建连接选项
clientOpts := options.Client()
// 构建URI
uri := buildMongoURI(cfg)
clientOpts.ApplyURI(uri)
// 设置连接池
if cfg.MaxPoolSize > 0 {
clientOpts.SetMaxPoolSize(cfg.MaxPoolSize)
}
if cfg.MinPoolSize > 0 {
clientOpts.SetMinPoolSize(cfg.MinPoolSize)
}
if cfg.MaxConnIdleTime > 0 {
clientOpts.SetMaxConnIdleTime(cfg.MaxConnIdleTime * time.Second)
}
// 设置超时
clientOpts.SetConnectTimeout(timeout)
clientOpts.SetServerSelectionTimeout(timeout)
// 连接MongoDB
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, err := mongo.Connect(ctx, clientOpts)
if err != nil {
return nil, fmt.Errorf("连接MongoDB失败: %w", err)
}
// Ping测试连接
pingCtx, pingCancel := context.WithTimeout(context.Background(), timeout)
defer pingCancel()
if err = client.Ping(pingCtx, readpref.Primary()); err != nil {
_ = client.Disconnect(context.Background())
return nil, fmt.Errorf("Ping MongoDB失败: %w", err)
}
return &mongoDB{
client: client,
db: client.Database(cfg.Name),
config: cfg,
timeout: timeout,
}, nil
}
// GetDB 获取数据库实例
func (m *mongoDB) GetDB() *mongo.Database {
return m.db
}
// GetClient 获取客户端实例
func (m *mongoDB) GetClient() *mongo.Client {
return m.client
}
// GetCollection 获取集合
func (m *mongoDB) GetCollection(name string) *mongo.Collection {
return m.db.Collection(name)
}
// Ping 检查连接
func (m *mongoDB) Ping(ctx context.Context) error {
pingCtx, cancel := m.WithContext(ctx)
defer cancel()
return m.client.Ping(pingCtx, readpref.Primary())
}
// Close 关闭连接
func (m *mongoDB) Close(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
disconnectCtx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
if err := m.client.Disconnect(disconnectCtx); err != nil {
return fmt.Errorf("断开MongoDB连接失败: %w", err)
}
return nil
}
// WithContext 创建带超时的上下文
func (m *mongoDB) WithContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, m.timeout)
}
// buildMongoURI 构建MongoDB URI
func buildMongoURI(cfg *MongoDBConfig) string {
var auth string
if cfg.User != "" && cfg.Pass != "" {
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
}
uri := fmt.Sprintf("mongodb://%s%s/%s", auth, cfg.Addr, cfg.Name)
// 添加查询参数
params := []string{}
if cfg.AuthSource != "" {
params = append(params, fmt.Sprintf("authSource=%s", cfg.AuthSource))
}
if cfg.ReplicaSet != "" {
params = append(params, fmt.Sprintf("replicaSet=%s", cfg.ReplicaSet))
}
if cfg.UseTLS {
params = append(params, "tls=true")
}
if len(params) > 0 {
uri += "?"
for i, param := range params {
if i > 0 {
uri += "&"
}
uri += param
}
}
return uri
}
// Helper functions
// EnsureIndexes 确保索引存在
func EnsureIndexes(ctx context.Context, collection *mongo.Collection, indexes []mongo.IndexModel) error {
if len(indexes) == 0 {
return nil
}
_, err := collection.Indexes().CreateMany(ctx, indexes)
if err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
return nil
}
// DropIndexes 删除索引
func DropIndexes(ctx context.Context, collection *mongo.Collection, indexNames []string) error {
for _, name := range indexNames {
if _, err := collection.Indexes().DropOne(ctx, name); err != nil {
return fmt.Errorf("删除索引 %s 失败: %w", name, err)
}
}
return nil
}