package database import ( "context" "fmt" "time" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) // Config 数据库配置 type Config struct { Host string Port string User string Password string Name string SSLMode string Timezone string MaxOpenConns int MaxIdleConns int ConnMaxLifetime time.Duration } // DB 数据库包装器 type DB struct { *gorm.DB config Config } // NewConnection 创建新的数据库连接 func NewConnection(config Config) (*DB, error) { // 构建DSN dsn := buildDSN(config) // 配置GORM gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), NamingStrategy: schema.NamingStrategy{ SingularTable: true, // 使用单数表名 }, DisableForeignKeyConstraintWhenMigrating: true, NowFunc: func() time.Time { return time.Now().In(time.FixedZone("CST", 8*3600)) // 强制使用北京时间 }, } // 连接数据库 db, err := gorm.Open(postgres.Open(dsn), gormConfig) if err != nil { return nil, fmt.Errorf("连接数据库失败: %w", err) } // 获取底层sql.DB sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("获取数据库实例失败: %w", err) } // 配置连接池 sqlDB.SetMaxOpenConns(config.MaxOpenConns) sqlDB.SetMaxIdleConns(config.MaxIdleConns) sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime) // 测试连接 if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("数据库连接测试失败: %w", err) } return &DB{ DB: db, config: config, }, nil } // buildDSN 构建数据库连接字符串 func buildDSN(config Config) string { return fmt.Sprintf( "host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s options='-c timezone=%s'", config.Host, config.User, config.Password, config.Name, config.Port, config.SSLMode, config.Timezone, config.Timezone, ) } // Close 关闭数据库连接 func (db *DB) Close() error { sqlDB, err := db.DB.DB() if err != nil { return err } return sqlDB.Close() } // Ping 检查数据库连接 func (db *DB) Ping() error { sqlDB, err := db.DB.DB() if err != nil { return err } return sqlDB.Ping() } // GetStats 获取连接池统计信息 func (db *DB) GetStats() (map[string]interface{}, error) { sqlDB, err := db.DB.DB() if err != nil { return nil, err } stats := sqlDB.Stats() return map[string]interface{}{ "max_open_connections": stats.MaxOpenConnections, "open_connections": stats.OpenConnections, "in_use": stats.InUse, "idle": stats.Idle, "wait_count": stats.WaitCount, "wait_duration": stats.WaitDuration, "max_idle_closed": stats.MaxIdleClosed, "max_idle_time_closed": stats.MaxIdleTimeClosed, "max_lifetime_closed": stats.MaxLifetimeClosed, }, nil } // BeginTx 开始事务 func (db *DB) BeginTx() *gorm.DB { return db.DB.Begin() } // Migrate 执行数据库迁移 func (db *DB) Migrate(models ...interface{}) error { return db.DB.AutoMigrate(models...) } // IsHealthy 检查数据库健康状态 func (db *DB) IsHealthy() bool { return db.Ping() == nil } // WithContext 返回带上下文的数据库实例 func (db *DB) WithContext(ctx interface{}) *gorm.DB { if c, ok := ctx.(context.Context); ok { return db.DB.WithContext(c) } return db.DB } // 事务包装器 type TxWrapper struct { tx *gorm.DB } // NewTxWrapper 创建事务包装器 func (db *DB) NewTxWrapper() *TxWrapper { return &TxWrapper{ tx: db.BeginTx(), } } // Commit 提交事务 func (tx *TxWrapper) Commit() error { return tx.tx.Commit().Error } // Rollback 回滚事务 func (tx *TxWrapper) Rollback() error { return tx.tx.Rollback().Error } // GetDB 获取事务数据库实例 func (tx *TxWrapper) GetDB() *gorm.DB { return tx.tx } // WithTx 在事务中执行函数 func (db *DB) WithTx(fn func(*gorm.DB) error) error { tx := db.BeginTx() defer func() { if r := recover(); r != nil { tx.Rollback() panic(r) } }() if err := fn(tx); err != nil { tx.Rollback() return err } return tx.Commit().Error }