refactor: Flatten directory structure

Move project files from uzdb/ subdirectory to root directory for cleaner project structure.

Changes:
- Move frontend/ to root
- Move internal/ to root
- Move build/ to root
- Move all config files (go.mod, wails.json, etc.) to root
- Remove redundant uzdb/ subdirectory nesting

Project structure is now:
├── frontend/        # React application
├── internal/        # Go backend
├── build/           # Wails build assets
├── doc/             # Design documentation
├── main.go          # Entry point
└── ...

🤖 Generated with Qoder
This commit is contained in:
loveuer
2026-04-04 07:14:00 -07:00
parent 5a83e86bc9
commit 9874561410
83 changed files with 0 additions and 46 deletions

324
internal/app/app.go Normal file
View File

@@ -0,0 +1,324 @@
package app
import (
"context"
"time"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/database"
"uzdb/internal/handler"
"uzdb/internal/models"
"uzdb/internal/services"
)
// App is the main application structure for Wails bindings
type App struct {
ctx context.Context
config *config.Config
connectionSvc *services.ConnectionService
querySvc *services.QueryService
httpServer *handler.HTTPServer
shutdownFunc context.CancelFunc
}
// NewApp creates a new App instance
func NewApp() *App {
return &App{}
}
// Initialize initializes the application with all services
func (a *App) Initialize(
cfg *config.Config,
connectionSvc *services.ConnectionService,
querySvc *services.QueryService,
httpServer *handler.HTTPServer,
) {
a.config = cfg
a.connectionSvc = connectionSvc
a.querySvc = querySvc
a.httpServer = httpServer
}
// OnStartup is called when the app starts (public method for Wails)
func (a *App) OnStartup(ctx context.Context) {
a.ctx = ctx
config.GetLogger().Info("Wails application started")
}
// GetConnections returns all user connections
// Wails binding: frontend can call window.go.app.GetConnections()
func (a *App) GetConnections() []models.UserConnection {
if a.connectionSvc == nil {
return []models.UserConnection{}
}
connections, err := a.connectionSvc.GetAllConnections(a.ctx)
if err != nil {
config.GetLogger().Error("failed to get connections", zap.Error(err))
return []models.UserConnection{}
}
return connections
}
// CreateConnection creates a new database connection
// Returns error message or empty string on success
func (a *App) CreateConnection(conn models.CreateConnectionRequest) string {
if a.connectionSvc == nil {
return "Service not initialized"
}
_, err := a.connectionSvc.CreateConnection(a.ctx, &conn)
if err != nil {
config.GetLogger().Error("failed to create connection", zap.Error(err))
return err.Error()
}
return ""
}
// UpdateConnection updates an existing database connection
// Returns error message or empty string on success
func (a *App) UpdateConnection(conn models.UserConnection) string {
if a.connectionSvc == nil {
return "Service not initialized"
}
req := &models.UpdateConnectionRequest{
Name: conn.Name,
Type: conn.Type,
Host: conn.Host,
Port: conn.Port,
Username: conn.Username,
Password: conn.Password,
Database: conn.Database,
SSLMode: conn.SSLMode,
Timeout: conn.Timeout,
}
_, err := a.connectionSvc.UpdateConnection(a.ctx, conn.ID, req)
if err != nil {
config.GetLogger().Error("failed to update connection", zap.Error(err))
return err.Error()
}
return ""
}
// DeleteConnection deletes a database connection
// Returns error message or empty string on success
func (a *App) DeleteConnection(id string) string {
if a.connectionSvc == nil {
return "Service not initialized"
}
err := a.connectionSvc.DeleteConnection(a.ctx, id)
if err != nil {
config.GetLogger().Error("failed to delete connection", zap.Error(err))
return err.Error()
}
return ""
}
// TestConnection tests a database connection
// Returns (success, error_message)
func (a *App) TestConnection(id string) (bool, string) {
if a.connectionSvc == nil {
return false, "Service not initialized"
}
result, err := a.connectionSvc.TestConnection(a.ctx, id)
if err != nil {
config.GetLogger().Error("failed to test connection", zap.Error(err))
return false, err.Error()
}
return result.Success, result.Message
}
// ExecuteQuery executes a SQL query on a connection
// Returns query result or error message
func (a *App) ExecuteQuery(connectionID, sql string) (*models.QueryResult, string) {
if a.connectionSvc == nil {
return nil, "Service not initialized"
}
result, err := a.connectionSvc.ExecuteQuery(a.ctx, connectionID, sql)
if err != nil {
config.GetLogger().Error("failed to execute query",
zap.String("connection_id", connectionID),
zap.String("sql", sql),
zap.Error(err))
return nil, err.Error()
}
return result, ""
}
// GetTables returns all tables for a connection
func (a *App) GetTables(connectionID string) ([]models.Table, string) {
if a.connectionSvc == nil {
return []models.Table{}, "Service not initialized"
}
tables, err := a.connectionSvc.GetTables(a.ctx, connectionID)
if err != nil {
config.GetLogger().Error("failed to get tables",
zap.String("connection_id", connectionID),
zap.Error(err))
return []models.Table{}, err.Error()
}
return tables, ""
}
// GetTableData returns data from a table
func (a *App) GetTableData(connectionID, tableName string, limit, offset int) (*models.QueryResult, string) {
if a.connectionSvc == nil {
return nil, "Service not initialized"
}
result, err := a.connectionSvc.GetTableData(a.ctx, connectionID, tableName, limit, offset)
if err != nil {
config.GetLogger().Error("failed to get table data",
zap.String("connection_id", connectionID),
zap.String("table", tableName),
zap.Error(err))
return nil, err.Error()
}
return result, ""
}
// GetTableStructure returns the structure of a table
func (a *App) GetTableStructure(connectionID, tableName string) (*models.TableStructure, string) {
if a.connectionSvc == nil {
return nil, "Service not initialized"
}
structure, err := a.connectionSvc.GetTableStructure(a.ctx, connectionID, tableName)
if err != nil {
config.GetLogger().Error("failed to get table structure",
zap.String("connection_id", connectionID),
zap.String("table", tableName),
zap.Error(err))
return nil, err.Error()
}
return structure, ""
}
// GetQueryHistory returns query history with pagination
func (a *App) GetQueryHistory(connectionID string, page, pageSize int) ([]models.QueryHistory, int64, string) {
if a.querySvc == nil {
return []models.QueryHistory{}, 0, "Service not initialized"
}
history, total, err := a.querySvc.GetQueryHistory(a.ctx, connectionID, page, pageSize)
if err != nil {
config.GetLogger().Error("failed to get query history", zap.Error(err))
return []models.QueryHistory{}, 0, err.Error()
}
return history, total, ""
}
// GetSavedQueries returns all saved queries
func (a *App) GetSavedQueries(connectionID string) ([]models.SavedQuery, string) {
if a.querySvc == nil {
return []models.SavedQuery{}, "Service not initialized"
}
queries, err := a.querySvc.GetSavedQueries(a.ctx, connectionID)
if err != nil {
config.GetLogger().Error("failed to get saved queries", zap.Error(err))
return []models.SavedQuery{}, err.Error()
}
return queries, ""
}
// CreateSavedQuery creates a new saved query
func (a *App) CreateSavedQuery(req models.CreateSavedQueryRequest) (*models.SavedQuery, string) {
if a.querySvc == nil {
return nil, "Service not initialized"
}
query, err := a.querySvc.CreateSavedQuery(a.ctx, &req)
if err != nil {
config.GetLogger().Error("failed to create saved query", zap.Error(err))
return nil, err.Error()
}
return query, ""
}
// UpdateSavedQuery updates a saved query
func (a *App) UpdateSavedQuery(id uint, req models.UpdateSavedQueryRequest) (*models.SavedQuery, string) {
if a.querySvc == nil {
return nil, "Service not initialized"
}
query, err := a.querySvc.UpdateSavedQuery(a.ctx, id, &req)
if err != nil {
config.GetLogger().Error("failed to update saved query", zap.Error(err))
return nil, err.Error()
}
return query, ""
}
// DeleteSavedQuery deletes a saved query
func (a *App) DeleteSavedQuery(id uint) string {
if a.querySvc == nil {
return "Service not initialized"
}
err := a.querySvc.DeleteSavedQuery(a.ctx, id)
if err != nil {
config.GetLogger().Error("failed to delete saved query", zap.Error(err))
return err.Error()
}
return ""
}
// StartHTTPServer starts the HTTP API server in background
func (a *App) StartHTTPServer() string {
if a.httpServer == nil {
return "HTTP server not initialized"
}
go func() {
port := a.config.API.Port
if err := a.httpServer.Start(port); err != nil {
config.GetLogger().Error("HTTP server error", zap.Error(err))
}
}()
return ""
}
// Shutdown gracefully shuts down the application
func (a *App) Shutdown() {
config.GetLogger().Info("shutting down application")
if a.shutdownFunc != nil {
a.shutdownFunc()
}
if a.httpServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
a.httpServer.Shutdown(ctx)
}
// Close all database connections
database.CloseSQLite()
config.Sync()
}

261
internal/config/config.go Normal file
View File

