2024-07-23 10:23:43 +08:00

78 lines
1.7 KiB
Go

package database
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"time"
)
var _ MongoDB = (*mongoDB)(nil)
type MongoDB interface {
i()
GetDB() *mongo.Database
Close() error
}
type MongoDBConfig struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
Timeout time.Duration `yaml:"timeout"`
}
type mongoDB struct {
client *mongo.Client
db *mongo.Database
timeout time.Duration
}
func (m *mongoDB) i() {}
func NewMongoDB(cfg MongoDBConfig) (MongoDB, error) {
timeout := cfg.Timeout * time.Second
connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
defer connectCancelFunc()
var auth string
if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
}
client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
))
if err != nil {
return nil, err
}
pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
defer pingCancelFunc()
err = client.Ping(pingCtx, readpref.Primary())
if err != nil {
return nil, err
}
return &mongoDB{
client: client,
db: client.Database(cfg.Name),
timeout: timeout,
}, nil
}
func (m *mongoDB) GetDB() *mongo.Database {
return m.db
}
func (m *mongoDB) Close() error {
disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
defer disconnectCancelFunc()
err := m.client.Disconnect(disconnectCtx)
if err != nil {
return err
}
return nil
}