115 lines
2.4 KiB
Go
115 lines
2.4 KiB
Go
package database
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// DB 数据库连接
|
|
var DB *gorm.DB
|
|
|
|
// Init 初始化数据库
|
|
func Init(dataDir string) error {
|
|
// 确保数据目录存在
|
|
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
|
return fmt.Errorf("failed to create data directory %s: %w", dataDir, err)
|
|
}
|
|
|
|
dbPath := filepath.Join(dataDir, "cluster.db")
|
|
|
|
// 检查数据库文件是否存在
|
|
dbExists := false
|
|
if _, err := os.Stat(dbPath); err == nil {
|
|
dbExists = true
|
|
log.Printf("Database file already exists: %s", dbPath)
|
|
} else if !os.IsNotExist(err) {
|
|
return fmt.Errorf("failed to check database file: %w", err)
|
|
} else {
|
|
log.Printf("Creating new database file: %s", dbPath)
|
|
}
|
|
|
|
// 配置 GORM logger
|
|
gormLogger := logger.Default
|
|
if os.Getenv("GORM_LOG_LEVEL") == "silent" {
|
|
gormLogger = gormLogger.LogMode(logger.Silent)
|
|
}
|
|
|
|
// 打开数据库连接
|
|
var err error
|
|
DB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
|
Logger: gormLogger,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
// 获取底层 sql.DB 以设置连接池
|
|
sqlDB, err := DB.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get database instance: %w", err)
|
|
}
|
|
|
|
// 设置连接池参数
|
|
sqlDB.SetMaxIdleConns(10)
|
|
sqlDB.SetMaxOpenConns(100)
|
|
sqlDB.SetConnMaxLifetime(time.Hour)
|
|
|
|
// 启用外键约束
|
|
if err := DB.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
|
return fmt.Errorf("failed to enable foreign keys: %w", err)
|
|
}
|
|
|
|
// 自动迁移表结构
|
|
if err := autoMigrate(); err != nil {
|
|
return fmt.Errorf("failed to migrate tables: %w", err)
|
|
}
|
|
|
|
if !dbExists {
|
|
log.Printf("Database initialized successfully: %s", dbPath)
|
|
} else {
|
|
log.Printf("Database tables verified: %s", dbPath)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close 关闭数据库连接
|
|
func Close() error {
|
|
if DB != nil {
|
|
sqlDB, err := DB.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return sqlDB.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// autoMigrate 自动迁移表结构
|
|
func autoMigrate() error {
|
|
models := []interface{}{
|
|
&Repository{},
|
|
&Manifest{},
|
|
&Blob{},
|
|
&Tag{},
|
|
&BlobUpload{},
|
|
&ManifestBlob{},
|
|
}
|
|
|
|
for _, model := range models {
|
|
if err := DB.AutoMigrate(model); err != nil {
|
|
return fmt.Errorf("failed to migrate model %T: %w", model, err)
|
|
}
|
|
}
|
|
|
|
log.Printf("Database migration completed successfully")
|
|
return nil
|
|
}
|