196 lines
3.9 KiB
Go
196 lines
3.9 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
"gorm.io/gorm/schema"
|
|
)
|
|
|
|
// Config 数据库配置
|
|
type Config struct {
|
|
Host string
|
|
Port string
|
|
User string
|
|
Password string
|
|
Name string
|
|
SSLMode string
|
|
Timezone string
|
|
MaxOpenConns int
|
|
MaxIdleConns int
|
|
ConnMaxLifetime time.Duration
|
|
}
|
|
|
|
// DB 数据库包装器
|
|
type DB struct {
|
|
*gorm.DB
|
|
config Config
|
|
}
|
|
|
|
// NewConnection 创建新的数据库连接
|
|
func NewConnection(config Config) (*DB, error) {
|
|
// 构建DSN
|
|
dsn := buildDSN(config)
|
|
|
|
// 配置GORM
|
|
gormConfig := &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Info),
|
|
NamingStrategy: schema.NamingStrategy{
|
|
SingularTable: true, // 使用单数表名
|
|
},
|
|
DisableForeignKeyConstraintWhenMigrating: true,
|
|
}
|
|
|
|
// 连接数据库
|
|
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
|
}
|
|
|
|
// 获取底层sql.DB
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
|
}
|
|
|
|
// 配置连接池
|
|
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
|
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
|
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
|
|
|
|
// 测试连接
|
|
if err := sqlDB.Ping(); err != nil {
|
|
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
|
}
|
|
|
|
return &DB{
|
|
DB: db,
|
|
config: config,
|
|
}, nil
|
|
}
|
|
|
|
// buildDSN 构建数据库连接字符串
|
|
func buildDSN(config Config) string {
|
|
return fmt.Sprintf(
|
|
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s",
|
|
config.Host,
|
|
config.User,
|
|
config.Password,
|
|
config.Name,
|
|
config.Port,
|
|
config.SSLMode,
|
|
config.Timezone,
|
|
)
|
|
}
|
|
|
|
// Close 关闭数据库连接
|
|
func (db *DB) Close() error {
|
|
sqlDB, err := db.DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return sqlDB.Close()
|
|
}
|
|
|
|
// Ping 检查数据库连接
|
|
func (db *DB) Ping() error {
|
|
sqlDB, err := db.DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return sqlDB.Ping()
|
|
}
|
|
|
|
// GetStats 获取连接池统计信息
|
|
func (db *DB) GetStats() (map[string]interface{}, error) {
|
|
sqlDB, err := db.DB.DB()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stats := sqlDB.Stats()
|
|
return map[string]interface{}{
|
|
"max_open_connections": stats.MaxOpenConnections,
|
|
"open_connections": stats.OpenConnections,
|
|
"in_use": stats.InUse,
|
|
"idle": stats.Idle,
|
|
"wait_count": stats.WaitCount,
|
|
"wait_duration": stats.WaitDuration,
|
|
"max_idle_closed": stats.MaxIdleClosed,
|
|
"max_idle_time_closed": stats.MaxIdleTimeClosed,
|
|
"max_lifetime_closed": stats.MaxLifetimeClosed,
|
|
}, nil
|
|
}
|
|
|
|
// BeginTx 开始事务
|
|
func (db *DB) BeginTx() *gorm.DB {
|
|
return db.DB.Begin()
|
|
}
|
|
|
|
// Migrate 执行数据库迁移
|
|
func (db *DB) Migrate(models ...interface{}) error {
|
|
return db.DB.AutoMigrate(models...)
|
|
}
|
|
|
|
// IsHealthy 检查数据库健康状态
|
|
func (db *DB) IsHealthy() bool {
|
|
return db.Ping() == nil
|
|
}
|
|
|
|
// WithContext 返回带上下文的数据库实例
|
|
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
|
|
if c, ok := ctx.(context.Context); ok {
|
|
return db.DB.WithContext(c)
|
|
}
|
|
return db.DB
|
|
}
|
|
|
|
// 事务包装器
|
|
type TxWrapper struct {
|
|
tx *gorm.DB
|
|
}
|
|
|
|
// NewTxWrapper 创建事务包装器
|
|
func (db *DB) NewTxWrapper() *TxWrapper {
|
|
return &TxWrapper{
|
|
tx: db.BeginTx(),
|
|
}
|
|
}
|
|
|
|
// Commit 提交事务
|
|
func (tx *TxWrapper) Commit() error {
|
|
return tx.tx.Commit().Error
|
|
}
|
|
|
|
// Rollback 回滚事务
|
|
func (tx *TxWrapper) Rollback() error {
|
|
return tx.tx.Rollback().Error
|
|
}
|
|
|
|
// GetDB 获取事务数据库实例
|
|
func (tx *TxWrapper) GetDB() *gorm.DB {
|
|
return tx.tx
|
|
}
|
|
|
|
// WithTx 在事务中执行函数
|
|
func (db *DB) WithTx(fn func(*gorm.DB) error) error {
|
|
tx := db.BeginTx()
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
tx.Rollback()
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
return tx.Commit().Error
|
|
}
|