package database import ( "context" "fmt" "time" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" ) // MongoDB MongoDB接口 type MongoDB interface { // GetDB 获取数据库实例 GetDB() *mongo.Database // GetClient 获取客户端实例 GetClient() *mongo.Client // GetCollection 获取集合 GetCollection(name string) *mongo.Collection // Ping 检查连接 Ping(ctx context.Context) error // Close 关闭连接 Close(ctx context.Context) error // WithContext 创建带超时的上下文 WithContext(ctx context.Context) (context.Context, context.CancelFunc) } // MongoDBConfig MongoDB配置 type MongoDBConfig struct { // 地址,支持多个地址: "localhost:27017,localhost:27018" Addr string `yaml:"addr" json:"addr"` // 用户名 User string `yaml:"user" json:"user"` // 密码 Pass string `yaml:"pass" json:"pass"` // 数据库名称 Name string `yaml:"name" json:"name"` // 连接超时(秒) Timeout time.Duration `yaml:"timeout" json:"timeout"` // 最大连接池大小 MaxPoolSize uint64 `yaml:"max_pool_size" json:"max_pool_size"` // 最小连接池大小 MinPoolSize uint64 `yaml:"min_pool_size" json:"min_pool_size"` // 最大连接空闲时间(秒) MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time" json:"max_conn_idle_time"` // 是否使用副本集 ReplicaSet string `yaml:"replica_set" json:"replica_set"` // 是否使用TLS UseTLS bool `yaml:"use_tls" json:"use_tls"` // 认证数据库 AuthSource string `yaml:"auth_source" json:"auth_source"` } // DefaultMongoDBConfig 默认配置 func DefaultMongoDBConfig() *MongoDBConfig { return &MongoDBConfig{ Timeout: 10, // 10秒 MaxPoolSize: 100, MinPoolSize: 10, MaxConnIdleTime: 60, // 60秒 AuthSource: "admin", } } // mongoDB MongoDB实现 type mongoDB struct { client *mongo.Client db *mongo.Database config *MongoDBConfig timeout time.Duration } // NewMongoDB 创建MongoDB实例 func NewMongoDB(cfg *MongoDBConfig) (MongoDB, error) { if cfg == nil { return nil, fmt.Errorf("配置不能为空") } // 验证必填参数 if cfg.Addr == "" { return nil, fmt.Errorf("地址不能为空") } if cfg.Name == "" { return nil, fmt.Errorf("数据库名称不能为空") } timeout := cfg.Timeout * time.Second if timeout == 0 { timeout = 10 * time.Second } // 构建连接选项 clientOpts := options.Client() // 构建URI uri := buildMongoURI(cfg) clientOpts.ApplyURI(uri) // 设置连接池 if cfg.MaxPoolSize > 0 { clientOpts.SetMaxPoolSize(cfg.MaxPoolSize) } if cfg.MinPoolSize > 0 { clientOpts.SetMinPoolSize(cfg.MinPoolSize) } if cfg.MaxConnIdleTime > 0 { clientOpts.SetMaxConnIdleTime(cfg.MaxConnIdleTime * time.Second) } // 设置超时 clientOpts.SetConnectTimeout(timeout) clientOpts.SetServerSelectionTimeout(timeout) // 连接MongoDB ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() client, err := mongo.Connect(ctx, clientOpts) if err != nil { return nil, fmt.Errorf("连接MongoDB失败: %w", err) } // Ping测试连接 pingCtx, pingCancel := context.WithTimeout(context.Background(), timeout) defer pingCancel() if err = client.Ping(pingCtx, readpref.Primary()); err != nil { _ = client.Disconnect(context.Background()) return nil, fmt.Errorf("Ping MongoDB失败: %w", err) } return &mongoDB{ client: client, db: client.Database(cfg.Name), config: cfg, timeout: timeout, }, nil } // GetDB 获取数据库实例 func (m *mongoDB) GetDB() *mongo.Database { return m.db } // GetClient 获取客户端实例 func (m *mongoDB) GetClient() *mongo.Client { return m.client } // GetCollection 获取集合 func (m *mongoDB) GetCollection(name string) *mongo.Collection { return m.db.Collection(name) } // Ping 检查连接 func (m *mongoDB) Ping(ctx context.Context) error { pingCtx, cancel := m.WithContext(ctx) defer cancel() return m.client.Ping(pingCtx, readpref.Primary()) } // Close 关闭连接 func (m *mongoDB) Close(ctx context.Context) error { if ctx == nil { ctx = context.Background() } disconnectCtx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() if err := m.client.Disconnect(disconnectCtx); err != nil { return fmt.Errorf("断开MongoDB连接失败: %w", err) } return nil } // WithContext 创建带超时的上下文 func (m *mongoDB) WithContext(ctx context.Context) (context.Context, context.CancelFunc) { if ctx == nil { ctx = context.Background() } return context.WithTimeout(ctx, m.timeout) } // buildMongoURI 构建MongoDB URI func buildMongoURI(cfg *MongoDBConfig) string { var auth string if cfg.User != "" && cfg.Pass != "" { auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass) } uri := fmt.Sprintf("mongodb://%s%s/%s", auth, cfg.Addr, cfg.Name) // 添加查询参数 params := []string{} if cfg.AuthSource != "" { params = append(params, fmt.Sprintf("authSource=%s", cfg.AuthSource)) } if cfg.ReplicaSet != "" { params = append(params, fmt.Sprintf("replicaSet=%s", cfg.ReplicaSet)) } if cfg.UseTLS { params = append(params, "tls=true") } if len(params) > 0 { uri += "?" for i, param := range params { if i > 0 { uri += "&" } uri += param } } return uri } // Helper functions // EnsureIndexes 确保索引存在 func EnsureIndexes(ctx context.Context, collection *mongo.Collection, indexes []mongo.IndexModel) error { if len(indexes) == 0 { return nil } _, err := collection.Indexes().CreateMany(ctx, indexes) if err != nil { return fmt.Errorf("创建索引失败: %w", err) } return nil } // DropIndexes 删除索引 func DropIndexes(ctx context.Context, collection *mongo.Collection, indexNames []string) error { for _, name := range indexNames { if _, err := collection.Indexes().DropOne(ctx, name); err != nil { return fmt.Errorf("删除索引 %s 失败: %w", name, err) } } return nil }