@@ -0,0 +1,261 @@
package config
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// Environment represents the application environment type
type Environment string
const (
// EnvDevelopment represents development environment
EnvDevelopment Environment = "development"
// EnvProduction represents production environment
EnvProduction Environment = "production"
)
// Config holds all configuration for the application
type Config struct {
// App settings
AppName string `json:"app_name"`
Version string `json:"version"`
Environment Environment `json:"environment"`
// Database settings (SQLite for app data)
Database DatabaseConfig `json:"database"`
// Encryption settings
Encryption EncryptionConfig `json:"encryption"`
// Logger settings
Logger LoggerConfig `json:"logger"`
// API settings (for debug HTTP server)
API APIConfig `json:"api"`
// File paths
DataDir string `json:"-"`
}
// DatabaseConfig holds database configuration
type DatabaseConfig struct {
// SQLite database file path for app data
SQLitePath string `json:"sqlite_path"`
// Max open connections
MaxOpenConns int `json:"max_open_conns"`
// Max idle connections
MaxIdleConns int `json:"max_idle_conns"`
// Connection max lifetime in minutes
MaxLifetime int `json:"max_lifetime"`
}
// EncryptionConfig holds encryption configuration
type EncryptionConfig struct {
// Key for encrypting sensitive data (passwords, etc.)
// In production, this should be loaded from secure storage
Key string `json:"-"`
// KeyFile path to load encryption key from
KeyFile string `json:"key_file"`
}
// LoggerConfig holds logger configuration
type LoggerConfig struct {
// Log level: debug, info, warn, error
Level string `json:"level"`
// Log format: json, console
Format string `json:"format"`
// Output file path (empty for stdout)
OutputPath string `json:"output_path"`
}
// APIConfig holds HTTP API configuration
type APIConfig struct {
// Enable HTTP API server (for debugging)
Enabled bool `json:"enabled"`
// Port for HTTP API server
Port string `json:"port"`
}
var (
instance *Config
once sync.Once
logger *zap.Logger
)
// Get returns the singleton config instance
func Get() *Config {
return instance
}
// GetLogger returns the zap logger
func GetLogger() *zap.Logger {
return logger
}
// Init initializes the configuration
// If config file doesn't exist, creates default config
func Init(dataDir string) (*Config, error) {
var err error
once.Do(func() {
instance = &Config{
DataDir: dataDir,
}
err = instance.load(dataDir)
})
return instance, err
}
// load loads configuration from file or creates default
func (c *Config) load(dataDir string) error {
configPath := filepath.Join(dataDir, "config.json")
// Try to load existing config
if _, err := os.Stat(configPath); err == nil {
data, err := os.ReadFile(configPath)
if err != nil {
return err
}
if err := json.Unmarshal(data, c); err != nil {
return err
}
} else {
// Create default config
c.setDefaults()
if err := c.save(configPath); err != nil {
return err
}
}
// Override with environment variables
c.loadEnv()
// Initialize logger
if err := c.initLogger(); err != nil {
return err
}
logger.Info("configuration loaded",
zap.String("environment", string(c.Environment)),
zap.String("data_dir", c.DataDir),
)
return nil
}
// setDefaults sets default configuration values
func (c *Config) setDefaults() {
c.AppName = "uzdb"
c.Version = "1.0.0"
c.Environment = EnvDevelopment
c.Database = DatabaseConfig{
SQLitePath: filepath.Join(c.DataDir, "uzdb.db"),
MaxOpenConns: 25,
MaxIdleConns: 5,
MaxLifetime: 5,
}
c.Encryption = EncryptionConfig{
Key: "", // Will be generated if empty
KeyFile: filepath.Join(c.DataDir, "encryption.key"),
}
c.Logger = LoggerConfig{
Level: "debug",
Format: "console",
OutputPath: "",
}
c.API = APIConfig{
Enabled: true,
Port: "8080",
}
}
// loadEnv loads configuration from environment variables
func (c *Config) loadEnv() {
if env := os.Getenv("UZDB_ENV"); env != "" {
c.Environment = Environment(env)
}
if port := os.Getenv("UZDB_API_PORT"); port != "" {
c.API.Port = port
}
if logLevel := os.Getenv("UZDB_LOG_LEVEL"); logLevel != "" {
c.Logger.Level = logLevel
}
if dbPath := os.Getenv("UZDB_DB_PATH"); dbPath != "" {
c.Database.SQLitePath = dbPath
}
}
// save saves configuration to file
func (c *Config) save(path string) error {
data, err := json.MarshalIndent(c, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
// initLogger initializes the zap logger
func (c *Config) initLogger() error {
var cfg zap.Config
switch c.Logger.Format {
case "json":
cfg = zap.NewProductionConfig()
default:
cfg = zap.NewDevelopmentConfig()
}
// Set log level
level, parseErr := zapcore.ParseLevel(c.Logger.Level)
if parseErr != nil {
level = zapcore.InfoLevel
}
cfg.Level.SetLevel(level)
// Configure output
if c.Logger.OutputPath != "" {
cfg.OutputPaths = []string{c.Logger.OutputPath}
cfg.ErrorOutputPaths = []string{c.Logger.OutputPath}
}
var buildErr error
logger, buildErr = cfg.Build(
zap.AddCaller(),
zap.AddStacktrace(zapcore.ErrorLevel),
)
if buildErr != nil {
return buildErr
}
return nil
}
// IsDevelopment returns true if running in development mode
func (c *Config) IsDevelopment() bool {
return c.Environment == EnvDevelopment
}
// IsProduction returns true if running in production mode
func (c *Config) IsProduction() bool {
return c.Environment == EnvProduction
}
// Sync flushes any buffered log entries
func Sync() error {
if logger != nil {
return logger.Sync()
}
return nil
}

View File

@@ -0,0 +1,137 @@
package database
import (
"database/sql"
"fmt"
"sync"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
)
// ConnectionManager manages database connections for different database types
type ConnectionManager struct {
connections map[string]DatabaseConnector
mu sync.RWMutex
}
// DatabaseConnector is the interface for all database connections
type DatabaseConnector interface {
GetDB() *sql.DB
Close() error
IsConnected() bool
ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error)
ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error)
GetTables(schema string) ([]models.Table, error)
GetTableStructure(tableName string) (*models.TableStructure, error)
GetMetadata() (*models.DBMetadata, error)
}
// NewConnectionManager creates a new connection manager
func NewConnectionManager() *ConnectionManager {
return &ConnectionManager{
connections: make(map[string]DatabaseConnector),
}
}
// GetConnection gets or creates a connection for a user connection config
func (m *ConnectionManager) GetConnection(conn *models.UserConnection, password string) (DatabaseConnector, error) {
m.mu.Lock()
defer m.mu.Unlock()
// Check if connection already exists
if existing, ok := m.connections[conn.ID]; ok {
if existing.IsConnected() {
return existing, nil
}
// Connection is dead, remove it
existing.Close()
delete(m.connections, conn.ID)
}
// Create new connection based on type
var connector DatabaseConnector
var err error
switch conn.Type {
case models.ConnectionTypeMySQL:
connector, err = NewMySQLConnection(conn, password)
case models.ConnectionTypePostgreSQL:
connector, err = NewPostgreSQLConnection(conn, password)
case models.ConnectionTypeSQLite:
connector, err = NewSQLiteConnection(conn)
default:
return nil, fmt.Errorf("unsupported database type: %s", conn.Type)
}
if err != nil {
return nil, err
}
m.connections[conn.ID] = connector
config.GetLogger().Info("connection created in manager",
zap.String("connection_id", conn.ID),
zap.String("type", string(conn.Type)),
)
return connector, nil
}
// RemoveConnection removes a connection from the manager
func (m *ConnectionManager) RemoveConnection(connectionID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if conn, ok := m.connections[connectionID]; ok {
if err := conn.Close(); err != nil {
return err
}
delete(m.connections, connectionID)
config.GetLogger().Info("connection removed from manager",
zap.String("connection_id", connectionID),
)
}
return nil
}
// GetAllConnections returns all managed connections
func (m *ConnectionManager) GetAllConnections() map[string]DatabaseConnector {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]DatabaseConnector)
for k, v := range m.connections {
result[k] = v
}
return result
}
// CloseAll closes all managed connections
func (m *ConnectionManager) CloseAll() error {
m.mu.Lock()
defer m.mu.Unlock()
var lastErr error
for id, conn := range m.connections {
if err := conn.Close(); err != nil {
lastErr = err
config.GetLogger().Error("failed to close connection",
zap.String("connection_id", id),
zap.Error(err),
)
}
}
m.connections = make(map[string]DatabaseConnector)
if lastErr != nil {
return lastErr
}
config.GetLogger().Info("all connections closed")
return nil
}

414
internal/database/mysql.go Normal file
View File

@@ -0,0 +1,414 @@
package database
import (
"database/sql"
"fmt"
"strings"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
)
// MySQLConnection represents a MySQL connection
type MySQLConnection struct {
db *sql.DB
dsn string
host string
port int
database string
username string
mu sync.RWMutex
}
// BuildMySQLDSN builds MySQL DSN connection string
func BuildMySQLDSN(conn *models.UserConnection, password string) string {
// MySQL DSN format: user:pass@tcp(host:port)/dbname?params
params := []string{
"parseTime=true",
"loc=Local",
fmt.Sprintf("timeout=%ds", conn.Timeout),
"charset=utf8mb4",
"collation=utf8mb4_unicode_ci",
}
if conn.SSLMode != "" && conn.SSLMode != "disable" {
switch conn.SSLMode {
case "require":
params = append(params, "tls=required")
case "verify-ca":
params = append(params, "tls=skip-verify")
case "verify-full":
params = append(params, "tls=skip-verify")
}
}
queryString := strings.Join(params, "&")
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s",
conn.Username,
password,
conn.Host,
conn.Port,
conn.Database,
queryString,
)
}
// NewMySQLConnection creates a new MySQL connection
func NewMySQLConnection(conn *models.UserConnection, password string) (*MySQLConnection, error) {
dsn := BuildMySQLDSN(conn, password)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open MySQL connection: %w", err)
}
// Configure connection pool
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
// Test connection
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping MySQL: %w", err)
}
mysqlConn := &MySQLConnection{
db: db,
dsn: dsn,
host: conn.Host,
port: conn.Port,
database: conn.Database,
username: conn.Username,
}
config.GetLogger().Info("MySQL connection established",
zap.String("host", conn.Host),
zap.Int("port", conn.Port),
zap.String("database", conn.Database),
)
return mysqlConn, nil
}
// GetDB returns the underlying sql.DB
func (m *MySQLConnection) GetDB() *sql.DB {
m.mu.RLock()
defer m.mu.RUnlock()
return m.db
}
// Close closes the MySQL connection
func (m *MySQLConnection) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.db != nil {
if err := m.db.Close(); err != nil {
return fmt.Errorf("failed to close MySQL connection: %w", err)
}
m.db = nil
config.GetLogger().Info("MySQL connection closed",
zap.String("host", m.host),
zap.Int("port", m.port),
)
}
return nil
}
// IsConnected checks if the connection is alive
func (m *MySQLConnection) IsConnected() bool {
m.mu.RLock()
defer m.mu.RUnlock()
if m.db == nil {
return false
}
err := m.db.Ping()
return err == nil
}
// ExecuteQuery executes a SQL query and returns results
func (m *MySQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) {
startTime := time.Now()
db := m.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
rows, err := db.Query(sql, args...)
if err != nil {
return nil, fmt.Errorf("query execution failed: %w", err)
}
defer rows.Close()
result := &models.QueryResult{
Success: true,
Duration: time.Since(startTime).Milliseconds(),
}
// Get column names
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %w", err)
}
result.Columns = columns
// Scan rows
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
row := make([]interface{}, len(columns))
for i, v := range values {
row[i] = convertValue(v)
}
result.Rows = append(result.Rows, row)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration error: %w", err)
}
result.RowCount = int64(len(result.Rows))
return result, nil
}
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
func (m *MySQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) {
startTime := time.Now()
db := m.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
res, err := db.Exec(sql, args...)
if err != nil {
return nil, fmt.Errorf("statement execution failed: %w", err)
}
rowsAffected, _ := res.RowsAffected()
lastInsertID, _ := res.LastInsertId()
result := &models.QueryResult{
Success: true,
AffectedRows: rowsAffected,
Duration: time.Since(startTime).Milliseconds(),
Rows: [][]interface{}{{lastInsertID}},
Columns: []string{"last_insert_id"},
}
return result, nil
}
// GetTables returns all tables in the database
func (m *MySQLConnection) GetTables(schema string) ([]models.Table, error) {
db := m.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
query := `
SELECT
TABLE_NAME as table_name,
TABLE_TYPE as table_type,
TABLE_ROWS as row_count
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = ?
ORDER BY TABLE_NAME
`
rows, err := db.Query(query, m.database)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
var tables []models.Table
for rows.Next() {
var t models.Table
var rowCount sql.NullInt64
if err := rows.Scan(&t.Name, &t.Type, &rowCount); err != nil {
return nil, fmt.Errorf("failed to scan table: %w", err)
}
if rowCount.Valid {
t.RowCount = rowCount.Int64
}
tables = append(tables, t)
}
return tables, nil
}
// GetTableStructure returns the structure of a table
func (m *MySQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
db := m.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
structure := &models.TableStructure{
TableName: tableName,
}
// Get columns
query := `
SELECT
COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT,
COLUMN_KEY, EXTRA, CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION,
NUMERIC_SCALE, COLUMN_COMMENT
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION
`
rows, err := db.Query(query, m.database, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query columns: %w", err)
}
defer rows.Close()
for rows.Next() {
var col models.TableColumn
var nullable, columnKey, extra string
var defaultVal, comment sql.NullString
var maxLength, precision, scale sql.NullInt32
if err := rows.Scan(
&col.Name, &col.DataType, &nullable, &defaultVal,
&columnKey, &extra, &maxLength, &precision,
&scale, &comment,
); err != nil {
return nil, fmt.Errorf("failed to scan column: %w", err)
}
col.Nullable = nullable == "YES"
col.IsPrimary = columnKey == "PRI"
col.IsUnique = columnKey == "UNI"
col.AutoIncrement = strings.Contains(extra, "auto_increment")
if defaultVal.Valid {
col.Default = defaultVal.String
}
if maxLength.Valid {
col.Length = int(maxLength.Int32)
}
if precision.Valid {
col.Scale = int(precision.Int32)
}
if scale.Valid {
col.Scale = int(scale.Int32)
}
if comment.Valid {
col.Comment = comment.String
}
structure.Columns = append(structure.Columns, col)
}
// Get indexes
indexQuery := `
SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
ORDER BY INDEX_NAME, SEQ_IN_INDEX
`
idxRows, err := db.Query(indexQuery, m.database, tableName)
if err != nil {
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
} else {
defer idxRows.Close()
indexMap := make(map[string]*models.TableIndex)
for idxRows.Next() {
var indexName, columnName string
var nonUnique bool
var seqInIndex int
if err := idxRows.Scan(&indexName, &columnName, &nonUnique, &seqInIndex); err != nil {
continue
}
idx, exists := indexMap[indexName]
if !exists {
idx = &models.TableIndex{
Name: indexName,
IsUnique: !nonUnique,
IsPrimary: indexName == "PRIMARY",
Columns: []string{},
}
indexMap[indexName] = idx
}
idx.Columns = append(idx.Columns, columnName)
}
for _, idx := range indexMap {
structure.Indexes = append(structure.Indexes, *idx)
}
}
return structure, nil
}
// GetMetadata returns database metadata
func (m *MySQLConnection) GetMetadata() (*models.DBMetadata, error) {
db := m.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
metadata := &models.DBMetadata{
Database: m.database,
User: m.username,
Host: m.host,
Port: m.port,
}
// Get MySQL version
var version string
err := db.QueryRow("SELECT VERSION()").Scan(&version)
if err == nil {
metadata.Version = version
}
// Get server time
var serverTime string
err = db.QueryRow("SELECT NOW()").Scan(&serverTime)
if err == nil {
metadata.ServerTime = serverTime
}
return metadata, nil
}
// convertValue converts database value to interface{}
func convertValue(val interface{}) interface{} {
if val == nil {
return nil
}
// Handle []byte (common from MySQL driver)
if b, ok := val.([]byte); ok {
return string(b)
}
return val
}

