265 lines
5.9 KiB
Go
265 lines
5.9 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.mongodb.org/mongo-driver/mongo/options"
|
|
"go.mongodb.org/mongo-driver/mongo/readpref"
|
|
)
|
|
|
|
// MongoDB MongoDB接口
|
|
type MongoDB interface {
|
|
// GetDB 获取数据库实例
|
|
GetDB() *mongo.Database
|
|
|
|
// GetClient 获取客户端实例
|
|
GetClient() *mongo.Client
|
|
|
|
// GetCollection 获取集合
|
|
GetCollection(name string) *mongo.Collection
|
|
|
|
// Ping 检查连接
|
|
Ping(ctx context.Context) error
|
|
|
|
// Close 关闭连接
|
|
Close(ctx context.Context) error
|
|
|
|
// WithContext 创建带超时的上下文
|
|
WithContext(ctx context.Context) (context.Context, context.CancelFunc)
|
|
}
|
|
|
|
// MongoDBConfig MongoDB配置
|
|
type MongoDBConfig struct {
|
|
// 地址,支持多个地址: "localhost:27017,localhost:27018"
|
|
Addr string `yaml:"addr" json:"addr"`
|
|
|
|
// 用户名
|
|
User string `yaml:"user" json:"user"`
|
|
|
|
// 密码
|
|
Pass string `yaml:"pass" json:"pass"`
|
|
|
|
// 数据库名称
|
|
Name string `yaml:"name" json:"name"`
|
|
|
|
// 连接超时(秒)
|
|
Timeout time.Duration `yaml:"timeout" json:"timeout"`
|
|
|
|
// 最大连接池大小
|
|
MaxPoolSize uint64 `yaml:"max_pool_size" json:"max_pool_size"`
|
|
|
|
// 最小连接池大小
|
|
MinPoolSize uint64 `yaml:"min_pool_size" json:"min_pool_size"`
|
|
|
|
// 最大连接空闲时间(秒)
|
|
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time" json:"max_conn_idle_time"`
|
|
|
|
// 是否使用副本集
|
|
ReplicaSet string `yaml:"replica_set" json:"replica_set"`
|
|
|
|
// 是否使用TLS
|
|
UseTLS bool `yaml:"use_tls" json:"use_tls"`
|
|
|
|
// 认证数据库
|
|
AuthSource string `yaml:"auth_source" json:"auth_source"`
|
|
}
|
|
|
|
// DefaultMongoDBConfig 默认配置
|
|
func DefaultMongoDBConfig() *MongoDBConfig {
|
|
return &MongoDBConfig{
|
|
Timeout: 10, // 10秒
|
|
MaxPoolSize: 100,
|
|
MinPoolSize: 10,
|
|
MaxConnIdleTime: 60, // 60秒
|
|
AuthSource: "admin",
|
|
}
|
|
}
|
|
|
|
// mongoDB MongoDB实现
|
|
type mongoDB struct {
|
|
client *mongo.Client
|
|
db *mongo.Database
|
|
config *MongoDBConfig
|
|
timeout time.Duration
|
|
}
|
|
|
|
// NewMongoDB 创建MongoDB实例
|
|
func NewMongoDB(cfg *MongoDBConfig) (MongoDB, error) {
|
|
if cfg == nil {
|
|
return nil, fmt.Errorf("配置不能为空")
|
|
}
|
|
|
|
// 验证必填参数
|
|
if cfg.Addr == "" {
|
|
return nil, fmt.Errorf("地址不能为空")
|
|
}
|
|
if cfg.Name == "" {
|
|
return nil, fmt.Errorf("数据库名称不能为空")
|
|
}
|
|
|
|
timeout := cfg.Timeout * time.Second
|
|
if timeout == 0 {
|
|
timeout = 10 * time.Second
|
|
}
|
|
|
|
// 构建连接选项
|
|
clientOpts := options.Client()
|
|
|
|
// 构建URI
|
|
uri := buildMongoURI(cfg)
|
|
clientOpts.ApplyURI(uri)
|
|
|
|
// 设置连接池
|
|
if cfg.MaxPoolSize > 0 {
|
|
clientOpts.SetMaxPoolSize(cfg.MaxPoolSize)
|
|
}
|
|
if cfg.MinPoolSize > 0 {
|
|
clientOpts.SetMinPoolSize(cfg.MinPoolSize)
|
|
}
|
|
if cfg.MaxConnIdleTime > 0 {
|
|
clientOpts.SetMaxConnIdleTime(cfg.MaxConnIdleTime * time.Second)
|
|
}
|
|
|
|
// 设置超时
|
|
clientOpts.SetConnectTimeout(timeout)
|
|
clientOpts.SetServerSelectionTimeout(timeout)
|
|
|
|
// 连接MongoDB
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
client, err := mongo.Connect(ctx, clientOpts)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("连接MongoDB失败: %w", err)
|
|
}
|
|
|
|
// Ping测试连接
|
|
pingCtx, pingCancel := context.WithTimeout(context.Background(), timeout)
|
|
defer pingCancel()
|
|
|
|
if err = client.Ping(pingCtx, readpref.Primary()); err != nil {
|
|
_ = client.Disconnect(context.Background())
|
|
return nil, fmt.Errorf("Ping MongoDB失败: %w", err)
|
|
}
|
|
|
|
return &mongoDB{
|
|
client: client,
|
|
db: client.Database(cfg.Name),
|
|
config: cfg,
|
|
timeout: timeout,
|
|
}, nil
|
|
}
|
|
|
|
// GetDB 获取数据库实例
|
|
func (m *mongoDB) GetDB() *mongo.Database {
|
|
return m.db
|
|
}
|
|
|
|
// GetClient 获取客户端实例
|
|
func (m *mongoDB) GetClient() *mongo.Client {
|
|
return m.client
|
|
}
|
|
|
|
// GetCollection 获取集合
|
|
func (m *mongoDB) GetCollection(name string) *mongo.Collection {
|
|
return m.db.Collection(name)
|
|
}
|
|
|
|
// Ping 检查连接
|
|
func (m *mongoDB) Ping(ctx context.Context) error {
|
|
pingCtx, cancel := m.WithContext(ctx)
|
|
defer cancel()
|
|
|
|
return m.client.Ping(pingCtx, readpref.Primary())
|
|
}
|
|
|
|
// Close 关闭连接
|
|
func (m *mongoDB) Close(ctx context.Context) error {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
|
|
disconnectCtx, cancel := context.WithTimeout(ctx, m.timeout)
|
|
defer cancel()
|
|
|
|
if err := m.client.Disconnect(disconnectCtx); err != nil {
|
|
return fmt.Errorf("断开MongoDB连接失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// WithContext 创建带超时的上下文
|
|
func (m *mongoDB) WithContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
return context.WithTimeout(ctx, m.timeout)
|
|
}
|
|
|
|
// buildMongoURI 构建MongoDB URI
|
|
func buildMongoURI(cfg *MongoDBConfig) string {
|
|
var auth string
|
|
if cfg.User != "" && cfg.Pass != "" {
|
|
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
|
|
}
|
|
|
|
uri := fmt.Sprintf("mongodb://%s%s/%s", auth, cfg.Addr, cfg.Name)
|
|
|
|
// 添加查询参数
|
|
params := []string{}
|
|
|
|
if cfg.AuthSource != "" {
|
|
params = append(params, fmt.Sprintf("authSource=%s", cfg.AuthSource))
|
|
}
|
|
|
|
if cfg.ReplicaSet != "" {
|
|
params = append(params, fmt.Sprintf("replicaSet=%s", cfg.ReplicaSet))
|
|
}
|
|
|
|
if cfg.UseTLS {
|
|
params = append(params, "tls=true")
|
|
}
|
|
|
|
if len(params) > 0 {
|
|
uri += "?"
|
|
for i, param := range params {
|
|
if i > 0 {
|
|
uri += "&"
|
|
}
|
|
uri += param
|
|
}
|
|
}
|
|
|
|
return uri
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// EnsureIndexes 确保索引存在
|
|
func EnsureIndexes(ctx context.Context, collection *mongo.Collection, indexes []mongo.IndexModel) error {
|
|
if len(indexes) == 0 {
|
|
return nil
|
|
}
|
|
|
|
_, err := collection.Indexes().CreateMany(ctx, indexes)
|
|
if err != nil {
|
|
return fmt.Errorf("创建索引失败: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DropIndexes 删除索引
|
|
func DropIndexes(ctx context.Context, collection *mongo.Collection, indexNames []string) error {
|
|
for _, name := range indexNames {
|
|
if _, err := collection.Indexes().DropOne(ctx, name); err != nil {
|
|
return fmt.Errorf("删除索引 %s 失败: %w", name, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|