161 lines
3.6 KiB
Go
161 lines
3.6 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/logger"
|
||
"gorm.io/gorm/schema"
|
||
)
|
||
|
||
// Config 数据库配置
|
||
type Config struct {
|
||
Host string
|
||
Port string
|
||
User string
|
||
Password string
|
||
Name string
|
||
SSLMode string
|
||
Timezone string
|
||
MaxOpenConns int
|
||
MaxIdleConns int
|
||
ConnMaxLifetime time.Duration
|
||
}
|
||
|
||
// DB 数据库包装器
|
||
type DB struct {
|
||
*gorm.DB
|
||
config Config
|
||
}
|
||
|
||
// NewConnection 创建新的数据库连接
|
||
func NewConnection(config Config) (*DB, error) {
|
||
// 构建DSN
|
||
dsn := buildDSN(config)
|
||
|
||
// 配置GORM
|
||
gormConfig := &gorm.Config{
|
||
Logger: logger.Default.LogMode(logger.Info),
|
||
NamingStrategy: schema.NamingStrategy{
|
||
SingularTable: true, // 使用单数表名
|
||
},
|
||
DisableForeignKeyConstraintWhenMigrating: true,
|
||
NowFunc: func() time.Time {
|
||
return time.Now().In(time.FixedZone("CST", 8*3600)) // 强制使用北京时间
|
||
},
|
||
PrepareStmt: true,
|
||
DisableAutomaticPing: false,
|
||
}
|
||
|
||
// 连接数据库
|
||
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||
}
|
||
|
||
// 获取底层sql.DB
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||
}
|
||
|
||
// 配置连接池
|
||
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
||
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
||
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||
|
||
// 测试连接
|
||
if err := sqlDB.Ping(); err != nil {
|
||
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
||
}
|
||
|
||
return &DB{
|
||
DB: db,
|
||
config: config,
|
||
}, nil
|
||
}
|
||
|
||
// buildDSN 构建数据库连接字符串
|
||
func buildDSN(config Config) string {
|
||
return fmt.Sprintf(
|
||
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s options='-c timezone=%s'",
|
||
config.Host,
|
||
config.User,
|
||
config.Password,
|
||
config.Name,
|
||
config.Port,
|
||
config.SSLMode,
|
||
config.Timezone,
|
||
config.Timezone,
|
||
)
|
||
}
|
||
|
||
// Close 关闭数据库连接
|
||
func (db *DB) Close() error {
|
||
sqlDB, err := db.DB.DB()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return sqlDB.Close()
|
||
}
|
||
|
||
// Ping 检查数据库连接
|
||
func (db *DB) Ping() error {
|
||
sqlDB, err := db.DB.DB()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return sqlDB.Ping()
|
||
}
|
||
|
||
// GetStats 获取连接池统计信息
|
||
func (db *DB) GetStats() (map[string]interface{}, error) {
|
||
sqlDB, err := db.DB.DB()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
stats := sqlDB.Stats()
|
||
return map[string]interface{}{
|
||
"max_open_connections": stats.MaxOpenConnections,
|
||
"open_connections": stats.OpenConnections,
|
||
"in_use": stats.InUse,
|
||
"idle": stats.Idle,
|
||
"wait_count": stats.WaitCount,
|
||
"wait_duration": stats.WaitDuration,
|
||
"max_idle_closed": stats.MaxIdleClosed,
|
||
"max_idle_time_closed": stats.MaxIdleTimeClosed,
|
||
"max_lifetime_closed": stats.MaxLifetimeClosed,
|
||
}, nil
|
||
}
|
||
|
||
// BeginTx 开始事务(已废弃,请使用shared/database.TransactionManager)
|
||
// @deprecated 请使用 shared/database.TransactionManager
|
||
func (db *DB) BeginTx() *gorm.DB {
|
||
return db.DB.Begin()
|
||
}
|
||
|
||
// Migrate 执行数据库迁移
|
||
func (db *DB) Migrate(models ...interface{}) error {
|
||
return db.DB.AutoMigrate(models...)
|
||
}
|
||
|
||
// IsHealthy 检查数据库健康状态
|
||
func (db *DB) IsHealthy() bool {
|
||
return db.Ping() == nil
|
||
}
|
||
|
||
// WithContext 返回带上下文的数据库实例
|
||
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
|
||
if c, ok := ctx.(context.Context); ok {
|
||
return db.DB.WithContext(c)
|
||
}
|
||
return db.DB
|
||
}
|
||
|
||
// 注意:事务相关功能已迁移到 shared/database.TransactionManager
|
||
// 请使用 TransactionManager 进行事务管理
|