Files
tyapi-server/scripts/update_cost_price.go

161 lines
4.0 KiB
Go
Raw Normal View History

2025-11-19 13:41:41 +08:00
package main
import (
"context"
"encoding/csv"
"fmt"
"os"
"strconv"
"strings"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
func main() {
// 连接数据库
db, err := connectDB()
if err != nil {
fmt.Fprintf(os.Stderr, "连接数据库失败: %v\n", err)
os.Exit(1)
}
// 读取 CSV 文件
csvFile, err := os.Open("成本价.csv")
if err != nil {
fmt.Fprintf(os.Stderr, "打开 CSV 文件失败: %v\n", err)
os.Exit(1)
}
defer csvFile.Close()
reader := csv.NewReader(csvFile)
reader.LazyQuotes = true
reader.TrimLeadingSpace = true
// 读取所有记录
records, err := reader.ReadAll()
if err != nil {
fmt.Fprintf(os.Stderr, "读取 CSV 文件失败: %v\n", err)
os.Exit(1)
}
if len(records) < 2 {
fmt.Fprintf(os.Stderr, "CSV 文件数据不足(需要至少包含表头和数据行)\n")
os.Exit(1)
}
ctx := context.Background()
successCount := 0
failCount := 0
skipCount := 0
fmt.Printf("开始更新成本价...\n")
fmt.Printf("共 %d 条记录(包含表头)\n\n", len(records))
// 从第二行开始处理(跳过表头)
for i := 1; i < len(records); i++ {
record := records[i]
if len(record) < 7 {
fmt.Printf("第 %d 行数据列数不足,跳过\n", i+1)
skipCount++
continue
}
productCode := strings.TrimSpace(record[0])
costPriceStr := strings.TrimSpace(record[6]) // 成本价在第7列索引6
// 跳过产品编号为空的行
if productCode == "" {
fmt.Printf("第 %d 行产品编号为空,跳过\n", i+1)
skipCount++
continue
}
// 如果成本价为空,跳过(不更新)
if costPriceStr == "" {
fmt.Printf("产品 %s: 成本价为空,跳过\n", productCode)
skipCount++
continue
}
// 解析成本价为浮点数
costPrice, err := strconv.ParseFloat(costPriceStr, 64)
if err != nil {
fmt.Printf("产品 %s: 成本价格式错误 (%s),跳过: %v\n", productCode, costPriceStr, err)
skipCount++
continue
}
// 更新数据库
result := db.WithContext(ctx).
Table("product").
Where("code = ? AND deleted_at IS NULL", productCode).
Update("cost_price", costPrice)
if result.Error != nil {
// 如果单数表名失败,尝试复数表名
if strings.Contains(result.Error.Error(), "does not exist") {
result = db.WithContext(ctx).
Table("products").
Where("code = ? AND deleted_at IS NULL", productCode).
Update("cost_price", costPrice)
}
if result.Error != nil {
fmt.Printf("产品 %s: 更新失败 - %v\n", productCode, result.Error)
failCount++
continue
}
}
if result.RowsAffected == 0 {
fmt.Printf("产品 %s: 未找到匹配的记录\n", productCode)
failCount++
} else {
fmt.Printf("产品 %s: 成功更新成本价为 %.2f (影响 %d 行)\n", productCode, costPrice, result.RowsAffected)
successCount++
}
}
fmt.Printf("\n=== 更新完成 ===\n")
fmt.Printf("成功更新: %d 条\n", successCount)
fmt.Printf("更新失败: %d 条\n", failCount)
fmt.Printf("跳过记录: %d 条\n", skipCount)
fmt.Printf("总计处理: %d 条\n", len(records)-1)
}
// connectDB 连接数据库
func connectDB() (*gorm.DB, error) {
// 数据库连接配置
dsn := "host=1.117.67.95 user=tyapi_user password=Pg9mX4kL8nW2rT5y dbname=tyapi port=25010 sslmode=disable TimeZone=Asia/Shanghai"
// 配置GORM使用单数表名与项目配置一致
gormConfig := &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true, // 使用单数表名
},
Logger: logger.Default.LogMode(logger.Info), // 显示 SQL 日志
}
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
// 测试连接
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
}
fmt.Println("数据库连接成功")
return db, nil
}