refactor: Flatten directory structure
Move project files from uzdb/ subdirectory to root directory for cleaner project structure.
Changes:
- Move frontend/ to root
- Move internal/ to root
- Move build/ to root
- Move all config files (go.mod, wails.json, etc.) to root
- Remove redundant uzdb/ subdirectory nesting
Project structure is now:
├── frontend/ # React application
├── internal/ # Go backend
├── build/ # Wails build assets
├── doc/ # Design documentation
├── main.go # Entry point
└── ...
🤖 Generated with Qoder
This commit is contained in:
137
internal/database/manager.go
Normal file
137
internal/database/manager.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// ConnectionManager manages database connections for different database types
|
||||
type ConnectionManager struct {
|
||||
connections map[string]DatabaseConnector
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// DatabaseConnector is the interface for all database connections
|
||||
type DatabaseConnector interface {
|
||||
GetDB() *sql.DB
|
||||
Close() error
|
||||
IsConnected() bool
|
||||
ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error)
|
||||
ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error)
|
||||
GetTables(schema string) ([]models.Table, error)
|
||||
GetTableStructure(tableName string) (*models.TableStructure, error)
|
||||
GetMetadata() (*models.DBMetadata, error)
|
||||
}
|
||||
|
||||
// NewConnectionManager creates a new connection manager
|
||||
func NewConnectionManager() *ConnectionManager {
|
||||
return &ConnectionManager{
|
||||
connections: make(map[string]DatabaseConnector),
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnection gets or creates a connection for a user connection config
|
||||
func (m *ConnectionManager) GetConnection(conn *models.UserConnection, password string) (DatabaseConnector, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if connection already exists
|
||||
if existing, ok := m.connections[conn.ID]; ok {
|
||||
if existing.IsConnected() {
|
||||
return existing, nil
|
||||
}
|
||||
// Connection is dead, remove it
|
||||
existing.Close()
|
||||
delete(m.connections, conn.ID)
|
||||
}
|
||||
|
||||
// Create new connection based on type
|
||||
var connector DatabaseConnector
|
||||
var err error
|
||||
|
||||
switch conn.Type {
|
||||
case models.ConnectionTypeMySQL:
|
||||
connector, err = NewMySQLConnection(conn, password)
|
||||
case models.ConnectionTypePostgreSQL:
|
||||
connector, err = NewPostgreSQLConnection(conn, password)
|
||||
case models.ConnectionTypeSQLite:
|
||||
connector, err = NewSQLiteConnection(conn)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", conn.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.connections[conn.ID] = connector
|
||||
|
||||
config.GetLogger().Info("connection created in manager",
|
||||
zap.String("connection_id", conn.ID),
|
||||
zap.String("type", string(conn.Type)),
|
||||
)
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
// RemoveConnection removes a connection from the manager
|
||||
func (m *ConnectionManager) RemoveConnection(connectionID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if conn, ok := m.connections[connectionID]; ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.connections, connectionID)
|
||||
config.GetLogger().Info("connection removed from manager",
|
||||
zap.String("connection_id", connectionID),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnections returns all managed connections
|
||||
func (m *ConnectionManager) GetAllConnections() map[string]DatabaseConnector {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]DatabaseConnector)
|
||||
for k, v := range m.connections {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CloseAll closes all managed connections
|
||||
func (m *ConnectionManager) CloseAll() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for id, conn := range m.connections {
|
||||
if err := conn.Close(); err != nil {
|
||||
lastErr = err
|
||||
config.GetLogger().Error("failed to close connection",
|
||||
zap.String("connection_id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
m.connections = make(map[string]DatabaseConnector)
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
|
||||
config.GetLogger().Info("all connections closed")
|
||||
return nil
|
||||
}
|
||||
414
internal/database/mysql.go
Normal file
414
internal/database/mysql.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// MySQLConnection represents a MySQL connection
|
||||
type MySQLConnection struct {
|
||||
db *sql.DB
|
||||
dsn string
|
||||
host string
|
||||
port int
|
||||
database string
|
||||
username string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BuildMySQLDSN builds MySQL DSN connection string
|
||||
func BuildMySQLDSN(conn *models.UserConnection, password string) string {
|
||||
// MySQL DSN format: user:pass@tcp(host:port)/dbname?params
|
||||
params := []string{
|
||||
"parseTime=true",
|
||||
"loc=Local",
|
||||
fmt.Sprintf("timeout=%ds", conn.Timeout),
|
||||
"charset=utf8mb4",
|
||||
"collation=utf8mb4_unicode_ci",
|
||||
}
|
||||
|
||||
if conn.SSLMode != "" && conn.SSLMode != "disable" {
|
||||
switch conn.SSLMode {
|
||||
case "require":
|
||||
params = append(params, "tls=required")
|
||||
case "verify-ca":
|
||||
params = append(params, "tls=skip-verify")
|
||||
case "verify-full":
|
||||
params = append(params, "tls=skip-verify")
|
||||
}
|
||||
}
|
||||
|
||||
queryString := strings.Join(params, "&")
|
||||
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s",
|
||||
conn.Username,
|
||||
password,
|
||||
conn.Host,
|
||||
conn.Port,
|
||||
conn.Database,
|
||||
queryString,
|
||||
)
|
||||
}
|
||||
|
||||
// NewMySQLConnection creates a new MySQL connection
|
||||
func NewMySQLConnection(conn *models.UserConnection, password string) (*MySQLConnection, error) {
|
||||
dsn := BuildMySQLDSN(conn, password)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open MySQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping MySQL: %w", err)
|
||||
}
|
||||
|
||||
mysqlConn := &MySQLConnection{
|
||||
db: db,
|
||||
dsn: dsn,
|
||||
host: conn.Host,
|
||||
port: conn.Port,
|
||||
database: conn.Database,
|
||||
username: conn.Username,
|
||||
}
|
||||
|
||||
config.GetLogger().Info("MySQL connection established",
|
||||
zap.String("host", conn.Host),
|
||||
zap.Int("port", conn.Port),
|
||||
zap.String("database", conn.Database),
|
||||
)
|
||||
|
||||
return mysqlConn, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying sql.DB
|
||||
func (m *MySQLConnection) GetDB() *sql.DB {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.db
|
||||
}
|
||||
|
||||
// Close closes the MySQL connection
|
||||
func (m *MySQLConnection) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.db != nil {
|
||||
if err := m.db.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close MySQL connection: %w", err)
|
||||
}
|
||||
m.db = nil
|
||||
config.GetLogger().Info("MySQL connection closed",
|
||||
zap.String("host", m.host),
|
||||
zap.Int("port", m.port),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the connection is alive
|
||||
func (m *MySQLConnection) IsConnected() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := m.db.Ping()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query and returns results
|
||||
func (m *MySQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
rows, err := db.Query(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query execution failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %w", err)
|
||||
}
|
||||
result.Columns = columns
|
||||
|
||||
// Scan rows
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
row := make([]interface{}, len(columns))
|
||||
for i, v := range values {
|
||||
row[i] = convertValue(v)
|
||||
}
|
||||
result.Rows = append(result.Rows, row)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("row iteration error: %w", err)
|
||||
}
|
||||
|
||||
result.RowCount = int64(len(result.Rows))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
|
||||
func (m *MySQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("statement execution failed: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := res.RowsAffected()
|
||||
lastInsertID, _ := res.LastInsertId()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
AffectedRows: rowsAffected,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
Rows: [][]interface{}{{lastInsertID}},
|
||||
Columns: []string{"last_insert_id"},
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTables returns all tables in the database
|
||||
func (m *MySQLConnection) GetTables(schema string) ([]models.Table, error) {
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
TABLE_NAME as table_name,
|
||||
TABLE_TYPE as table_type,
|
||||
TABLE_ROWS as row_count
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = ?
|
||||
ORDER BY TABLE_NAME
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, m.database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []models.Table
|
||||
for rows.Next() {
|
||||
var t models.Table
|
||||
var rowCount sql.NullInt64
|
||||
if err := rows.Scan(&t.Name, &t.Type, &rowCount); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table: %w", err)
|
||||
}
|
||||
if rowCount.Valid {
|
||||
t.RowCount = rowCount.Int64
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (m *MySQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
structure := &models.TableStructure{
|
||||
TableName: tableName,
|
||||
}
|
||||
|
||||
// Get columns
|
||||
query := `
|
||||
SELECT
|
||||
COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT,
|
||||
COLUMN_KEY, EXTRA, CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION,
|
||||
NUMERIC_SCALE, COLUMN_COMMENT
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
ORDER BY ORDINAL_POSITION
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, m.database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var col models.TableColumn
|
||||
var nullable, columnKey, extra string
|
||||
var defaultVal, comment sql.NullString
|
||||
var maxLength, precision, scale sql.NullInt32
|
||||
|
||||
if err := rows.Scan(
|
||||
&col.Name, &col.DataType, &nullable, &defaultVal,
|
||||
&columnKey, &extra, &maxLength, &precision,
|
||||
&scale, &comment,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan column: %w", err)
|
||||
}
|
||||
|
||||
col.Nullable = nullable == "YES"
|
||||
col.IsPrimary = columnKey == "PRI"
|
||||
col.IsUnique = columnKey == "UNI"
|
||||
col.AutoIncrement = strings.Contains(extra, "auto_increment")
|
||||
|
||||
if defaultVal.Valid {
|
||||
col.Default = defaultVal.String
|
||||
}
|
||||
if maxLength.Valid {
|
||||
col.Length = int(maxLength.Int32)
|
||||
}
|
||||
if precision.Valid {
|
||||
col.Scale = int(precision.Int32)
|
||||
}
|
||||
if scale.Valid {
|
||||
col.Scale = int(scale.Int32)
|
||||
}
|
||||
if comment.Valid {
|
||||
col.Comment = comment.String
|
||||
}
|
||||
|
||||
structure.Columns = append(structure.Columns, col)
|
||||
}
|
||||
|
||||
// Get indexes
|
||||
indexQuery := `
|
||||
SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX
|
||||
FROM information_schema.STATISTICS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
ORDER BY INDEX_NAME, SEQ_IN_INDEX
|
||||
`
|
||||
|
||||
idxRows, err := db.Query(indexQuery, m.database, tableName)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
|
||||
} else {
|
||||
defer idxRows.Close()
|
||||
|
||||
indexMap := make(map[string]*models.TableIndex)
|
||||
for idxRows.Next() {
|
||||
var indexName, columnName string
|
||||
var nonUnique bool
|
||||
var seqInIndex int
|
||||
|
||||
if err := idxRows.Scan(&indexName, &columnName, &nonUnique, &seqInIndex); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
idx, exists := indexMap[indexName]
|
||||
if !exists {
|
||||
idx = &models.TableIndex{
|
||||
Name: indexName,
|
||||
IsUnique: !nonUnique,
|
||||
IsPrimary: indexName == "PRIMARY",
|
||||
Columns: []string{},
|
||||
}
|
||||
indexMap[indexName] = idx
|
||||
}
|
||||
idx.Columns = append(idx.Columns, columnName)
|
||||
}
|
||||
|
||||
for _, idx := range indexMap {
|
||||
structure.Indexes = append(structure.Indexes, *idx)
|
||||
}
|
||||
}
|
||||
|
||||
return structure, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns database metadata
|
||||
func (m *MySQLConnection) GetMetadata() (*models.DBMetadata, error) {
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
metadata := &models.DBMetadata{
|
||||
Database: m.database,
|
||||
User: m.username,
|
||||
Host: m.host,
|
||||
Port: m.port,
|
||||
}
|
||||
|
||||
// Get MySQL version
|
||||
var version string
|
||||
err := db.QueryRow("SELECT VERSION()").Scan(&version)
|
||||
if err == nil {
|
||||
metadata.Version = version
|
||||
}
|
||||
|
||||
// Get server time
|
||||
var serverTime string
|
||||
err = db.QueryRow("SELECT NOW()").Scan(&serverTime)
|
||||
if err == nil {
|
||||
metadata.ServerTime = serverTime
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// convertValue converts database value to interface{}
|
||||
func convertValue(val interface{}) interface{} {
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle []byte (common from MySQL driver)
|
||||
if b, ok := val.([]byte); ok {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
449
internal/database/postgres.go
Normal file
449
internal/database/postgres.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// PostgreSQLConnection represents a PostgreSQL connection
|
||||
type PostgreSQLConnection struct {
|
||||
db *sql.DB
|
||||
dsn string
|
||||
host string
|
||||
port int
|
||||
database string
|
||||
username string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BuildPostgreSQLDSN builds PostgreSQL DSN connection string
|
||||
func BuildPostgreSQLDSN(conn *models.UserConnection, password string) string {
|
||||
// PostgreSQL DSN format: postgres://user:pass@host:port/dbname?params
|
||||
params := []string{
|
||||
fmt.Sprintf("connect_timeout=%d", conn.Timeout),
|
||||
"sslmode=disable",
|
||||
}
|
||||
|
||||
if conn.SSLMode != "" {
|
||||
switch conn.SSLMode {
|
||||
case "require":
|
||||
params = append(params, "sslmode=require")
|
||||
case "verify-ca":
|
||||
params = append(params, "sslmode=verify-ca")
|
||||
case "verify-full":
|
||||
params = append(params, "sslmode=verify-full")
|
||||
case "disable":
|
||||
params = append(params, "sslmode=disable")
|
||||
default:
|
||||
params = append(params, fmt.Sprintf("sslmode=%s", conn.SSLMode))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?%s",
|
||||
conn.Username,
|
||||
password,
|
||||
conn.Host,
|
||||
conn.Port,
|
||||
conn.Database,
|
||||
strings.Join(params, "&"),
|
||||
)
|
||||
}
|
||||
|
||||
// NewPostgreSQLConnection creates a new PostgreSQL connection
|
||||
func NewPostgreSQLConnection(conn *models.UserConnection, password string) (*PostgreSQLConnection, error) {
|
||||
dsn := BuildPostgreSQLDSN(conn, password)
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
pgConn := &PostgreSQLConnection{
|
||||
db: db,
|
||||
dsn: dsn,
|
||||
host: conn.Host,
|
||||
port: conn.Port,
|
||||
database: conn.Database,
|
||||
username: conn.Username,
|
||||
}
|
||||
|
||||
config.GetLogger().Info("PostgreSQL connection established",
|
||||
zap.String("host", conn.Host),
|
||||
zap.Int("port", conn.Port),
|
||||
zap.String("database", conn.Database),
|
||||
)
|
||||
|
||||
return pgConn, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying sql.DB
|
||||
func (p *PostgreSQLConnection) GetDB() *sql.DB {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
// Close closes the PostgreSQL connection
|
||||
func (p *PostgreSQLConnection) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.db != nil {
|
||||
if err := p.db.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close PostgreSQL connection: %w", err)
|
||||
}
|
||||
p.db = nil
|
||||
config.GetLogger().Info("PostgreSQL connection closed",
|
||||
zap.String("host", p.host),
|
||||
zap.Int("port", p.port),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the connection is alive
|
||||
func (p *PostgreSQLConnection) IsConnected() bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := p.db.Ping()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query and returns results
|
||||
func (p *PostgreSQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
rows, err := db.Query(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query execution failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %w", err)
|
||||
}
|
||||
result.Columns = columns
|
||||
|
||||
// Scan rows
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
row := make([]interface{}, len(columns))
|
||||
for i, v := range values {
|
||||
row[i] = convertValue(v)
|
||||
}
|
||||
result.Rows = append(result.Rows, row)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("row iteration error: %w", err)
|
||||
}
|
||||
|
||||
result.RowCount = int64(len(result.Rows))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
|
||||
func (p *PostgreSQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("statement execution failed: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := res.RowsAffected()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
AffectedRows: rowsAffected,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTables returns all tables in the database
|
||||
func (p *PostgreSQLConnection) GetTables(schema string) ([]models.Table, error) {
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
table_name,
|
||||
table_type,
|
||||
(SELECT COUNT(*) FROM information_schema.columns c
|
||||
WHERE c.table_name = t.table_name AND c.table_schema = t.table_schema) as column_count
|
||||
FROM information_schema.tables t
|
||||
WHERE table_schema = $1
|
||||
ORDER BY table_name
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []models.Table
|
||||
for rows.Next() {
|
||||
var t models.Table
|
||||
var columnCount int
|
||||
t.Schema = schema
|
||||
if err := rows.Scan(&t.Name, &t.Type, &columnCount); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table: %w", err)
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (p *PostgreSQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
structure := &models.TableStructure{
|
||||
TableName: tableName,
|
||||
Schema: "public",
|
||||
}
|
||||
|
||||
// Get columns
|
||||
query := `
|
||||
SELECT
|
||||
c.column_name,
|
||||
c.data_type,
|
||||
c.is_nullable,
|
||||
c.column_default,
|
||||
c.character_maximum_length,
|
||||
c.numeric_precision,
|
||||
c.numeric_scale,
|
||||
pg_catalog.col_description(c.oid, c.ordinal_position) as comment
|
||||
FROM information_schema.columns c
|
||||
JOIN pg_class cl ON cl.relname = c.table_name
|
||||
WHERE c.table_schema = 'public' AND c.table_name = $1
|
||||
ORDER BY c.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var col models.TableColumn
|
||||
var nullable string
|
||||
var defaultVal, comment sql.NullString
|
||||
var maxLength, precision, scale sql.NullInt32
|
||||
|
||||
if err := rows.Scan(
|
||||
&col.Name, &col.DataType, &nullable, &defaultVal,
|
||||
&maxLength, &precision, &scale, &comment,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan column: %w", err)
|
||||
}
|
||||
|
||||
col.Nullable = nullable == "YES"
|
||||
|
||||
if defaultVal.Valid {
|
||||
col.Default = defaultVal.String
|
||||
}
|
||||
if maxLength.Valid && maxLength.Int32 > 0 {
|
||||
col.Length = int(maxLength.Int32)
|
||||
}
|
||||
if precision.Valid {
|
||||
col.Scale = int(precision.Int32)
|
||||
}
|
||||
if scale.Valid {
|
||||
col.Scale = int(scale.Int32)
|
||||
}
|
||||
if comment.Valid {
|
||||
col.Comment = comment.String
|
||||
}
|
||||
|
||||
structure.Columns = append(structure.Columns, col)
|
||||
}
|
||||
|
||||
// Get primary key information
|
||||
pkQuery := `
|
||||
SELECT a.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||||
WHERE i.indrelid = $1::regclass AND i.indisprimary
|
||||
`
|
||||
|
||||
pkRows, err := db.Query(pkQuery, tableName)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query primary keys", zap.Error(err))
|
||||
} else {
|
||||
defer pkRows.Close()
|
||||
|
||||
pkColumns := make(map[string]bool)
|
||||
for pkRows.Next() {
|
||||
var colName string
|
||||
if err := pkRows.Scan(&colName); err != nil {
|
||||
continue
|
||||
}
|
||||
pkColumns[colName] = true
|
||||
}
|
||||
|
||||
// Mark primary key columns
|
||||
for i := range structure.Columns {
|
||||
if pkColumns[structure.Columns[i].Name] {
|
||||
structure.Columns[i].IsPrimary = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get indexes
|
||||
indexQuery := `
|
||||
SELECT
|
||||
i.indexname,
|
||||
i.indexdef,
|
||||
i.indisunique
|
||||
FROM pg_indexes i
|
||||
WHERE i.schemaname = 'public' AND i.tablename = $1
|
||||
`
|
||||
|
||||
idxRows, err := db.Query(indexQuery, tableName)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
|
||||
} else {
|
||||
defer idxRows.Close()
|
||||
|
||||
for idxRows.Next() {
|
||||
var idx models.TableIndex
|
||||
var indexDef string
|
||||
if err := idxRows.Scan(&idx.Name, &indexDef, &idx.IsUnique); err != nil {
|
||||
continue
|
||||
}
|
||||
idx.IsPrimary = idx.Name == (tableName + "_pkey")
|
||||
|
||||
// Extract column names from index definition
|
||||
idx.Columns = extractColumnsFromIndexDef(indexDef)
|
||||
idx.Type = "btree" // Default for PostgreSQL
|
||||
|
||||
structure.Indexes = append(structure.Indexes, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return structure, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns database metadata
|
||||
func (p *PostgreSQLConnection) GetMetadata() (*models.DBMetadata, error) {
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
metadata := &models.DBMetadata{
|
||||
Database: p.database,
|
||||
User: p.username,
|
||||
Host: p.host,
|
||||
Port: p.port,
|
||||
}
|
||||
|
||||
// Get PostgreSQL version
|
||||
var version string
|
||||
err := db.QueryRow("SELECT version()").Scan(&version)
|
||||
if err == nil {
|
||||
metadata.Version = version
|
||||
}
|
||||
|
||||
// Get server time
|
||||
var serverTime string
|
||||
err = db.QueryRow("SELECT NOW()").Scan(&serverTime)
|
||||
if err == nil {
|
||||
metadata.ServerTime = serverTime
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// extractColumnsFromIndexDef extracts column names from PostgreSQL index definition
|
||||
func extractColumnsFromIndexDef(indexDef string) []string {
|
||||
var columns []string
|
||||
|
||||
// Simple extraction - look for content between parentheses
|
||||
start := strings.Index(indexDef, "(")
|
||||
end := strings.LastIndex(indexDef, ")")
|
||||
|
||||
if start != -1 && end != -1 && end > start {
|
||||
content := indexDef[start+1 : end]
|
||||
parts := strings.Split(content, ",")
|
||||
for _, part := range parts {
|
||||
col := strings.TrimSpace(part)
|
||||
// Remove any expressions like "LOWER(column)"
|
||||
if parenIdx := strings.Index(col, "("); parenIdx != -1 {
|
||||
continue
|
||||
}
|
||||
if col != "" {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
128
internal/database/sqlite.go
Normal file
128
internal/database/sqlite.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteDB *gorm.DB
|
||||
sqliteMu sync.RWMutex
|
||||
)
|
||||
|
||||
// InitSQLite initializes the SQLite database for application data
|
||||
func InitSQLite(dbPath string, cfg *config.DatabaseConfig) (*gorm.DB, error) {
|
||||
sqliteMu.Lock()
|
||||
defer sqliteMu.Unlock()
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
|
||||
// Configure GORM logger
|
||||
var gormLogger logger.Interface
|
||||
if config.Get().IsDevelopment() {
|
||||
gormLogger = logger.Default.LogMode(logger.Info)
|
||||
} else {
|
||||
gormLogger = logger.Default.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
// Open SQLite connection
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool settings
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get underlying DB: %w", err)
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(cfg.MaxLifetime) * time.Minute)
|
||||
|
||||
// Enable WAL mode for better concurrency
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
config.GetLogger().Warn("failed to enable WAL mode", zap.Error(err))
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
sqliteDB = db
|
||||
|
||||
config.GetLogger().Info("SQLite database initialized",
|
||||
zap.String("path", dbPath),
|
||||
zap.Int("max_open_conns", cfg.MaxOpenConns),
|
||||
zap.Int("max_idle_conns", cfg.MaxIdleConns),
|
||||
)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations runs database migrations
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
migrations := []interface{}{
|
||||
&models.UserConnection{},
|
||||
&models.QueryHistory{},
|
||||
&models.SavedQuery{},
|
||||
}
|
||||
|
||||
for _, model := range migrations {
|
||||
if err := db.AutoMigrate(model); err != nil {
|
||||
return fmt.Errorf("failed to migrate %T: %w", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
config.GetLogger().Info("database migrations completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSQLiteDB returns the SQLite database instance
|
||||
func GetSQLiteDB() *gorm.DB {
|
||||
sqliteMu.RLock()
|
||||
defer sqliteMu.RUnlock()
|
||||
return sqliteDB
|
||||
}
|
||||
|
||||
// CloseSQLite closes the SQLite database connection
|
||||
func CloseSQLite() error {
|
||||
sqliteMu.Lock()
|
||||
defer sqliteMu.Unlock()
|
||||
|
||||
if sqliteDB == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := sqliteDB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close SQLite database: %w", err)
|
||||
}
|
||||
|
||||
sqliteDB = nil
|
||||
config.GetLogger().Info("SQLite database closed")
|
||||
return nil
|
||||
}
|
||||
369
internal/database/sqlite_driver.go
Normal file
369
internal/database/sqlite_driver.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// SQLiteConnection represents a SQLite connection
|
||||
type SQLiteConnection struct {
|
||||
db *sql.DB
|
||||
filePath string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSQLiteConnection creates a new SQLite connection
|
||||
func NewSQLiteConnection(conn *models.UserConnection) (*SQLiteConnection, error) {
|
||||
filePath := conn.Database
|
||||
|
||||
// Ensure file exists for new connections
|
||||
db, err := sql.Open("sqlite", filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(1) // SQLite only supports one writer
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// Enable WAL mode
|
||||
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
config.GetLogger().Warn("failed to enable WAL mode", zap.Error(err))
|
||||
}
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping SQLite: %w", err)
|
||||
}
|
||||
|
||||
sqliteConn := &SQLiteConnection{
|
||||
db: db,
|
||||
filePath: filePath,
|
||||
}
|
||||
|
||||
config.GetLogger().Info("SQLite connection established",
|
||||
zap.String("path", filePath),
|
||||
)
|
||||
|
||||
return sqliteConn, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying sql.DB
|
||||
func (s *SQLiteConnection) GetDB() *sql.DB {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db
|
||||
}
|
||||
|
||||
// Close closes the SQLite connection
|
||||
func (s *SQLiteConnection) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.db != nil {
|
||||
if err := s.db.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close SQLite connection: %w", err)
|
||||
}
|
||||
s.db = nil
|
||||
config.GetLogger().Info("SQLite connection closed",
|
||||
zap.String("path", s.filePath),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the connection is alive
|
||||
func (s *SQLiteConnection) IsConnected() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := s.db.Ping()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query and returns results
|
||||
func (s *SQLiteConnection) ExecuteQuery(query string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query execution failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %w", err)
|
||||
}
|
||||
result.Columns = columns
|
||||
|
||||
// Scan rows
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
row := make([]interface{}, len(columns))
|
||||
for i, v := range values {
|
||||
row[i] = convertValue(v)
|
||||
}
|
||||
result.Rows = append(result.Rows, row)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("row iteration error: %w", err)
|
||||
}
|
||||
|
||||
result.RowCount = int64(len(result.Rows))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
|
||||
func (s *SQLiteConnection) ExecuteStatement(stmt string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
res, err := db.Exec(stmt, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("statement execution failed: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := res.RowsAffected()
|
||||
lastInsertID, _ := res.LastInsertId()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
AffectedRows: rowsAffected,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
Rows: [][]interface{}{{lastInsertID}},
|
||||
Columns: []string{"last_insert_id"},
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTables returns all tables in the database
|
||||
func (s *SQLiteConnection) GetTables(schema string) ([]models.Table, error) {
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT name, type
|
||||
FROM sqlite_master
|
||||
WHERE type IN ('table', 'view') AND name NOT LIKE 'sqlite_%'
|
||||
ORDER BY name
|
||||
`
|
||||
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []models.Table
|
||||
for rows.Next() {
|
||||
var t models.Table
|
||||
if err := rows.Scan(&t.Name, &t.Type); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table: %w", err)
|
||||
}
|
||||
|
||||
// Get row count for tables
|
||||
if t.Type == "table" {
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM \"%s\"", t.Name)
|
||||
var rowCount int64
|
||||
if err := db.QueryRow(countQuery).Scan(&rowCount); err == nil {
|
||||
t.RowCount = rowCount
|
||||
}
|
||||
}
|
||||
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (s *SQLiteConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
structure := &models.TableStructure{
|
||||
TableName: tableName,
|
||||
}
|
||||
|
||||
// Get table info using PRAGMA
|
||||
query := fmt.Sprintf("PRAGMA table_info(\"%s\")", tableName)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query table info: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var col models.TableColumn
|
||||
var pk int
|
||||
|
||||
if err := rows.Scan(&col.Name, &col.DataType, &col.Default, &col.Nullable, &pk, &col.IsPrimary); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan column: %w", err)
|
||||
}
|
||||
|
||||
col.IsPrimary = pk > 0
|
||||
col.Nullable = col.Nullable || !col.IsPrimary
|
||||
|
||||
structure.Columns = append(structure.Columns, col)
|
||||
}
|
||||
|
||||
// Get indexes
|
||||
idxQuery := fmt.Sprintf("PRAGMA index_list(\"%s\")", tableName)
|
||||
idxRows, err := db.Query(idxQuery)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
|
||||
} else {
|
||||
defer idxRows.Close()
|
||||
|
||||
for idxRows.Next() {
|
||||
var idx models.TableIndex
|
||||
var origin string
|
||||
|
||||
if err := idxRows.Scan(&idx.Name, &idx.IsUnique, &origin); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
idx.IsPrimary = idx.Name == "sqlite_autoindex_"+tableName+"_1"
|
||||
idx.Type = origin
|
||||
|
||||
// Get index columns
|
||||
colQuery := fmt.Sprintf("PRAGMA index_info(\"%s\")", idx.Name)
|
||||
colRows, err := db.Query(colQuery)
|
||||
if err == nil {
|
||||
defer colRows.Close()
|
||||
for colRows.Next() {
|
||||
var seqno int
|
||||
var colName string
|
||||
if err := colRows.Scan(&seqno, &colName, &colName); err != nil {
|
||||
continue
|
||||
}
|
||||
idx.Columns = append(idx.Columns, colName)
|
||||
}
|
||||
}
|
||||
|
||||
structure.Indexes = append(structure.Indexes, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Get foreign keys
|
||||
fkQuery := fmt.Sprintf("PRAGMA foreign_key_list(\"%s\")", tableName)
|
||||
fkRows, err := db.Query(fkQuery)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query foreign keys", zap.Error(err))
|
||||
} else {
|
||||
defer fkRows.Close()
|
||||
|
||||
fkMap := make(map[string]*models.ForeignKey)
|
||||
for fkRows.Next() {
|
||||
var fk models.ForeignKey
|
||||
var id, seq int
|
||||
|
||||
if err := fkRows.Scan(&id, &seq, &fk.Name, &fk.ReferencedTable,
|
||||
&fk.Columns, &fk.ReferencedColumns, &fk.OnUpdate, &fk.OnDelete, ""); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle array fields properly
|
||||
fk.Columns = []string{}
|
||||
fk.ReferencedColumns = []string{}
|
||||
|
||||
var fromCol, toCol string
|
||||
if err := fkRows.Scan(&id, &seq, &fk.Name, &fk.ReferencedTable,
|
||||
&fromCol, &toCol, &fk.OnUpdate, &fk.OnDelete, ""); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
existingFk, exists := fkMap[fk.Name]
|
||||
if !exists {
|
||||
existingFk = &models.ForeignKey{
|
||||
Name: fk.Name,
|
||||
ReferencedTable: fk.ReferencedTable,
|
||||
OnDelete: fk.OnDelete,
|
||||
OnUpdate: fk.OnUpdate,
|
||||
Columns: []string{},
|
||||
ReferencedColumns: []string{},
|
||||
}
|
||||
fkMap[fk.Name] = existingFk
|
||||
}
|
||||
existingFk.Columns = append(existingFk.Columns, fromCol)
|
||||
existingFk.ReferencedColumns = append(existingFk.ReferencedColumns, toCol)
|
||||
}
|
||||
|
||||
for _, fk := range fkMap {
|
||||
structure.ForeignKeys = append(structure.ForeignKeys, *fk)
|
||||
}
|
||||
}
|
||||
|
||||
return structure, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns database metadata
|
||||
func (s *SQLiteConnection) GetMetadata() (*models.DBMetadata, error) {
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
metadata := &models.DBMetadata{
|
||||
Database: s.filePath,
|
||||
Version: "3", // SQLite version 3
|
||||
}
|
||||
|
||||
// Get SQLite version
|
||||
var version string
|
||||
err := db.QueryRow("SELECT sqlite_version()").Scan(&version)
|
||||
if err == nil {
|
||||
metadata.Version = version
|
||||
}
|
||||
|
||||
// Get current time
|
||||
metadata.ServerTime = time.Now().Format(time.RFC3339)
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
Reference in New Issue
Block a user