package database import ( "database/sql" "fmt" "strings" "sync" "time" _ "github.com/go-sql-driver/mysql" "go.uber.org/zap" "uzdb/internal/config" "uzdb/internal/models" ) // MySQLConnection represents a MySQL connection type MySQLConnection struct { db *sql.DB dsn string host string port int database string username string mu sync.RWMutex } // BuildMySQLDSN builds MySQL DSN connection string func BuildMySQLDSN(conn *models.UserConnection, password string) string { // MySQL DSN format: user:pass@tcp(host:port)/dbname?params params := []string{ "parseTime=true", "loc=Local", fmt.Sprintf("timeout=%ds", conn.Timeout), "charset=utf8mb4", "collation=utf8mb4_unicode_ci", } if conn.SSLMode != "" && conn.SSLMode != "disable" { switch conn.SSLMode { case "require": params = append(params, "tls=required") case "verify-ca": params = append(params, "tls=skip-verify") case "verify-full": params = append(params, "tls=skip-verify") } } queryString := strings.Join(params, "&") return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", conn.Username, password, conn.Host, conn.Port, conn.Database, queryString, ) } // NewMySQLConnection creates a new MySQL connection func NewMySQLConnection(conn *models.UserConnection, password string) (*MySQLConnection, error) { dsn := BuildMySQLDSN(conn, password) db, err := sql.Open("mysql", dsn) if err != nil { return nil, fmt.Errorf("failed to open MySQL connection: %w", err) } // Configure connection pool db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) // Test connection if err := db.Ping(); err != nil { db.Close() return nil, fmt.Errorf("failed to ping MySQL: %w", err) } mysqlConn := &MySQLConnection{ db: db, dsn: dsn, host: conn.Host, port: conn.Port, database: conn.Database, username: conn.Username, } config.GetLogger().Info("MySQL connection established", zap.String("host", conn.Host), zap.Int("port", conn.Port), zap.String("database", conn.Database), ) return mysqlConn, nil } // GetDB returns the underlying sql.DB func (m *MySQLConnection) GetDB() *sql.DB { m.mu.RLock() defer m.mu.RUnlock() return m.db } // Close closes the MySQL connection func (m *MySQLConnection) Close() error { m.mu.Lock() defer m.mu.Unlock() if m.db != nil { if err := m.db.Close(); err != nil { return fmt.Errorf("failed to close MySQL connection: %w", err) } m.db = nil config.GetLogger().Info("MySQL connection closed", zap.String("host", m.host), zap.Int("port", m.port), ) } return nil } // IsConnected checks if the connection is alive func (m *MySQLConnection) IsConnected() bool { m.mu.RLock() defer m.mu.RUnlock() if m.db == nil { return false } err := m.db.Ping() return err == nil } // ExecuteQuery executes a SQL query and returns results func (m *MySQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) { startTime := time.Now() db := m.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } rows, err := db.Query(sql, args...) if err != nil { return nil, fmt.Errorf("query execution failed: %w", err) } defer rows.Close() result := &models.QueryResult{ Success: true, Duration: time.Since(startTime).Milliseconds(), } // Get column names columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("failed to get columns: %w", err) } result.Columns = columns // Scan rows for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, fmt.Errorf("failed to scan row: %w", err) } row := make([]interface{}, len(columns)) for i, v := range values { row[i] = convertValue(v) } result.Rows = append(result.Rows, row) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("row iteration error: %w", err) } result.RowCount = int64(len(result.Rows)) return result, nil } // ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.) func (m *MySQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) { startTime := time.Now() db := m.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } res, err := db.Exec(sql, args...) if err != nil { return nil, fmt.Errorf("statement execution failed: %w", err) } rowsAffected, _ := res.RowsAffected() lastInsertID, _ := res.LastInsertId() result := &models.QueryResult{ Success: true, AffectedRows: rowsAffected, Duration: time.Since(startTime).Milliseconds(), Rows: [][]interface{}{{lastInsertID}}, Columns: []string{"last_insert_id"}, } return result, nil } // GetTables returns all tables in the database func (m *MySQLConnection) GetTables(schema string) ([]models.Table, error) { db := m.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } query := ` SELECT TABLE_NAME as table_name, TABLE_TYPE as table_type, TABLE_ROWS as row_count FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? ORDER BY TABLE_NAME ` rows, err := db.Query(query, m.database) if err != nil { return nil, fmt.Errorf("failed to query tables: %w", err) } defer rows.Close() var tables []models.Table for rows.Next() { var t models.Table var rowCount sql.NullInt64 if err := rows.Scan(&t.Name, &t.Type, &rowCount); err != nil { return nil, fmt.Errorf("failed to scan table: %w", err) } if rowCount.Valid { t.RowCount = rowCount.Int64 } tables = append(tables, t) } return tables, nil } // GetTableStructure returns the structure of a table func (m *MySQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) { db := m.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } structure := &models.TableStructure{ TableName: tableName, } // Get columns query := ` SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_KEY, EXTRA, CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION, NUMERIC_SCALE, COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION ` rows, err := db.Query(query, m.database, tableName) if err != nil { return nil, fmt.Errorf("failed to query columns: %w", err) } defer rows.Close() for rows.Next() { var col models.TableColumn var nullable, columnKey, extra string var defaultVal, comment sql.NullString var maxLength, precision, scale sql.NullInt32 if err := rows.Scan( &col.Name, &col.DataType, &nullable, &defaultVal, &columnKey, &extra, &maxLength, &precision, &scale, &comment, ); err != nil { return nil, fmt.Errorf("failed to scan column: %w", err) } col.Nullable = nullable == "YES" col.IsPrimary = columnKey == "PRI" col.IsUnique = columnKey == "UNI" col.AutoIncrement = strings.Contains(extra, "auto_increment") if defaultVal.Valid { col.Default = defaultVal.String } if maxLength.Valid { col.Length = int(maxLength.Int32) } if precision.Valid { col.Scale = int(precision.Int32) } if scale.Valid { col.Scale = int(scale.Int32) } if comment.Valid { col.Comment = comment.String } structure.Columns = append(structure.Columns, col) } // Get indexes indexQuery := ` SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY INDEX_NAME, SEQ_IN_INDEX ` idxRows, err := db.Query(indexQuery, m.database, tableName) if err != nil { config.GetLogger().Warn("failed to query indexes", zap.Error(err)) } else { defer idxRows.Close() indexMap := make(map[string]*models.TableIndex) for idxRows.Next() { var indexName, columnName string var nonUnique bool var seqInIndex int if err := idxRows.Scan(&indexName, &columnName, &nonUnique, &seqInIndex); err != nil { continue } idx, exists := indexMap[indexName] if !exists { idx = &models.TableIndex{ Name: indexName, IsUnique: !nonUnique, IsPrimary: indexName == "PRIMARY", Columns: []string{}, } indexMap[indexName] = idx } idx.Columns = append(idx.Columns, columnName) } for _, idx := range indexMap { structure.Indexes = append(structure.Indexes, *idx) } } return structure, nil } // GetMetadata returns database metadata func (m *MySQLConnection) GetMetadata() (*models.DBMetadata, error) { db := m.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } metadata := &models.DBMetadata{ Database: m.database, User: m.username, Host: m.host, Port: m.port, } // Get MySQL version var version string err := db.QueryRow("SELECT VERSION()").Scan(&version) if err == nil { metadata.Version = version } // Get server time var serverTime string err = db.QueryRow("SELECT NOW()").Scan(&serverTime) if err == nil { metadata.ServerTime = serverTime } return metadata, nil } // convertValue converts database value to interface{} func convertValue(val interface{}) interface{} { if val == nil { return nil } // Handle []byte (common from MySQL driver) if b, ok := val.([]byte); ok { return string(b) } return val }