View File

@@ -0,0 +1,449 @@
package database
import (
"database/sql"
"fmt"
"strings"
"sync"
"time"
_ "github.com/lib/pq"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
)
// PostgreSQLConnection represents a PostgreSQL connection
type PostgreSQLConnection struct {
db *sql.DB
dsn string
host string
port int
database string
username string
mu sync.RWMutex
}
// BuildPostgreSQLDSN builds PostgreSQL DSN connection string
func BuildPostgreSQLDSN(conn *models.UserConnection, password string) string {
// PostgreSQL DSN format: postgres://user:pass@host:port/dbname?params
params := []string{
fmt.Sprintf("connect_timeout=%d", conn.Timeout),
"sslmode=disable",
}
if conn.SSLMode != "" {
switch conn.SSLMode {
case "require":
params = append(params, "sslmode=require")
case "verify-ca":
params = append(params, "sslmode=verify-ca")
case "verify-full":
params = append(params, "sslmode=verify-full")
case "disable":
params = append(params, "sslmode=disable")
default:
params = append(params, fmt.Sprintf("sslmode=%s", conn.SSLMode))
}
}
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?%s",
conn.Username,
password,
conn.Host,
conn.Port,
conn.Database,
strings.Join(params, "&"),
)
}
// NewPostgreSQLConnection creates a new PostgreSQL connection
func NewPostgreSQLConnection(conn *models.UserConnection, password string) (*PostgreSQLConnection, error) {
dsn := BuildPostgreSQLDSN(conn, password)
db, err := sql.Open("postgres", dsn)
if err != nil {
return nil, fmt.Errorf("failed to open PostgreSQL connection: %w", err)
}
// Configure connection pool
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
// Test connection
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err)
}
pgConn := &PostgreSQLConnection{
db: db,
dsn: dsn,
host: conn.Host,
port: conn.Port,
database: conn.Database,
username: conn.Username,
}
config.GetLogger().Info("PostgreSQL connection established",
zap.String("host", conn.Host),
zap.Int("port", conn.Port),
zap.String("database", conn.Database),
)
return pgConn, nil
}
// GetDB returns the underlying sql.DB
func (p *PostgreSQLConnection) GetDB() *sql.DB {
p.mu.RLock()
defer p.mu.RUnlock()
return p.db
}
// Close closes the PostgreSQL connection
func (p *PostgreSQLConnection) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.db != nil {
if err := p.db.Close(); err != nil {
return fmt.Errorf("failed to close PostgreSQL connection: %w", err)
}
p.db = nil
config.GetLogger().Info("PostgreSQL connection closed",
zap.String("host", p.host),
zap.Int("port", p.port),
)
}
return nil
}
// IsConnected checks if the connection is alive
func (p *PostgreSQLConnection) IsConnected() bool {
p.mu.RLock()
defer p.mu.RUnlock()
if p.db == nil {
return false
}
err := p.db.Ping()
return err == nil
}
// ExecuteQuery executes a SQL query and returns results
func (p *PostgreSQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) {
startTime := time.Now()
db := p.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
rows, err := db.Query(sql, args...)
if err != nil {
return nil, fmt.Errorf("query execution failed: %w", err)
}
defer rows.Close()
result := &models.QueryResult{
Success: true,
Duration: time.Since(startTime).Milliseconds(),
}
// Get column names
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %w", err)
}
result.Columns = columns
// Scan rows
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
row := make([]interface{}, len(columns))
for i, v := range values {
row[i] = convertValue(v)
}
result.Rows = append(result.Rows, row)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration error: %w", err)
}
result.RowCount = int64(len(result.Rows))
return result, nil
}
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
func (p *PostgreSQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) {
startTime := time.Now()
db := p.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
res, err := db.Exec(sql, args...)
if err != nil {
return nil, fmt.Errorf("statement execution failed: %w", err)
}
rowsAffected, _ := res.RowsAffected()
result := &models.QueryResult{
Success: true,
AffectedRows: rowsAffected,
Duration: time.Since(startTime).Milliseconds(),
}
return result, nil
}
// GetTables returns all tables in the database
func (p *PostgreSQLConnection) GetTables(schema string) ([]models.Table, error) {
db := p.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
if schema == "" {
schema = "public"
}
query := `
SELECT
table_name,
table_type,
(SELECT COUNT(*) FROM information_schema.columns c
WHERE c.table_name = t.table_name AND c.table_schema = t.table_schema) as column_count
FROM information_schema.tables t
WHERE table_schema = $1
ORDER BY table_name
`
rows, err := db.Query(query, schema)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
var tables []models.Table
for rows.Next() {
var t models.Table
var columnCount int
t.Schema = schema
if err := rows.Scan(&t.Name, &t.Type, &columnCount); err != nil {
return nil, fmt.Errorf("failed to scan table: %w", err)
}
tables = append(tables, t)
}
return tables, nil
}
// GetTableStructure returns the structure of a table
func (p *PostgreSQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
db := p.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
structure := &models.TableStructure{
TableName: tableName,
Schema: "public",
}
// Get columns
query := `
SELECT
c.column_name,
c.data_type,
c.is_nullable,
c.column_default,
c.character_maximum_length,
c.numeric_precision,
c.numeric_scale,
pg_catalog.col_description(c.oid, c.ordinal_position) as comment
FROM information_schema.columns c
JOIN pg_class cl ON cl.relname = c.table_name
WHERE c.table_schema = 'public' AND c.table_name = $1
ORDER BY c.ordinal_position
`
rows, err := db.Query(query, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query columns: %w", err)
}
defer rows.Close()
for rows.Next() {
var col models.TableColumn
var nullable string
var defaultVal, comment sql.NullString
var maxLength, precision, scale sql.NullInt32
if err := rows.Scan(
&col.Name, &col.DataType, &nullable, &defaultVal,
&maxLength, &precision, &scale, &comment,
); err != nil {
return nil, fmt.Errorf("failed to scan column: %w", err)
}
col.Nullable = nullable == "YES"
if defaultVal.Valid {
col.Default = defaultVal.String
}
if maxLength.Valid && maxLength.Int32 > 0 {
col.Length = int(maxLength.Int32)
}
if precision.Valid {
col.Scale = int(precision.Int32)
}
if scale.Valid {
col.Scale = int(scale.Int32)
}
if comment.Valid {
col.Comment = comment.String
}
structure.Columns = append(structure.Columns, col)
}
// Get primary key information
pkQuery := `
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1::regclass AND i.indisprimary
`
pkRows, err := db.Query(pkQuery, tableName)
if err != nil {
config.GetLogger().Warn("failed to query primary keys", zap.Error(err))
} else {
defer pkRows.Close()
pkColumns := make(map[string]bool)
for pkRows.Next() {
var colName string
if err := pkRows.Scan(&colName); err != nil {
continue
}
pkColumns[colName] = true
}
// Mark primary key columns
for i := range structure.Columns {
if pkColumns[structure.Columns[i].Name] {
structure.Columns[i].IsPrimary = true
}
}
}
// Get indexes
indexQuery := `
SELECT
i.indexname,
i.indexdef,
i.indisunique
FROM pg_indexes i
WHERE i.schemaname = 'public' AND i.tablename = $1
`
idxRows, err := db.Query(indexQuery, tableName)
if err != nil {
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
} else {
defer idxRows.Close()
for idxRows.Next() {
var idx models.TableIndex
var indexDef string
if err := idxRows.Scan(&idx.Name, &indexDef, &idx.IsUnique); err != nil {
continue
}
idx.IsPrimary = idx.Name == (tableName + "_pkey")
// Extract column names from index definition
idx.Columns = extractColumnsFromIndexDef(indexDef)
idx.Type = "btree" // Default for PostgreSQL
structure.Indexes = append(structure.Indexes, idx)
}
}
return structure, nil
}
// GetMetadata returns database metadata
func (p *PostgreSQLConnection) GetMetadata() (*models.DBMetadata, error) {
db := p.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
metadata := &models.DBMetadata{
Database: p.database,
User: p.username,
Host: p.host,
Port: p.port,
}
// Get PostgreSQL version
var version string
err := db.QueryRow("SELECT version()").Scan(&version)
if err == nil {
metadata.Version = version
}
// Get server time
var serverTime string
err = db.QueryRow("SELECT NOW()").Scan(&serverTime)
if err == nil {
metadata.ServerTime = serverTime
}
return metadata, nil
}
// extractColumnsFromIndexDef extracts column names from PostgreSQL index definition
func extractColumnsFromIndexDef(indexDef string) []string {
var columns []string
// Simple extraction - look for content between parentheses
start := strings.Index(indexDef, "(")
end := strings.LastIndex(indexDef, ")")
if start != -1 && end != -1 && end > start {
content := indexDef[start+1 : end]
parts := strings.Split(content, ",")
for _, part := range parts {
col := strings.TrimSpace(part)
// Remove any expressions like "LOWER(column)"
if parenIdx := strings.Index(col, "("); parenIdx != -1 {
continue
}
if col != "" {
columns = append(columns, col)
}
}
}
return columns
}

128
internal/database/sqlite.go Normal file
View File

@@ -0,0 +1,128 @@
package database
import (
"fmt"
"os"
"path/filepath"
"sync"
"time"
"go.uber.org/zap"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"uzdb/internal/config"
"uzdb/internal/models"
)
var (
sqliteDB *gorm.DB
sqliteMu sync.RWMutex
)
// InitSQLite initializes the SQLite database for application data
func InitSQLite(dbPath string, cfg *config.DatabaseConfig) (*gorm.DB, error) {
sqliteMu.Lock()
defer sqliteMu.Unlock()
// Ensure directory exists
dir := filepath.Dir(dbPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
// Configure GORM logger
var gormLogger logger.Interface
if config.Get().IsDevelopment() {
gormLogger = logger.Default.LogMode(logger.Info)
} else {
gormLogger = logger.Default.LogMode(logger.Silent)
}
// Open SQLite connection
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
Logger: gormLogger,
})
if err != nil {
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
}
// Set connection pool settings
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get underlying DB: %w", err)
}
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetConnMaxLifetime(time.Duration(cfg.MaxLifetime) * time.Minute)
// Enable WAL mode for better concurrency
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
config.GetLogger().Warn("failed to enable WAL mode", zap.Error(err))
}
// Run migrations
if err := runMigrations(db); err != nil {
return nil, fmt.Errorf("migration failed: %w", err)
}
sqliteDB = db
config.GetLogger().Info("SQLite database initialized",
zap.String("path", dbPath),
zap.Int("max_open_conns", cfg.MaxOpenConns),
zap.Int("max_idle_conns", cfg.MaxIdleConns),
)
return db, nil
}
// runMigrations runs database migrations
func runMigrations(db *gorm.DB) error {
migrations := []interface{}{
&models.UserConnection{},
&models.QueryHistory{},
&models.SavedQuery{},
}
for _, model := range migrations {
if err := db.AutoMigrate(model); err != nil {
return fmt.Errorf("failed to migrate %T: %w", model, err)
}
}
config.GetLogger().Info("database migrations completed")
return nil
}
// GetSQLiteDB returns the SQLite database instance
func GetSQLiteDB() *gorm.DB {
sqliteMu.RLock()
defer sqliteMu.RUnlock()
return sqliteDB
}
// CloseSQLite closes the SQLite database connection
func CloseSQLite() error {
sqliteMu.Lock()
defer sqliteMu.Unlock()
if sqliteDB == nil {
return nil
}
sqlDB, err := sqliteDB.DB()
if err != nil {
return err
}
if err := sqlDB.Close(); err != nil {
return fmt.Errorf("failed to close SQLite database: %w", err)
}
sqliteDB = nil
config.GetLogger().Info("SQLite database closed")
return nil
}

