package database import ( "database/sql" "fmt" "strings" "sync" "time" _ "github.com/lib/pq" "go.uber.org/zap" "uzdb/internal/config" "uzdb/internal/models" ) // PostgreSQLConnection represents a PostgreSQL connection type PostgreSQLConnection struct { db *sql.DB dsn string host string port int database string username string mu sync.RWMutex } // BuildPostgreSQLDSN builds PostgreSQL DSN connection string func BuildPostgreSQLDSN(conn *models.UserConnection, password string) string { // PostgreSQL DSN format: postgres://user:pass@host:port/dbname?params params := []string{ fmt.Sprintf("connect_timeout=%d", conn.Timeout), "sslmode=disable", } if conn.SSLMode != "" { switch conn.SSLMode { case "require": params = append(params, "sslmode=require") case "verify-ca": params = append(params, "sslmode=verify-ca") case "verify-full": params = append(params, "sslmode=verify-full") case "disable": params = append(params, "sslmode=disable") default: params = append(params, fmt.Sprintf("sslmode=%s", conn.SSLMode)) } } return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?%s", conn.Username, password, conn.Host, conn.Port, conn.Database, strings.Join(params, "&"), ) } // NewPostgreSQLConnection creates a new PostgreSQL connection func NewPostgreSQLConnection(conn *models.UserConnection, password string) (*PostgreSQLConnection, error) { dsn := BuildPostgreSQLDSN(conn, password) db, err := sql.Open("postgres", dsn) if err != nil { return nil, fmt.Errorf("failed to open PostgreSQL connection: %w", err) } // Configure connection pool db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) // Test connection if err := db.Ping(); err != nil { db.Close() return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err) } pgConn := &PostgreSQLConnection{ db: db, dsn: dsn, host: conn.Host, port: conn.Port, database: conn.Database, username: conn.Username, } config.GetLogger().Info("PostgreSQL connection established", zap.String("host", conn.Host), zap.Int("port", conn.Port), zap.String("database", conn.Database), ) return pgConn, nil } // GetDB returns the underlying sql.DB func (p *PostgreSQLConnection) GetDB() *sql.DB { p.mu.RLock() defer p.mu.RUnlock() return p.db } // Close closes the PostgreSQL connection func (p *PostgreSQLConnection) Close() error { p.mu.Lock() defer p.mu.Unlock() if p.db != nil { if err := p.db.Close(); err != nil { return fmt.Errorf("failed to close PostgreSQL connection: %w", err) } p.db = nil config.GetLogger().Info("PostgreSQL connection closed", zap.String("host", p.host), zap.Int("port", p.port), ) } return nil } // IsConnected checks if the connection is alive func (p *PostgreSQLConnection) IsConnected() bool { p.mu.RLock() defer p.mu.RUnlock() if p.db == nil { return false } err := p.db.Ping() return err == nil } // ExecuteQuery executes a SQL query and returns results func (p *PostgreSQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) { startTime := time.Now() db := p.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } rows, err := db.Query(sql, args...) if err != nil { return nil, fmt.Errorf("query execution failed: %w", err) } defer rows.Close() result := &models.QueryResult{ Success: true, Duration: time.Since(startTime).Milliseconds(), } // Get column names columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("failed to get columns: %w", err) } result.Columns = columns // Scan rows for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { return nil, fmt.Errorf("failed to scan row: %w", err) } row := make([]interface{}, len(columns)) for i, v := range values { row[i] = convertValue(v) } result.Rows = append(result.Rows, row) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("row iteration error: %w", err) } result.RowCount = int64(len(result.Rows)) return result, nil } // ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.) func (p *PostgreSQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) { startTime := time.Now() db := p.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } res, err := db.Exec(sql, args...) if err != nil { return nil, fmt.Errorf("statement execution failed: %w", err) } rowsAffected, _ := res.RowsAffected() result := &models.QueryResult{ Success: true, AffectedRows: rowsAffected, Duration: time.Since(startTime).Milliseconds(), } return result, nil } // GetTables returns all tables in the database func (p *PostgreSQLConnection) GetTables(schema string) ([]models.Table, error) { db := p.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } if schema == "" { schema = "public" } query := ` SELECT table_name, table_type, (SELECT COUNT(*) FROM information_schema.columns c WHERE c.table_name = t.table_name AND c.table_schema = t.table_schema) as column_count FROM information_schema.tables t WHERE table_schema = $1 ORDER BY table_name ` rows, err := db.Query(query, schema) if err != nil { return nil, fmt.Errorf("failed to query tables: %w", err) } defer rows.Close() var tables []models.Table for rows.Next() { var t models.Table var columnCount int t.Schema = schema if err := rows.Scan(&t.Name, &t.Type, &columnCount); err != nil { return nil, fmt.Errorf("failed to scan table: %w", err) } tables = append(tables, t) } return tables, nil } // GetTableStructure returns the structure of a table func (p *PostgreSQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) { db := p.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } structure := &models.TableStructure{ TableName: tableName, Schema: "public", } // Get columns query := ` SELECT c.column_name, c.data_type, c.is_nullable, c.column_default, c.character_maximum_length, c.numeric_precision, c.numeric_scale, pg_catalog.col_description(c.oid, c.ordinal_position) as comment FROM information_schema.columns c JOIN pg_class cl ON cl.relname = c.table_name WHERE c.table_schema = 'public' AND c.table_name = $1 ORDER BY c.ordinal_position ` rows, err := db.Query(query, tableName) if err != nil { return nil, fmt.Errorf("failed to query columns: %w", err) } defer rows.Close() for rows.Next() { var col models.TableColumn var nullable string var defaultVal, comment sql.NullString var maxLength, precision, scale sql.NullInt32 if err := rows.Scan( &col.Name, &col.DataType, &nullable, &defaultVal, &maxLength, &precision, &scale, &comment, ); err != nil { return nil, fmt.Errorf("failed to scan column: %w", err) } col.Nullable = nullable == "YES" if defaultVal.Valid { col.Default = defaultVal.String } if maxLength.Valid && maxLength.Int32 > 0 { col.Length = int(maxLength.Int32) } if precision.Valid { col.Scale = int(precision.Int32) } if scale.Valid { col.Scale = int(scale.Int32) } if comment.Valid { col.Comment = comment.String } structure.Columns = append(structure.Columns, col) } // Get primary key information pkQuery := ` SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indrelid = $1::regclass AND i.indisprimary ` pkRows, err := db.Query(pkQuery, tableName) if err != nil { config.GetLogger().Warn("failed to query primary keys", zap.Error(err)) } else { defer pkRows.Close() pkColumns := make(map[string]bool) for pkRows.Next() { var colName string if err := pkRows.Scan(&colName); err != nil { continue } pkColumns[colName] = true } // Mark primary key columns for i := range structure.Columns { if pkColumns[structure.Columns[i].Name] { structure.Columns[i].IsPrimary = true } } } // Get indexes indexQuery := ` SELECT i.indexname, i.indexdef, i.indisunique FROM pg_indexes i WHERE i.schemaname = 'public' AND i.tablename = $1 ` idxRows, err := db.Query(indexQuery, tableName) if err != nil { config.GetLogger().Warn("failed to query indexes", zap.Error(err)) } else { defer idxRows.Close() for idxRows.Next() { var idx models.TableIndex var indexDef string if err := idxRows.Scan(&idx.Name, &indexDef, &idx.IsUnique); err != nil { continue } idx.IsPrimary = idx.Name == (tableName + "_pkey") // Extract column names from index definition idx.Columns = extractColumnsFromIndexDef(indexDef) idx.Type = "btree" // Default for PostgreSQL structure.Indexes = append(structure.Indexes, idx) } } return structure, nil } // GetMetadata returns database metadata func (p *PostgreSQLConnection) GetMetadata() (*models.DBMetadata, error) { db := p.GetDB() if db == nil { return nil, fmt.Errorf("connection is closed") } metadata := &models.DBMetadata{ Database: p.database, User: p.username, Host: p.host, Port: p.port, } // Get PostgreSQL version var version string err := db.QueryRow("SELECT version()").Scan(&version) if err == nil { metadata.Version = version } // Get server time var serverTime string err = db.QueryRow("SELECT NOW()").Scan(&serverTime) if err == nil { metadata.ServerTime = serverTime } return metadata, nil } // extractColumnsFromIndexDef extracts column names from PostgreSQL index definition func extractColumnsFromIndexDef(indexDef string) []string { var columns []string // Simple extraction - look for content between parentheses start := strings.Index(indexDef, "(") end := strings.LastIndex(indexDef, ")") if start != -1 && end != -1 && end > start { content := indexDef[start+1 : end] parts := strings.Split(content, ",") for _, part := range parts { col := strings.TrimSpace(part) // Remove any expressions like "LOWER(column)" if parenIdx := strings.Index(col, "("); parenIdx != -1 { continue } if col != "" { columns = append(columns, col) } } } return columns }