View File

@@ -0,0 +1,369 @@
package database
import (
"database/sql"
"fmt"
"sync"
"time"
_ "modernc.org/sqlite"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
)
// SQLiteConnection represents a SQLite connection
type SQLiteConnection struct {
db *sql.DB
filePath string
mu sync.RWMutex
}
// NewSQLiteConnection creates a new SQLite connection
func NewSQLiteConnection(conn *models.UserConnection) (*SQLiteConnection, error) {
filePath := conn.Database
// Ensure file exists for new connections
db, err := sql.Open("sqlite", filePath)
if err != nil {
return nil, fmt.Errorf("failed to open SQLite connection: %w", err)
}
// Configure connection pool
db.SetMaxOpenConns(1) // SQLite only supports one writer
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(5 * time.Minute)
// Enable WAL mode
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
config.GetLogger().Warn("failed to enable WAL mode", zap.Error(err))
}
// Test connection
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping SQLite: %w", err)
}
sqliteConn := &SQLiteConnection{
db: db,
filePath: filePath,
}
config.GetLogger().Info("SQLite connection established",
zap.String("path", filePath),
)
return sqliteConn, nil
}
// GetDB returns the underlying sql.DB
func (s *SQLiteConnection) GetDB() *sql.DB {
s.mu.RLock()
defer s.mu.RUnlock()
return s.db
}
// Close closes the SQLite connection
func (s *SQLiteConnection) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.db != nil {
if err := s.db.Close(); err != nil {
return fmt.Errorf("failed to close SQLite connection: %w", err)
}
s.db = nil
config.GetLogger().Info("SQLite connection closed",
zap.String("path", s.filePath),
)
}
return nil
}
// IsConnected checks if the connection is alive
func (s *SQLiteConnection) IsConnected() bool {
s.mu.RLock()
defer s.mu.RUnlock()
if s.db == nil {
return false
}
err := s.db.Ping()
return err == nil
}
// ExecuteQuery executes a SQL query and returns results
func (s *SQLiteConnection) ExecuteQuery(query string, args ...interface{}) (*models.QueryResult, error) {
startTime := time.Now()
db := s.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("query execution failed: %w", err)
}
defer rows.Close()
result := &models.QueryResult{
Success: true,
Duration: time.Since(startTime).Milliseconds(),
}
// Get column names
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %w", err)
}
result.Columns = columns
// Scan rows
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
row := make([]interface{}, len(columns))
for i, v := range values {
row[i] = convertValue(v)
}
result.Rows = append(result.Rows, row)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("row iteration error: %w", err)
}
result.RowCount = int64(len(result.Rows))
return result, nil
}
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
func (s *SQLiteConnection) ExecuteStatement(stmt string, args ...interface{}) (*models.QueryResult, error) {
startTime := time.Now()
db := s.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
res, err := db.Exec(stmt, args...)
if err != nil {
return nil, fmt.Errorf("statement execution failed: %w", err)
}
rowsAffected, _ := res.RowsAffected()
lastInsertID, _ := res.LastInsertId()
result := &models.QueryResult{
Success: true,
AffectedRows: rowsAffected,
Duration: time.Since(startTime).Milliseconds(),
Rows: [][]interface{}{{lastInsertID}},
Columns: []string{"last_insert_id"},
}
return result, nil
}
// GetTables returns all tables in the database
func (s *SQLiteConnection) GetTables(schema string) ([]models.Table, error) {
db := s.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
query := `
SELECT name, type
FROM sqlite_master
WHERE type IN ('table', 'view') AND name NOT LIKE 'sqlite_%'
ORDER BY name
`
rows, err := db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %w", err)
}
defer rows.Close()
var tables []models.Table
for rows.Next() {
var t models.Table
if err := rows.Scan(&t.Name, &t.Type); err != nil {
return nil, fmt.Errorf("failed to scan table: %w", err)
}
// Get row count for tables
if t.Type == "table" {
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM \"%s\"", t.Name)
var rowCount int64
if err := db.QueryRow(countQuery).Scan(&rowCount); err == nil {
t.RowCount = rowCount
}
}
tables = append(tables, t)
}
return tables, nil
}
// GetTableStructure returns the structure of a table
func (s *SQLiteConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
db := s.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
structure := &models.TableStructure{
TableName: tableName,
}
// Get table info using PRAGMA
query := fmt.Sprintf("PRAGMA table_info(\"%s\")", tableName)
rows, err := db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to query table info: %w", err)
}
defer rows.Close()
for rows.Next() {
var col models.TableColumn
var pk int
if err := rows.Scan(&col.Name, &col.DataType, &col.Default, &col.Nullable, &pk, &col.IsPrimary); err != nil {
return nil, fmt.Errorf("failed to scan column: %w", err)
}
col.IsPrimary = pk > 0
col.Nullable = col.Nullable || !col.IsPrimary
structure.Columns = append(structure.Columns, col)
}
// Get indexes
idxQuery := fmt.Sprintf("PRAGMA index_list(\"%s\")", tableName)
idxRows, err := db.Query(idxQuery)
if err != nil {
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
} else {
defer idxRows.Close()
for idxRows.Next() {
var idx models.TableIndex
var origin string
if err := idxRows.Scan(&idx.Name, &idx.IsUnique, &origin); err != nil {
continue
}
idx.IsPrimary = idx.Name == "sqlite_autoindex_"+tableName+"_1"
idx.Type = origin
// Get index columns
colQuery := fmt.Sprintf("PRAGMA index_info(\"%s\")", idx.Name)
colRows, err := db.Query(colQuery)
if err == nil {
defer colRows.Close()
for colRows.Next() {
var seqno int
var colName string
if err := colRows.Scan(&seqno, &colName, &colName); err != nil {
continue
}
idx.Columns = append(idx.Columns, colName)
}
}
structure.Indexes = append(structure.Indexes, idx)
}
}
// Get foreign keys
fkQuery := fmt.Sprintf("PRAGMA foreign_key_list(\"%s\")", tableName)
fkRows, err := db.Query(fkQuery)
if err != nil {
config.GetLogger().Warn("failed to query foreign keys", zap.Error(err))
} else {
defer fkRows.Close()
fkMap := make(map[string]*models.ForeignKey)
for fkRows.Next() {
var fk models.ForeignKey
var id, seq int
if err := fkRows.Scan(&id, &seq, &fk.Name, &fk.ReferencedTable,
&fk.Columns, &fk.ReferencedColumns, &fk.OnUpdate, &fk.OnDelete, ""); err != nil {
continue
}
// Handle array fields properly
fk.Columns = []string{}
fk.ReferencedColumns = []string{}
var fromCol, toCol string
if err := fkRows.Scan(&id, &seq, &fk.Name, &fk.ReferencedTable,
&fromCol, &toCol, &fk.OnUpdate, &fk.OnDelete, ""); err != nil {
continue
}
existingFk, exists := fkMap[fk.Name]
if !exists {
existingFk = &models.ForeignKey{
Name: fk.Name,
ReferencedTable: fk.ReferencedTable,
OnDelete: fk.OnDelete,
OnUpdate: fk.OnUpdate,
Columns: []string{},
ReferencedColumns: []string{},
}
fkMap[fk.Name] = existingFk
}
existingFk.Columns = append(existingFk.Columns, fromCol)
existingFk.ReferencedColumns = append(existingFk.ReferencedColumns, toCol)
}
for _, fk := range fkMap {
structure.ForeignKeys = append(structure.ForeignKeys, *fk)
}
}
return structure, nil
}
// GetMetadata returns database metadata
func (s *SQLiteConnection) GetMetadata() (*models.DBMetadata, error) {
db := s.GetDB()
if db == nil {
return nil, fmt.Errorf("connection is closed")
}
metadata := &models.DBMetadata{
Database: s.filePath,
Version: "3", // SQLite version 3
}
// Get SQLite version
var version string
err := db.QueryRow("SELECT sqlite_version()").Scan(&version)
if err == nil {
metadata.Version = version
}
// Get current time
metadata.ServerTime = time.Now().Format(time.RFC3339)
return metadata, nil
}

View File

@@ -0,0 +1,195 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
"uzdb/internal/services"
"uzdb/internal/utils"
)
// ConnectionHandler handles connection-related HTTP requests
type ConnectionHandler struct {
connectionSvc *services.ConnectionService
}
// NewConnectionHandler creates a new connection handler
func NewConnectionHandler(connectionSvc *services.ConnectionService) *ConnectionHandler {
return &ConnectionHandler{
connectionSvc: connectionSvc,
}
}
// GetAllConnections handles GET /api/connections
func (h *ConnectionHandler) GetAllConnections(c *gin.Context) {
ctx := c.Request.Context()
connections, err := h.connectionSvc.GetAllConnections(ctx)
if err != nil {
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get connections")
return
}
utils.SuccessResponse(c, gin.H{
"connections": connections,
})
}
// GetConnection handles GET /api/connections/:id
func (h *ConnectionHandler) GetConnection(c *gin.Context) {
ctx := c.Request.Context()
id := c.Param("id")
if id == "" {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
return
}
conn, err := h.connectionSvc.GetConnectionByID(ctx, id)
if err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get connection")
return
}
utils.SuccessResponse(c, gin.H{
"connection": conn,
})
}
// CreateConnection handles POST /api/connections
func (h *ConnectionHandler) CreateConnection(c *gin.Context) {
ctx := c.Request.Context()
var req models.CreateConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
return
}
conn, err := h.connectionSvc.CreateConnection(ctx, &req)
if err != nil {
if err == models.ErrValidationFailed {
utils.ErrorResponse(c, http.StatusBadRequest, err, "Validation failed")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to create connection")
return
}
config.GetLogger().Info("connection created via API",
zap.String("id", conn.ID),
zap.String("name", conn.Name))
utils.CreatedResponse(c, gin.H{
"connection": conn,
})
}
// UpdateConnection handles PUT /api/connections/:id
func (h *ConnectionHandler) UpdateConnection(c *gin.Context) {
ctx := c.Request.Context()
id := c.Param("id")
if id == "" {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
return
}
var req models.UpdateConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
return
}
conn, err := h.connectionSvc.UpdateConnection(ctx, id, &req)
if err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to update connection")
return
}
utils.SuccessResponse(c, gin.H{
"connection": conn,
})
}
// DeleteConnection handles DELETE /api/connections/:id
func (h *ConnectionHandler) DeleteConnection(c *gin.Context) {
ctx := c.Request.Context()
id := c.Param("id")
if id == "" {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
return
}
if err := h.connectionSvc.DeleteConnection(ctx, id); err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to delete connection")
return
}
utils.SuccessResponse(c, gin.H{
"message": "Connection deleted successfully",
})
}
// TestConnection handles POST /api/connections/:id/test
func (h *ConnectionHandler) TestConnection(c *gin.Context) {
ctx := c.Request.Context()
id := c.Param("id")
if id == "" {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
return
}
result, err := h.connectionSvc.TestConnection(ctx, id)
if err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to test connection")
return
}
statusCode := http.StatusOK
if !result.Success {
statusCode = http.StatusBadGateway
}
c.JSON(statusCode, gin.H{
"success": result.Success,
"message": result.Message,
"duration_ms": result.Duration,
"metadata": result.Metadata,
})
}
// RegisterRoutes registers connection routes
func (h *ConnectionHandler) RegisterRoutes(router *gin.RouterGroup) {
connections := router.Group("/connections")
{
connections.GET("", h.GetAllConnections)
connections.GET("/:id", h.GetConnection)
connections.POST("", h.CreateConnection)
connections.PUT("/:id", h.UpdateConnection)
connections.DELETE("/:id", h.DeleteConnection)
connections.POST("/:id/test", h.TestConnection)
}
}

287
internal/handler/query.go Normal file
View File

@@ -0,0 +1,287 @@
package handler
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"uzdb/internal/models"
"uzdb/internal/services"
"uzdb/internal/utils"
)
// QueryHandler handles query-related HTTP requests
type QueryHandler struct {
connectionSvc *services.ConnectionService
querySvc *services.QueryService
}
// NewQueryHandler creates a new query handler
func NewQueryHandler(
connectionSvc *services.ConnectionService,
querySvc *services.QueryService,
) *QueryHandler {
return &QueryHandler{
connectionSvc: connectionSvc,
querySvc: querySvc,
}
}
// ExecuteQuery handles POST /api/query
func (h *QueryHandler) ExecuteQuery(c *gin.Context) {
ctx := c.Request.Context()
var req struct {
ConnectionID string `json:"connection_id" binding:"required"`
SQL string `json:"sql" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
return
}
result, err := h.connectionSvc.ExecuteQuery(ctx, req.ConnectionID, req.SQL)
if err != nil {
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Query execution failed")
return
}
utils.SuccessResponse(c, gin.H{
"result": result,
})
}
// GetTables handles GET /api/connections/:id/tables
func (h *QueryHandler) GetTables(c *gin.Context) {
ctx := c.Request.Context()
connectionID := c.Param("id")
if connectionID == "" {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
return
}
tables, err := h.connectionSvc.GetTables(ctx, connectionID)
if err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get tables")
return
}
utils.SuccessResponse(c, gin.H{
"tables": tables,
})
}
// GetTableData handles GET /api/connections/:id/tables/:name/data
func (h *QueryHandler) GetTableData(c *gin.Context) {
ctx := c.Request.Context()
connectionID := c.Param("id")
tableName := c.Param("name")
if connectionID == "" || tableName == "" {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID and table name are required")
return
}
// Parse limit and offset
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
result, err := h.connectionSvc.GetTableData(ctx, connectionID, tableName, limit, offset)
if err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get table data")
return
}
utils.SuccessResponse(c, gin.H{
"result": result,
"table": tableName,
"limit": limit,
"offset": offset,
})
}
// GetTableStructure handles GET /api/connections/:id/tables/:name/structure
func (h *QueryHandler) GetTableStructure(c *gin.Context) {
ctx := c.Request.Context()
connectionID := c.Param("id")
tableName := c.Param("name")
if connectionID == "" || tableName == "" {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID and table name are required")
return
}
structure, err := h.connectionSvc.GetTableStructure(ctx, connectionID, tableName)
if err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get table structure")
return
}
utils.SuccessResponse(c, gin.H{
"structure": structure,
})
}
// GetQueryHistory handles GET /api/history
func (h *QueryHandler) GetQueryHistory(c *gin.Context) {
ctx := c.Request.Context()
connectionID := c.Query("connection_id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
history, total, err := h.querySvc.GetQueryHistory(ctx, connectionID, page, pageSize)
if err != nil {
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get query history")
return
}
totalPages := int(total) / pageSize
if int(total)%pageSize > 0 {
totalPages++
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"history": history,
"total": total,
"page": page,
"page_size": pageSize,
"total_pages": totalPages,
},
})
}
// GetSavedQueries handles GET /api/saved-queries
func (h *QueryHandler) GetSavedQueries(c *gin.Context) {
ctx := c.Request.Context()
connectionID := c.Query("connection_id")
queries, err := h.querySvc.GetSavedQueries(ctx, connectionID)
if err != nil {
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get saved queries")
return
}
utils.SuccessResponse(c, gin.H{
"queries": queries,
})
}
// CreateSavedQuery handles POST /api/saved-queries
func (h *QueryHandler) CreateSavedQuery(c *gin.Context) {
ctx := c.Request.Context()
var req models.CreateSavedQueryRequest
if err := c.ShouldBindJSON(&req); err != nil {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
return
}
query, err := h.querySvc.CreateSavedQuery(ctx, &req)
if err != nil {
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to save query")
return
}
utils.CreatedResponse(c, gin.H{
"query": query,
})
}
// UpdateSavedQuery handles PUT /api/saved-queries/:id
func (h *QueryHandler) UpdateSavedQuery(c *gin.Context) {
ctx := c.Request.Context()
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid query ID")
return
}
var req models.UpdateSavedQueryRequest
if err := c.ShouldBindJSON(&req); err != nil {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
return
}
query, err := h.querySvc.UpdateSavedQuery(ctx, uint(id), &req)
if err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Saved query not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to update saved query")
return
}
utils.SuccessResponse(c, gin.H{
"query": query,
})
}
// DeleteSavedQuery handles DELETE /api/saved-queries/:id
func (h *QueryHandler) DeleteSavedQuery(c *gin.Context) {
ctx := c.Request.Context()
idStr := c.Param("id")
id, err := strconv.ParseUint(idStr, 10, 64)
if err != nil {
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid query ID")
return
}
if err := h.querySvc.DeleteSavedQuery(ctx, uint(id)); err != nil {
if err == models.ErrNotFound {
utils.ErrorResponse(c, http.StatusNotFound, err, "Saved query not found")
return
}
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to delete saved query")
return
}
utils.SuccessResponse(c, gin.H{
"message": "Saved query deleted successfully",
})
}
// RegisterRoutes registers query routes
func (h *QueryHandler) RegisterRoutes(router *gin.RouterGroup) {
// Query execution
router.POST("/query", h.ExecuteQuery)
// Table operations
router.GET("/connections/:id/tables", h.GetTables)
router.GET("/connections/:id/tables/:name/data", h.GetTableData)
router.GET("/connections/:id/tables/:name/structure", h.GetTableStructure)
// Query history
router.GET("/history", h.GetQueryHistory)
// Saved queries
savedQueries := router.Group("/saved-queries")
{
savedQueries.GET("", h.GetSavedQueries)
savedQueries.POST("", h.CreateSavedQuery)
savedQueries.PUT("/:id", h.UpdateSavedQuery)
savedQueries.DELETE("/:id", h.DeleteSavedQuery)
}
}

View File

@@ -0,0 +1,95 @@
package handler
import (
"context"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/middleware"
"uzdb/internal/services"
)
// HTTPServer represents the HTTP API server
type HTTPServer struct {
engine *gin.Engine
server *http.Server
}
// NewHTTPServer creates a new HTTP server
func NewHTTPServer(
connectionSvc *services.ConnectionService,
querySvc *services.QueryService,
) *HTTPServer {
// Set Gin mode based on environment
cfg := config.Get()
if cfg.IsProduction() {
gin.SetMode(gin.ReleaseMode)
}
engine := gin.New()
// Create handlers
connectionHandler := NewConnectionHandler(connectionSvc)
queryHandler := NewQueryHandler(connectionSvc, querySvc)
// Setup middleware
engine.Use(middleware.RecoveryMiddleware())
engine.Use(middleware.LoggerMiddleware())
engine.Use(middleware.CORSMiddleware())
engine.Use(middleware.SecureHeadersMiddleware())
// Setup routes
api := engine.Group("/api")
{
connectionHandler.RegisterRoutes(api)
queryHandler.RegisterRoutes(api)
}
// Health check endpoint
engine.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"timestamp": time.Now().Format(time.RFC3339),
})
})
return &HTTPServer{
engine: engine,
}
}
// Start starts the HTTP server
func (s *HTTPServer) Start(port string) error {
s.server = &http.Server{
Addr: ":" + port,
Handler: s.engine,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
config.GetLogger().Info("starting HTTP API server",
zap.String("port", port))
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("failed to start HTTP server: %w", err)
}
return nil
}
// Shutdown gracefully shuts down the HTTP server
func (s *HTTPServer) Shutdown(ctx context.Context) error {
if s.server == nil {
return nil
}
config.GetLogger().Info("shutting down HTTP server")
return s.server.Shutdown(ctx)
}

101
internal/middleware/cors.go Normal file
View File

@@ -0,0 +1,101 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"uzdb/internal/config"
)
// CORSMiddleware returns a CORS middleware
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
cfg := config.Get()
// Set CORS headers
c.Header("Access-Control-Allow-Origin", getAllowedOrigin(c, cfg))
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-Request-ID")
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, PATCH, DELETE")
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type, X-Request-ID")
c.Header("Access-Control-Max-Age", "43200") // 12 hours
// Handle preflight requests
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
// getAllowedOrigin returns the allowed origin based on configuration
func getAllowedOrigin(c *gin.Context, cfg *config.Config) string {
origin := c.GetHeader("Origin")
// If no origin header, return empty (same-origin)
if origin == "" {
return ""
}
// In development mode, allow all origins
if cfg.IsDevelopment() {
return origin
}
// In production, validate against allowed origins
allowedOrigins := []string{
"http://localhost:3000",
"http://localhost:8080",
"http://127.0.0.1:3000",
"http://127.0.0.1:8080",
}
for _, allowed := range allowedOrigins {
if origin == allowed {
return origin
}
}
// Check for wildcard patterns
for _, allowed := range allowedOrigins {
if strings.HasSuffix(allowed, "*") {
prefix := strings.TrimSuffix(allowed, "*")
if strings.HasPrefix(origin, prefix) {
return origin
}
}
}
// Default to empty (deny) in production if not matched
if cfg.IsProduction() {
return ""
}
return origin
}
// SecureHeadersMiddleware adds security-related headers
func SecureHeadersMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Prevent MIME type sniffing
c.Header("X-Content-Type-Options", "nosniff")
// Enable XSS filter
c.Header("X-XSS-Protection", "1; mode=block")
// Prevent clickjacking
c.Header("X-Frame-Options", "DENY")
// Referrer policy
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy (adjust as needed)
c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
c.Next()
}
}

View File

@@ -0,0 +1,125 @@
package middleware
import (
"errors"
"net/http"
"runtime/debug"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
)
// RecoveryMiddleware returns a recovery middleware that handles panics
func RecoveryMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
logger := config.GetLogger()
// Log the panic with stack trace
logger.Error("panic recovered",
zap.Any("error", err),
zap.String("stack", string(debug.Stack())),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
// Send error response
c.JSON(http.StatusInternalServerError, models.ErrorResponse{
Error: "INTERNAL_ERROR",
Message: "An unexpected error occurred",
Timestamp: time.Now(),
Path: c.Request.URL.Path,
})
c.Abort()
}
}()
c.Next()
}
}
// ErrorMiddleware returns an error handling middleware
func ErrorMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// Check if there are any errors
if len(c.Errors) > 0 {
handleErrors(c)
}
}
}
// handleErrors processes and formats errors
func handleErrors(c *gin.Context) {
logger := config.GetLogger()
for _, e := range c.Errors {
err := e.Err
// Log the error
logger.Error("request error",
zap.String("error", err.Error()),
zap.Any("type", e.Type),
zap.String("path", c.Request.URL.Path),
)
// Determine status code
statusCode := http.StatusInternalServerError
// Map specific errors to status codes
switch {
case errors.Is(err, models.ErrNotFound):
statusCode = http.StatusNotFound
case errors.Is(err, models.ErrValidationFailed):
statusCode = http.StatusBadRequest
case errors.Is(err, models.ErrUnauthorized):
statusCode = http.StatusUnauthorized
case errors.Is(err, models.ErrForbidden):
statusCode = http.StatusForbidden
case errors.Is(err, models.ErrConnectionFailed):
statusCode = http.StatusBadGateway
}
// Send response if not already sent
if !c.Writer.Written() {
c.JSON(statusCode, models.ErrorResponse{
Error: getErrorCode(err),
Message: err.Error(),
Timestamp: time.Now(),
Path: c.Request.URL.Path,
})
}
// Only handle first error
break
}
}
// getErrorCode maps errors to error codes
func getErrorCode(err error) string {
switch {
case errors.Is(err, models.ErrNotFound):
return string(models.CodeNotFound)
case errors.Is(err, models.ErrValidationFailed):
return string(models.CodeValidation)
case errors.Is(err, models.ErrUnauthorized):
return string(models.CodeUnauthorized)
case errors.Is(err, models.ErrForbidden):
return string(models.CodeForbidden)
case errors.Is(err, models.ErrConnectionFailed):
return string(models.CodeConnection)
case errors.Is(err, models.ErrQueryFailed):
return string(models.CodeQuery)
case errors.Is(err, models.ErrEncryptionFailed):
return string(models.CodeEncryption)
default:
return string(models.CodeInternal)
}
}

View File

@@ -0,0 +1,98 @@
package middleware
import (
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"uzdb/internal/config"
)
// LoggerMiddleware returns a logging middleware using Zap
func LoggerMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
logger := config.GetLogger()
// Start timer
start := time.Now()
// Process request
c.Next()
// Calculate duration
duration := time.Since(start)
// Get client IP
clientIP := c.ClientIP()
// Get method
method := c.Request.Method
// Get path
path := c.Request.URL.Path
// Get status code
statusCode := c.Writer.Status()
// Get body size
bodySize := c.Writer.Size()
// Determine log level based on status code
var level zapcore.Level
switch {
case statusCode >= 500:
level = zapcore.ErrorLevel
case statusCode >= 400:
level = zapcore.WarnLevel
default:
level = zapcore.InfoLevel
}
// Create fields
fields := []zap.Field{
zap.Int("status", statusCode),
zap.String("method", method),
zap.String("path", path),
zap.String("ip", clientIP),
zap.Duration("duration", duration),
zap.Int("body_size", bodySize),
}
// Add error message if any
if len(c.Errors) > 0 {
fields = append(fields, zap.String("error", c.Errors.String()))
}
// Log the request
logger.Log(level, "HTTP request", fields...)
}
}
// RequestIDMiddleware adds a request ID to each request
func RequestIDMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Try to get request ID from header
requestID := c.GetHeader("X-Request-ID")
// Generate if not present
if requestID == "" {
requestID = generateRequestID()
}
// Set request ID in context
c.Set("request_id", requestID)
// Add to response header
c.Header("X-Request-ID", requestID)
c.Next()
}
}
// generateRequestID generates a simple request ID
func generateRequestID() string {
return time.Now().Format("20060102150405") + "-" +
time.Now().Format("000000.000000")[7:]
}

View File

@@ -0,0 +1,85 @@
package models
import (
"time"
)
// ConnectionType represents the type of database connection
type ConnectionType string
const (
// ConnectionTypeMySQL represents MySQL database
ConnectionTypeMySQL ConnectionType = "mysql"
// ConnectionTypePostgreSQL represents PostgreSQL database
ConnectionTypePostgreSQL ConnectionType = "postgres"
// ConnectionTypeSQLite represents SQLite database
ConnectionTypeSQLite ConnectionType = "sqlite"
)
// UserConnection represents a user's database connection configuration
// This model is stored in the local SQLite database
type UserConnection struct {
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
Name string `gorm:"type:varchar(100);not null" json:"name"`
Type ConnectionType `gorm:"type:varchar(20);not null" json:"type"`
Host string `gorm:"type:varchar(255)" json:"host,omitempty"`
Port int `gorm:"type:integer" json:"port,omitempty"`
Username string `gorm:"type:varchar(100)" json:"username,omitempty"`
Password string `gorm:"type:text" json:"password"` // Encrypted password
Database string `gorm:"type:varchar(255)" json:"database"`
SSLMode string `gorm:"type:varchar(50)" json:"ssl_mode,omitempty"`
Timeout int `gorm:"type:integer;default:30" json:"timeout"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
// For SQLite connections, Database field contains the file path
}
// TableName returns the table name for UserConnection
func (UserConnection) TableName() string {
return "user_connections"
}
// CreateConnectionRequest represents a request to create a new connection
type CreateConnectionRequest struct {
Name string `json:"name" binding:"required"`
Type ConnectionType `json:"type" binding:"required"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Database string `json:"database" binding:"required"`
SSLMode string `json:"ssl_mode"`
Timeout int `json:"timeout"`
}
// UpdateConnectionRequest represents a request to update an existing connection
type UpdateConnectionRequest struct {
Name string `json:"name"`
Type ConnectionType `json:"type"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
Timeout int `json:"timeout"`
}
// Validate validates the connection request
func (r *CreateConnectionRequest) Validate() error {
switch r.Type {
case ConnectionTypeMySQL, ConnectionTypePostgreSQL:
if r.Host == "" {
return ErrValidationFailed
}
if r.Port <= 0 || r.Port > 65535 {
return ErrValidationFailed
}
case ConnectionTypeSQLite:
if r.Database == "" {
return ErrValidationFailed
}
}
return nil
}

74
internal/models/errors.go Normal file
View File

@@ -0,0 +1,74 @@
package models
import "errors"
// Application errors
var (
// ErrNotFound resource not found
ErrNotFound = errors.New("resource not found")
// ErrAlreadyExists resource already exists
ErrAlreadyExists = errors.New("resource already exists")
// ErrValidationFailed validation failed
ErrValidationFailed = errors.New("validation failed")
// ErrUnauthorized unauthorized access
ErrUnauthorized = errors.New("unauthorized access")
// ErrForbidden forbidden access
ErrForbidden = errors.New("forbidden access")
// ErrInternalServer internal server error
ErrInternalServer = errors.New("internal server error")
// ErrConnectionFailed connection failed
ErrConnectionFailed = errors.New("connection failed")
// ErrQueryFailed query execution failed
ErrQueryFailed = errors.New("query execution failed")
// ErrEncryptionFailed encryption/decryption failed
ErrEncryptionFailed = errors.New("encryption/decryption failed")
// ErrInvalidConfig invalid configuration
ErrInvalidConfig = errors.New("invalid configuration")
// ErrDatabaseLocked database is locked
ErrDatabaseLocked = errors.New("database is locked")
// ErrTimeout operation timeout
ErrTimeout = errors.New("operation timeout")
)
// ErrorCode represents error codes for API responses
type ErrorCode string
const (
// CodeSuccess successful operation
CodeSuccess ErrorCode = "SUCCESS"
// CodeNotFound resource not found
CodeNotFound ErrorCode = "NOT_FOUND"
// CodeValidation validation error
CodeValidation ErrorCode = "VALIDATION_ERROR"
// CodeUnauthorized unauthorized
CodeUnauthorized ErrorCode = "UNAUTHORIZED"
// CodeForbidden forbidden
CodeForbidden ErrorCode = "FORBIDDEN"
// CodeInternal internal error
CodeInternal ErrorCode = "INTERNAL_ERROR"
// CodeConnection connection error
CodeConnection ErrorCode = "CONNECTION_ERROR"
// CodeQuery query error
CodeQuery ErrorCode = "QUERY_ERROR"
// CodeEncryption encryption error
CodeEncryption ErrorCode = "ENCRYPTION_ERROR"
)

58
internal/models/query.go Normal file
View File

@@ -0,0 +1,58 @@
package models
import (
"time"
)
// QueryHistory represents a record of executed queries
type QueryHistory struct {
ID uint `gorm:"primaryKey" json:"id"`
ConnectionID string `gorm:"type:varchar(36);not null;index" json:"connection_id"`
SQL string `gorm:"type:text;not null" json:"sql"`
Duration int64 `gorm:"type:bigint" json:"duration_ms"` // Duration in milliseconds
ExecutedAt time.Time `gorm:"autoCreateTime;index" json:"executed_at"`
RowsAffected int64 `gorm:"type:bigint" json:"rows_affected"`
Error string `gorm:"type:text" json:"error,omitempty"`
Success bool `gorm:"type:boolean;default:false" json:"success"`
ResultPreview string `gorm:"type:text" json:"result_preview,omitempty"` // JSON preview of first few rows
}
// TableName returns the table name for QueryHistory
func (QueryHistory) TableName() string {
return "query_history"
}
// SavedQuery represents a saved SQL query
type SavedQuery struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"type:varchar(255);not null" json:"name"`
Description string `gorm:"type:text" json:"description"`
SQL string `gorm:"type:text;not null" json:"sql"`
ConnectionID string `gorm:"type:varchar(36);index" json:"connection_id"`
Tags string `gorm:"type:text" json:"tags"` // Comma-separated tags
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
}
// TableName returns the table name for SavedQuery
func (SavedQuery) TableName() string {
return "saved_queries"
}
// CreateSavedQueryRequest represents a request to save a query
type CreateSavedQueryRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
SQL string `json:"sql" binding:"required"`
ConnectionID string `json:"connection_id"`
Tags string `json:"tags"`
}
// UpdateSavedQueryRequest represents a request to update a saved query
type UpdateSavedQueryRequest struct {
Name string `json:"name"`
Description string `json:"description"`
SQL string `json:"sql"`
ConnectionID string `json:"connection_id"`
Tags string `json:"tags"`
}

128
internal/models/response.go Normal file
View File

@@ -0,0 +1,128 @@
package models
import (
"time"
)
// APIResponse represents a standard API response
type APIResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// QueryResult represents the result of a SQL query execution
type QueryResult struct {
Columns []string `json:"columns"`
Rows [][]interface{} `json:"rows"`
RowCount int64 `json:"row_count"`
AffectedRows int64 `json:"affected_rows"`
Duration int64 `json:"duration_ms"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// Table represents a database table
type Table struct {
Name string `json:"name"`
Schema string `json:"schema,omitempty"`
Type string `json:"type"` // table, view, etc.
RowCount int64 `json:"row_count,omitempty"`
Description string `json:"description,omitempty"`
}
// TableStructure represents the structure of a database table
type TableStructure struct {
TableName string `json:"table_name"`
Schema string `json:"schema,omitempty"`
Columns []TableColumn `json:"columns"`
Indexes []TableIndex `json:"indexes,omitempty"`
ForeignKeys []ForeignKey `json:"foreign_keys,omitempty"`
}
// TableColumn represents a column in a database table
type TableColumn struct {
Name string `json:"name"`
DataType string `json:"data_type"`
Nullable bool `json:"nullable"`
Default string `json:"default,omitempty"`
IsPrimary bool `json:"is_primary"`
IsUnique bool `json:"is_unique"`
AutoIncrement bool `json:"auto_increment"`
Length int `json:"length,omitempty"`
Scale int `json:"scale,omitempty"`
Comment string `json:"comment,omitempty"`
}
// TableIndex represents an index on a database table
type TableIndex struct {
Name string `json:"name"`
Columns []string `json:"columns"`
IsUnique bool `json:"is_unique"`
IsPrimary bool `json:"is_primary"`
Type string `json:"type,omitempty"`
}
// ForeignKey represents a foreign key constraint
type ForeignKey struct {
Name string `json:"name"`
Columns []string `json:"columns"`
ReferencedTable string `json:"referenced_table"`
ReferencedColumns []string `json:"referenced_columns"`
OnDelete string `json:"on_delete,omitempty"`
OnUpdate string `json:"on_update,omitempty"`
}
// ConnectionTestResult represents the result of a connection test
type ConnectionTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
Duration int64 `json:"duration_ms"`
Metadata *DBMetadata `json:"metadata,omitempty"`
}
// DBMetadata represents database metadata
type DBMetadata struct {
Version string `json:"version"`
Database string `json:"database"`
User string `json:"user"`
Host string `json:"host"`
Port int `json:"port"`
ServerTime string `json:"server_time"`
}
// ErrorResponse represents an error response
type ErrorResponse struct {
Error string `json:"error"`
Message string `json:"message"`
Code string `json:"code,omitempty"`
Details map[string]interface{} `json:"details,omitempty"`
Timestamp time.Time `json:"timestamp"`
Path string `json:"path,omitempty"`
}
// PaginatedResponse represents a paginated response
type PaginatedResponse struct {
Data interface{} `json:"data"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}
// NewAPIResponse creates a new API response
func NewAPIResponse(data interface{}) *APIResponse {
return &APIResponse{
Success: true,
Data: data,
}
}
// NewErrorResponse creates a new error response
func NewErrorResponse(message string) *APIResponse {
return &APIResponse{
Success: false,
Error: message,
}
}

View File

@@ -0,0 +1,382 @@
package services
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"uzdb/internal/config"
"uzdb/internal/database"
"uzdb/internal/models"
"uzdb/internal/utils"
)
// ConnectionService manages database connections
type ConnectionService struct {
db *gorm.DB
connManager *database.ConnectionManager
encryptSvc *EncryptionService
}
// NewConnectionService creates a new connection service
func NewConnectionService(
db *gorm.DB,
connManager *database.ConnectionManager,
encryptSvc *EncryptionService,
) *ConnectionService {
return &ConnectionService{
db: db,
connManager: connManager,
encryptSvc: encryptSvc,
}
}
// GetAllConnections returns all user connections
func (s *ConnectionService) GetAllConnections(ctx context.Context) ([]models.UserConnection, error) {
var connections []models.UserConnection
result := s.db.WithContext(ctx).Find(&connections)
if result.Error != nil {
return nil, fmt.Errorf("failed to get connections: %w", result.Error)
}
// Mask passwords in response
for i := range connections {
connections[i].Password = s.encryptSvc.MaskPasswordForLogging(connections[i].Password)
}
config.GetLogger().Debug("retrieved all connections",
zap.Int("count", len(connections)))
return connections, nil
}
// GetConnectionByID returns a connection by ID
func (s *ConnectionService) GetConnectionByID(ctx context.Context, id string) (*models.UserConnection, error) {
var conn models.UserConnection
result := s.db.WithContext(ctx).First(&conn, "id = ?", id)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, models.ErrNotFound
}
return nil, fmt.Errorf("failed to get connection: %w", result.Error)
}
return &conn, nil
}
// CreateConnection creates a new connection
func (s *ConnectionService) CreateConnection(ctx context.Context, req *models.CreateConnectionRequest) (*models.UserConnection, error) {
// Validate request
if err := req.Validate(); err != nil {
return nil, models.ErrValidationFailed
}
// Encrypt password
encryptedPassword := req.Password
if req.Password != "" {
var err error
encryptedPassword, err = s.encryptSvc.EncryptPassword(req.Password)
if err != nil {
return nil, fmt.Errorf("failed to encrypt password: %w", err)
}
}
conn := &models.UserConnection{
ID: utils.GenerateID(),
Name: req.Name,
Type: req.Type,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: encryptedPassword,
Database: req.Database,
SSLMode: req.SSLMode,
Timeout: req.Timeout,
}
if conn.Timeout <= 0 {
conn.Timeout = 30
}
result := s.db.WithContext(ctx).Create(conn)
if result.Error != nil {
return nil, fmt.Errorf("failed to create connection: %w", result.Error)
}
// Mask password in response
conn.Password = s.encryptSvc.MaskPasswordForLogging(conn.Password)
config.GetLogger().Info("connection created",
zap.String("id", conn.ID),
zap.String("name", conn.Name),
zap.String("type", string(conn.Type)))
return conn, nil
}
// UpdateConnection updates an existing connection
func (s *ConnectionService) UpdateConnection(ctx context.Context, id string, req *models.UpdateConnectionRequest) (*models.UserConnection, error) {
// Get existing connection
existing, err := s.GetConnectionByID(ctx, id)
if err != nil {
return nil, err
}
// Update fields
if req.Name != "" {
existing.Name = req.Name
}
if req.Type != "" {
existing.Type = req.Type
}
if req.Host != "" {
existing.Host = req.Host
}
if req.Port > 0 {
existing.Port = req.Port
}
if req.Username != "" {
existing.Username = req.Username
}
if req.Password != "" {
// Encrypt new password
encryptedPassword, err := s.encryptSvc.EncryptPassword(req.Password)
if err != nil {
return nil, fmt.Errorf("failed to encrypt password: %w", err)
}
existing.Password = encryptedPassword
}
if req.Database != "" {
existing.Database = req.Database
}
if req.SSLMode != "" {
existing.SSLMode = req.SSLMode
}
if req.Timeout > 0 {
existing.Timeout = req.Timeout
}
result := s.db.WithContext(ctx).Save(existing)
if result.Error != nil {
return nil, fmt.Errorf("failed to update connection: %w", result.Error)
}
// Remove cached connection if exists
if err := s.connManager.RemoveConnection(id); err != nil {
config.GetLogger().Warn("failed to remove cached connection", zap.Error(err))
}
// Mask password in response
existing.Password = s.encryptSvc.MaskPasswordForLogging(existing.Password)
config.GetLogger().Info("connection updated",
zap.String("id", id),
zap.String("name", existing.Name))
return existing, nil
}
// DeleteConnection deletes a connection
func (s *ConnectionService) DeleteConnection(ctx context.Context, id string) error {
// Check if connection exists
if _, err := s.GetConnectionByID(ctx, id); err != nil {
return err
}
// Remove from connection manager
if err := s.connManager.RemoveConnection(id); err != nil {
config.GetLogger().Warn("failed to remove from connection manager", zap.Error(err))
}
// Delete from database
result := s.db.WithContext(ctx).Delete(&models.UserConnection{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("failed to delete connection: %w", result.Error)
}
config.GetLogger().Info("connection deleted", zap.String("id", id))
return nil
}
// TestConnection tests a database connection
func (s *ConnectionService) TestConnection(ctx context.Context, id string) (*models.ConnectionTestResult, error) {
// Get connection config
conn, err := s.GetConnectionByID(ctx, id)
if err != nil {
return nil, err
}
// Decrypt password
password, err := s.encryptSvc.DecryptPassword(conn.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
startTime := time.Now()
// Create temporary connection
tempConn, err := s.connManager.GetConnection(conn, password)
if err != nil {
return &models.ConnectionTestResult{
Success: false,
Message: fmt.Sprintf("Connection failed: %v", err),
Duration: time.Since(startTime).Milliseconds(),
}, nil
}
// Get metadata
metadata, err := tempConn.GetMetadata()
if err != nil {
config.GetLogger().Warn("failed to get metadata", zap.Error(err))
}
return &models.ConnectionTestResult{
Success: true,
Message: "Connection successful",
Duration: time.Since(startTime).Milliseconds(),
Metadata: metadata,
}, nil
}
// ExecuteQuery executes a SQL query on a connection
func (s *ConnectionService) ExecuteQuery(ctx context.Context, connectionID, sql string) (*models.QueryResult, error) {
// Get connection config
conn, err := s.GetConnectionByID(ctx, connectionID)
if err != nil {
return nil, err
}
// Decrypt password
password, err := s.encryptSvc.DecryptPassword(conn.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
// Get or create connection
dbConn, err := s.connManager.GetConnection(conn, password)
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
// Execute query
startTime := time.Now()
var result *models.QueryResult
if utils.IsReadOnlyQuery(sql) {
result, err = dbConn.ExecuteQuery(sql)
} else {
result, err = dbConn.ExecuteStatement(sql)
}
duration := time.Since(startTime)
// Record in history
history := &models.QueryHistory{
ConnectionID: connectionID,
SQL: utils.TruncateString(sql, 10000),
Duration: duration.Milliseconds(),
Success: err == nil,
}
if result != nil {
history.RowsAffected = result.AffectedRows
}
if err != nil {
history.Error = err.Error()
}
s.db.WithContext(ctx).Create(history)
if err != nil {
return nil, fmt.Errorf("query execution failed: %w", err)
}
return result, nil
}
// GetTables returns all tables for a connection
func (s *ConnectionService) GetTables(ctx context.Context, connectionID string) ([]models.Table, error) {
// Get connection config
conn, err := s.GetConnectionByID(ctx, connectionID)
if err != nil {
return nil, err
}
// Decrypt password
password, err := s.encryptSvc.DecryptPassword(conn.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
// Get or create connection
dbConn, err := s.connManager.GetConnection(conn, password)
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
return dbConn.GetTables("")
}
// GetTableData returns data from a table
func (s *ConnectionService) GetTableData(
ctx context.Context,
connectionID, tableName string,
limit, offset int,
) (*models.QueryResult, error) {
// Validate limit
if limit <= 0 || limit > 1000 {
limit = 100
}
if offset < 0 {
offset = 0
}
// Build query based on connection type
conn, err := s.GetConnectionByID(ctx, connectionID)
if err != nil {
return nil, err
}
var query string
switch conn.Type {
case models.ConnectionTypeMySQL:
query = fmt.Sprintf("SELECT * FROM `%s` LIMIT %d OFFSET %d", tableName, limit, offset)
case models.ConnectionTypePostgreSQL:
query = fmt.Sprintf(`SELECT * FROM "%s" LIMIT %d OFFSET %d`, tableName, limit, offset)
case models.ConnectionTypeSQLite:
query = fmt.Sprintf(`SELECT * FROM "%s" LIMIT %d OFFSET %d`, tableName, limit, offset)
default:
return nil, models.ErrValidationFailed
}
return s.ExecuteQuery(ctx, connectionID, query)
}
// GetTableStructure returns the structure of a table
func (s *ConnectionService) GetTableStructure(ctx context.Context, connectionID, tableName string) (*models.TableStructure, error) {
// Get connection config
conn, err := s.GetConnectionByID(ctx, connectionID)
if err != nil {
return nil, err
}
// Decrypt password
password, err := s.encryptSvc.DecryptPassword(conn.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
// Get or create connection
dbConn, err := s.connManager.GetConnection(conn, password)
if err != nil {
return nil, fmt.Errorf("failed to get connection: %w", err)
}
return dbConn.GetTableStructure(tableName)
}

View File

@@ -0,0 +1,199 @@
package services
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"os"
"sync"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
"uzdb/internal/utils"
)
// EncryptionService handles encryption and decryption of sensitive data
type EncryptionService struct {
key []byte
cipher cipher.Block
gcm cipher.AEAD
mu sync.RWMutex
}
var (
encryptionInstance *EncryptionService
encryptionOnce sync.Once
)
// GetEncryptionService returns the singleton encryption service instance
func GetEncryptionService() *EncryptionService {
return encryptionInstance
}
// InitEncryptionService initializes the encryption service
func InitEncryptionService(cfg *config.EncryptionConfig) (*EncryptionService, error) {
var err error
encryptionOnce.Do(func() {
encryptionInstance = &EncryptionService{}
err = encryptionInstance.init(cfg)
})
return encryptionInstance, err
}
// init initializes the encryption service with a key
func (s *EncryptionService) init(cfg *config.EncryptionConfig) error {
s.mu.Lock()
defer s.mu.Unlock()
// Try to load existing key or generate new one
key, err := s.loadOrGenerateKey(cfg)
if err != nil {
return fmt.Errorf("failed to load/generate key: %w", err)
}
s.key = key
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return fmt.Errorf("failed to create cipher: %w", err)
}
s.cipher = block
// Create GCM mode
gcm, err := cipher.NewGCM(block)
if err != nil {
return fmt.Errorf("failed to create GCM: %w", err)
}
s.gcm = gcm
config.GetLogger().Info("encryption service initialized")
return nil
}
// loadOrGenerateKey loads existing key or generates a new one
func (s *EncryptionService) loadOrGenerateKey(cfg *config.EncryptionConfig) ([]byte, error) {
// Use provided key if available
if cfg.Key != "" {
key := []byte(cfg.Key)
// Ensure key is correct length (32 bytes for AES-256)
if len(key) < 32 {
// Pad key
padded := make([]byte, 32)
copy(padded, key)
key = padded
} else if len(key) > 32 {
key = key[:32]
}
return key, nil
}
// Try to load from file
if cfg.KeyFile != "" {
if data, err := os.ReadFile(cfg.KeyFile); err == nil {
key := []byte(data)
if len(key) >= 32 {
return key[:32], nil
}
}
}
// Generate new key
key := make([]byte, 32) // AES-256
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
// Save key to file if path provided
if cfg.KeyFile != "" {
if err := os.WriteFile(cfg.KeyFile, key, 0600); err != nil {
config.GetLogger().Warn("failed to save encryption key", zap.Error(err))
} else {
config.GetLogger().Info("encryption key generated and saved",
zap.String("path", cfg.KeyFile))
}
}
return key, nil
}
// Encrypt encrypts plaintext using AES-GCM
func (s *EncryptionService) Encrypt(plaintext string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.gcm == nil {
return "", models.ErrEncryptionFailed
}
// Generate nonce
nonce := make([]byte, s.gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Encrypt
ciphertext := s.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
// Encode to base64
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts ciphertext using AES-GCM
func (s *EncryptionService) Decrypt(ciphertext string) (string, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.gcm == nil {
return "", models.ErrEncryptionFailed
}
// Decode from base64
data, err := base64.StdEncoding.DecodeString(ciphertext)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
// Verify nonce size
nonceSize := s.gcm.NonceSize()
if len(data) < nonceSize {
return "", models.ErrEncryptionFailed
}
// Extract nonce and ciphertext
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
// Decrypt
plaintext, err := s.gcm.Open(nil, nonce, ciphertextBytes, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}
// EncryptPassword encrypts a password for storage
func (s *EncryptionService) EncryptPassword(password string) (string, error) {
if password == "" {
return "", nil
}
return s.Encrypt(password)
}
// DecryptPassword decrypts a stored password
func (s *EncryptionService) DecryptPassword(encryptedPassword string) (string, error) {
if encryptedPassword == "" {
return "", nil
}
return s.Decrypt(encryptedPassword)
}
// MaskPasswordForLogging masks password for safe logging
func (s *EncryptionService) MaskPasswordForLogging(password string) string {
return utils.MaskPassword(password)
}

236
internal/services/query.go Normal file
View File

@@ -0,0 +1,236 @@
package services
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"uzdb/internal/config"
"uzdb/internal/models"
)
// QueryService handles query-related operations
type QueryService struct {
db *gorm.DB
}
// NewQueryService creates a new query service
func NewQueryService(db *gorm.DB) *QueryService {
return &QueryService{
db: db,
}
}
// GetQueryHistory returns query history with pagination
func (s *QueryService) GetQueryHistory(
ctx context.Context,
connectionID string,
page, pageSize int,
) ([]models.QueryHistory, int64, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
var total int64
var history []models.QueryHistory
query := s.db.WithContext(ctx).Model(&models.QueryHistory{})
if connectionID != "" {
query = query.Where("connection_id = ?", connectionID)
}
// Get total count
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count history: %w", err)
}
// Get paginated results
offset := (page - 1) * pageSize
if err := query.Order("executed_at DESC").
Offset(offset).
Limit(pageSize).
Find(&history).Error; err != nil {
return nil, 0, fmt.Errorf("failed to get history: %w", err)
}
config.GetLogger().Debug("retrieved query history",
zap.String("connection_id", connectionID),
zap.Int("page", page),
zap.Int("page_size", pageSize),
zap.Int64("total", total))
return history, total, nil
}
// GetSavedQueries returns all saved queries
func (s *QueryService) GetSavedQueries(
ctx context.Context,
connectionID string,
) ([]models.SavedQuery, error) {
var queries []models.SavedQuery
query := s.db.WithContext(ctx)
if connectionID != "" {
query = query.Where("connection_id = ?", connectionID)
}
result := query.Order("name ASC").Find(&queries)
if result.Error != nil {
return nil, fmt.Errorf("failed to get saved queries: %w", result.Error)
}
return queries, nil
}
// GetSavedQueryByID returns a saved query by ID
func (s *QueryService) GetSavedQueryByID(ctx context.Context, id uint) (*models.SavedQuery, error) {
var query models.SavedQuery
result := s.db.WithContext(ctx).First(&query, "id = ?", id)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, models.ErrNotFound
}
return nil, fmt.Errorf("failed to get saved query: %w", result.Error)
}
return &query, nil
}
// CreateSavedQuery creates a new saved query
func (s *QueryService) CreateSavedQuery(
ctx context.Context,
req *models.CreateSavedQueryRequest,
) (*models.SavedQuery, error) {
query := &models.SavedQuery{
Name: req.Name,
Description: req.Description,
SQL: req.SQL,
ConnectionID: req.ConnectionID,
Tags: req.Tags,
}
result := s.db.WithContext(ctx).Create(query)
if result.Error != nil {
return nil, fmt.Errorf("failed to create saved query: %w", result.Error)
}
config.GetLogger().Info("saved query created",
zap.Uint("id", query.ID),
zap.String("name", query.Name))
return query, nil
}
// UpdateSavedQuery updates an existing saved query
func (s *QueryService) UpdateSavedQuery(
ctx context.Context,
id uint,
req *models.UpdateSavedQueryRequest,
) (*models.SavedQuery, error) {
// Get existing query
existing, err := s.GetSavedQueryByID(ctx, id)
if err != nil {
return nil, err
}
// Update fields
if req.Name != "" {
existing.Name = req.Name
}
if req.Description != "" {
existing.Description = req.Description
}
if req.SQL != "" {
existing.SQL = req.SQL
}
if req.ConnectionID != "" {
existing.ConnectionID = req.ConnectionID
}
if req.Tags != "" {
existing.Tags = req.Tags
}
result := s.db.WithContext(ctx).Save(existing)
if result.Error != nil {
return nil, fmt.Errorf("failed to update saved query: %w", result.Error)
}
config.GetLogger().Info("saved query updated",
zap.Uint("id", id),
zap.String("name", existing.Name))
return existing, nil
}
// DeleteSavedQuery deletes a saved query
func (s *QueryService) DeleteSavedQuery(ctx context.Context, id uint) error {
// Check if exists
if _, err := s.GetSavedQueryByID(ctx, id); err != nil {
return err
}
result := s.db.WithContext(ctx).Delete(&models.SavedQuery{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("failed to delete saved query: %w", result.Error)
}
config.GetLogger().Info("saved query deleted", zap.Uint("id", id))
return nil
}
// ClearOldHistory clears query history older than specified days
func (s *QueryService) ClearOldHistory(ctx context.Context, days int) (int64, error) {
if days <= 0 {
days = 30 // Default to 30 days
}
cutoffTime := time.Now().AddDate(0, 0, -days)
result := s.db.WithContext(ctx).
Where("executed_at < ?", cutoffTime).
Delete(&models.QueryHistory{})
if result.Error != nil {
return 0, fmt.Errorf("failed to clear old history: %w", result.Error)
}
config.GetLogger().Info("cleared old query history",
zap.Int64("deleted_count", result.RowsAffected),
zap.Int("days", days))
return result.RowsAffected, nil
}
// GetRecentQueries returns recent queries for quick access
func (s *QueryService) GetRecentQueries(
ctx context.Context,
connectionID string,
limit int,
) ([]models.QueryHistory, error) {
if limit <= 0 || limit > 50 {
limit = 10
}
var queries []models.QueryHistory
query := s.db.WithContext(ctx).
Where("connection_id = ? AND success = ?", connectionID, true).
Order("executed_at DESC").
Limit(limit).
Find(&queries)
if query.Error != nil {
return nil, fmt.Errorf("failed to get recent queries: %w", query.Error)
}
return queries, nil
}

136
internal/utils/errors.go Normal file
View File

@@ -0,0 +1,136 @@
package utils
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"strings"
"time"
)
// GenerateID generates a unique ID (UUID-like)
func GenerateID() string {
b := make([]byte, 16)
rand.Read(b)
return formatUUID(b)
}
// formatUUID formats bytes as UUID string
func formatUUID(b []byte) string {
uuid := make([]byte, 36)
hex.Encode(uuid[0:8], b[0:4])
hex.Encode(uuid[9:13], b[4:6])
hex.Encode(uuid[14:18], b[6:8])
hex.Encode(uuid[19:23], b[8:10])
hex.Encode(uuid[24:], b[10:])
uuid[8] = '-'
uuid[13] = '-'
uuid[18] = '-'
uuid[23] = '-'
return string(uuid)
}
// FormatDuration formats duration in milliseconds
func FormatDuration(d time.Duration) int64 {
return d.Milliseconds()
}
// SanitizeSQL removes potentially dangerous SQL patterns
// Note: This is not a replacement for parameterized queries
func SanitizeSQL(sql string) string {
// Remove multiple semicolons
sql = strings.ReplaceAll(sql, ";;", ";")
// Trim whitespace
sql = strings.TrimSpace(sql)
return sql
}
// TruncateString truncates a string to max length
func TruncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}
// GenerateRandomBytes generates random bytes
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, fmt.Errorf("failed to generate random bytes: %w", err)
}
return b, nil
}
// EncodeBase64 encodes bytes to base64 string
func EncodeBase64(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
// DecodeBase64 decodes base64 string to bytes
func DecodeBase64(s string) ([]byte, error) {
return base64.StdEncoding.DecodeString(s)
}
// MaskPassword masks password for logging
func MaskPassword(password string) string {
if len(password) <= 4 {
return strings.Repeat("*", len(password))
}
return password[:2] + strings.Repeat("*", len(password)-2)
}
// ContainsAny checks if string contains any of the substrings
func ContainsAny(s string, substrings ...string) bool {
for _, sub := range substrings {
if strings.Contains(s, sub) {
return true
}
}
return false
}
// IsReadOnlyQuery checks if SQL query is read-only
func IsReadOnlyQuery(sql string) bool {
sql = strings.TrimSpace(strings.ToUpper(sql))
// Read-only operations
readOnlyPrefixes := []string{
"SELECT",
"SHOW",
"DESCRIBE",
"EXPLAIN",
"WITH",
}
for _, prefix := range readOnlyPrefixes {
if strings.HasPrefix(sql, prefix) {
return true
}
}
return false
}
// IsDDLQuery checks if SQL query is DDL
func IsDDLQuery(sql string) bool {
sql = strings.TrimSpace(strings.ToUpper(sql))
ddlKeywords := []string{
"CREATE",
"ALTER",
"DROP",
"TRUNCATE",
"RENAME",
}
for _, keyword := range ddlKeywords {
if strings.HasPrefix(sql, keyword) {
return true
}
}
return false
}

102
internal/utils/response.go Normal file
View File

@@ -0,0 +1,102 @@
package utils
import (
"errors"
"net/http"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"uzdb/internal/config"
"uzdb/internal/models"
)
// ErrorResponse sends an error response
func ErrorResponse(c *gin.Context, statusCode int, err error, message string) {
logger := config.GetLogger()
response := models.ErrorResponse{
Error: getErrorCode(err),
Message: message,
Timestamp: time.Now(),
Path: c.Request.URL.Path,
}
if message == "" {
response.Message = err.Error()
}
// Log the error
logger.Error("error response",
zap.Int("status_code", statusCode),
zap.String("error", response.Error),
zap.String("message", response.Message),
zap.String("path", response.Path),
)
c.JSON(statusCode, response)
}
// SuccessResponse sends a success response
func SuccessResponse(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, models.NewAPIResponse(data))
}
// CreatedResponse sends a created response
func CreatedResponse(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, models.NewAPIResponse(data))
}
// getErrorCode maps errors to error codes
func getErrorCode(err error) string {
if err == nil {
return string(models.CodeSuccess)
}
switch {
case errors.Is(err, models.ErrNotFound):
return string(models.CodeNotFound)
case errors.Is(err, models.ErrValidationFailed):
return string(models.CodeValidation)
case errors.Is(err, models.ErrUnauthorized):
return string(models.CodeUnauthorized)
case errors.Is(err, models.ErrForbidden):
return string(models.CodeForbidden)
case errors.Is(err, models.ErrConnectionFailed):
return string(models.CodeConnection)
case errors.Is(err, models.ErrQueryFailed):
return string(models.CodeQuery)
case errors.Is(err, models.ErrEncryptionFailed):
return string(models.CodeEncryption)
default:
return string(models.CodeInternal)
}
}
// WrapError wraps an error with context
func WrapError(err error, message string) error {
if err == nil {
return nil
}
return &wrappedError{
err: err,
message: message,
}
}
type wrappedError struct {
err error
message string
}
func (w *wrappedError) Error() string {
if w.message != "" {
return w.message + ": " + w.err.Error()
}
return w.err.Error()
}
func (w *wrappedError) Unwrap() error {
return w.err
}