refactor: Flatten directory structure
Move project files from uzdb/ subdirectory to root directory for cleaner project structure.
Changes:
- Move frontend/ to root
- Move internal/ to root
- Move build/ to root
- Move all config files (go.mod, wails.json, etc.) to root
- Remove redundant uzdb/ subdirectory nesting
Project structure is now:
├── frontend/ # React application
├── internal/ # Go backend
├── build/ # Wails build assets
├── doc/ # Design documentation
├── main.go # Entry point
└── ...
🤖 Generated with Qoder
This commit is contained in:
324
internal/app/app.go
Normal file
324
internal/app/app.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/database"
|
||||
"uzdb/internal/handler"
|
||||
"uzdb/internal/models"
|
||||
"uzdb/internal/services"
|
||||
)
|
||||
|
||||
// App is the main application structure for Wails bindings
|
||||
type App struct {
|
||||
ctx context.Context
|
||||
config *config.Config
|
||||
connectionSvc *services.ConnectionService
|
||||
querySvc *services.QueryService
|
||||
httpServer *handler.HTTPServer
|
||||
shutdownFunc context.CancelFunc
|
||||
}
|
||||
|
||||
// NewApp creates a new App instance
|
||||
func NewApp() *App {
|
||||
return &App{}
|
||||
}
|
||||
|
||||
// Initialize initializes the application with all services
|
||||
func (a *App) Initialize(
|
||||
cfg *config.Config,
|
||||
connectionSvc *services.ConnectionService,
|
||||
querySvc *services.QueryService,
|
||||
httpServer *handler.HTTPServer,
|
||||
) {
|
||||
a.config = cfg
|
||||
a.connectionSvc = connectionSvc
|
||||
a.querySvc = querySvc
|
||||
a.httpServer = httpServer
|
||||
}
|
||||
|
||||
// OnStartup is called when the app starts (public method for Wails)
|
||||
func (a *App) OnStartup(ctx context.Context) {
|
||||
a.ctx = ctx
|
||||
|
||||
config.GetLogger().Info("Wails application started")
|
||||
}
|
||||
|
||||
// GetConnections returns all user connections
|
||||
// Wails binding: frontend can call window.go.app.GetConnections()
|
||||
func (a *App) GetConnections() []models.UserConnection {
|
||||
if a.connectionSvc == nil {
|
||||
return []models.UserConnection{}
|
||||
}
|
||||
|
||||
connections, err := a.connectionSvc.GetAllConnections(a.ctx)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to get connections", zap.Error(err))
|
||||
return []models.UserConnection{}
|
||||
}
|
||||
|
||||
return connections
|
||||
}
|
||||
|
||||
// CreateConnection creates a new database connection
|
||||
// Returns error message or empty string on success
|
||||
func (a *App) CreateConnection(conn models.CreateConnectionRequest) string {
|
||||
if a.connectionSvc == nil {
|
||||
return "Service not initialized"
|
||||
}
|
||||
|
||||
_, err := a.connectionSvc.CreateConnection(a.ctx, &conn)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to create connection", zap.Error(err))
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// UpdateConnection updates an existing database connection
|
||||
// Returns error message or empty string on success
|
||||
func (a *App) UpdateConnection(conn models.UserConnection) string {
|
||||
if a.connectionSvc == nil {
|
||||
return "Service not initialized"
|
||||
}
|
||||
|
||||
req := &models.UpdateConnectionRequest{
|
||||
Name: conn.Name,
|
||||
Type: conn.Type,
|
||||
Host: conn.Host,
|
||||
Port: conn.Port,
|
||||
Username: conn.Username,
|
||||
Password: conn.Password,
|
||||
Database: conn.Database,
|
||||
SSLMode: conn.SSLMode,
|
||||
Timeout: conn.Timeout,
|
||||
}
|
||||
|
||||
_, err := a.connectionSvc.UpdateConnection(a.ctx, conn.ID, req)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to update connection", zap.Error(err))
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// DeleteConnection deletes a database connection
|
||||
// Returns error message or empty string on success
|
||||
func (a *App) DeleteConnection(id string) string {
|
||||
if a.connectionSvc == nil {
|
||||
return "Service not initialized"
|
||||
}
|
||||
|
||||
err := a.connectionSvc.DeleteConnection(a.ctx, id)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to delete connection", zap.Error(err))
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// TestConnection tests a database connection
|
||||
// Returns (success, error_message)
|
||||
func (a *App) TestConnection(id string) (bool, string) {
|
||||
if a.connectionSvc == nil {
|
||||
return false, "Service not initialized"
|
||||
}
|
||||
|
||||
result, err := a.connectionSvc.TestConnection(a.ctx, id)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to test connection", zap.Error(err))
|
||||
return false, err.Error()
|
||||
}
|
||||
|
||||
return result.Success, result.Message
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query on a connection
|
||||
// Returns query result or error message
|
||||
func (a *App) ExecuteQuery(connectionID, sql string) (*models.QueryResult, string) {
|
||||
if a.connectionSvc == nil {
|
||||
return nil, "Service not initialized"
|
||||
}
|
||||
|
||||
result, err := a.connectionSvc.ExecuteQuery(a.ctx, connectionID, sql)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to execute query",
|
||||
zap.String("connection_id", connectionID),
|
||||
zap.String("sql", sql),
|
||||
zap.Error(err))
|
||||
return nil, err.Error()
|
||||
}
|
||||
|
||||
return result, ""
|
||||
}
|
||||
|
||||
// GetTables returns all tables for a connection
|
||||
func (a *App) GetTables(connectionID string) ([]models.Table, string) {
|
||||
if a.connectionSvc == nil {
|
||||
return []models.Table{}, "Service not initialized"
|
||||
}
|
||||
|
||||
tables, err := a.connectionSvc.GetTables(a.ctx, connectionID)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to get tables",
|
||||
zap.String("connection_id", connectionID),
|
||||
zap.Error(err))
|
||||
return []models.Table{}, err.Error()
|
||||
}
|
||||
|
||||
return tables, ""
|
||||
}
|
||||
|
||||
// GetTableData returns data from a table
|
||||
func (a *App) GetTableData(connectionID, tableName string, limit, offset int) (*models.QueryResult, string) {
|
||||
if a.connectionSvc == nil {
|
||||
return nil, "Service not initialized"
|
||||
}
|
||||
|
||||
result, err := a.connectionSvc.GetTableData(a.ctx, connectionID, tableName, limit, offset)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to get table data",
|
||||
zap.String("connection_id", connectionID),
|
||||
zap.String("table", tableName),
|
||||
zap.Error(err))
|
||||
return nil, err.Error()
|
||||
}
|
||||
|
||||
return result, ""
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (a *App) GetTableStructure(connectionID, tableName string) (*models.TableStructure, string) {
|
||||
if a.connectionSvc == nil {
|
||||
return nil, "Service not initialized"
|
||||
}
|
||||
|
||||
structure, err := a.connectionSvc.GetTableStructure(a.ctx, connectionID, tableName)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to get table structure",
|
||||
zap.String("connection_id", connectionID),
|
||||
zap.String("table", tableName),
|
||||
zap.Error(err))
|
||||
return nil, err.Error()
|
||||
}
|
||||
|
||||
return structure, ""
|
||||
}
|
||||
|
||||
// GetQueryHistory returns query history with pagination
|
||||
func (a *App) GetQueryHistory(connectionID string, page, pageSize int) ([]models.QueryHistory, int64, string) {
|
||||
if a.querySvc == nil {
|
||||
return []models.QueryHistory{}, 0, "Service not initialized"
|
||||
}
|
||||
|
||||
history, total, err := a.querySvc.GetQueryHistory(a.ctx, connectionID, page, pageSize)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to get query history", zap.Error(err))
|
||||
return []models.QueryHistory{}, 0, err.Error()
|
||||
}
|
||||
|
||||
return history, total, ""
|
||||
}
|
||||
|
||||
// GetSavedQueries returns all saved queries
|
||||
func (a *App) GetSavedQueries(connectionID string) ([]models.SavedQuery, string) {
|
||||
if a.querySvc == nil {
|
||||
return []models.SavedQuery{}, "Service not initialized"
|
||||
}
|
||||
|
||||
queries, err := a.querySvc.GetSavedQueries(a.ctx, connectionID)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to get saved queries", zap.Error(err))
|
||||
return []models.SavedQuery{}, err.Error()
|
||||
}
|
||||
|
||||
return queries, ""
|
||||
}
|
||||
|
||||
// CreateSavedQuery creates a new saved query
|
||||
func (a *App) CreateSavedQuery(req models.CreateSavedQueryRequest) (*models.SavedQuery, string) {
|
||||
if a.querySvc == nil {
|
||||
return nil, "Service not initialized"
|
||||
}
|
||||
|
||||
query, err := a.querySvc.CreateSavedQuery(a.ctx, &req)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to create saved query", zap.Error(err))
|
||||
return nil, err.Error()
|
||||
}
|
||||
|
||||
return query, ""
|
||||
}
|
||||
|
||||
// UpdateSavedQuery updates a saved query
|
||||
func (a *App) UpdateSavedQuery(id uint, req models.UpdateSavedQueryRequest) (*models.SavedQuery, string) {
|
||||
if a.querySvc == nil {
|
||||
return nil, "Service not initialized"
|
||||
}
|
||||
|
||||
query, err := a.querySvc.UpdateSavedQuery(a.ctx, id, &req)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to update saved query", zap.Error(err))
|
||||
return nil, err.Error()
|
||||
}
|
||||
|
||||
return query, ""
|
||||
}
|
||||
|
||||
// DeleteSavedQuery deletes a saved query
|
||||
func (a *App) DeleteSavedQuery(id uint) string {
|
||||
if a.querySvc == nil {
|
||||
return "Service not initialized"
|
||||
}
|
||||
|
||||
err := a.querySvc.DeleteSavedQuery(a.ctx, id)
|
||||
if err != nil {
|
||||
config.GetLogger().Error("failed to delete saved query", zap.Error(err))
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// StartHTTPServer starts the HTTP API server in background
|
||||
func (a *App) StartHTTPServer() string {
|
||||
if a.httpServer == nil {
|
||||
return "HTTP server not initialized"
|
||||
}
|
||||
|
||||
go func() {
|
||||
port := a.config.API.Port
|
||||
if err := a.httpServer.Start(port); err != nil {
|
||||
config.GetLogger().Error("HTTP server error", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the application
|
||||
func (a *App) Shutdown() {
|
||||
config.GetLogger().Info("shutting down application")
|
||||
|
||||
if a.shutdownFunc != nil {
|
||||
a.shutdownFunc()
|
||||
}
|
||||
|
||||
if a.httpServer != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
a.httpServer.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// Close all database connections
|
||||
database.CloseSQLite()
|
||||
|
||||
config.Sync()
|
||||
}
|
||||
261
internal/config/config.go
Normal file
261
internal/config/config.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Environment represents the application environment type
|
||||
type Environment string
|
||||
|
||||
const (
|
||||
// EnvDevelopment represents development environment
|
||||
EnvDevelopment Environment = "development"
|
||||
// EnvProduction represents production environment
|
||||
EnvProduction Environment = "production"
|
||||
)
|
||||
|
||||
// Config holds all configuration for the application
|
||||
type Config struct {
|
||||
// App settings
|
||||
AppName string `json:"app_name"`
|
||||
Version string `json:"version"`
|
||||
Environment Environment `json:"environment"`
|
||||
|
||||
// Database settings (SQLite for app data)
|
||||
Database DatabaseConfig `json:"database"`
|
||||
|
||||
// Encryption settings
|
||||
Encryption EncryptionConfig `json:"encryption"`
|
||||
|
||||
// Logger settings
|
||||
Logger LoggerConfig `json:"logger"`
|
||||
|
||||
// API settings (for debug HTTP server)
|
||||
API APIConfig `json:"api"`
|
||||
|
||||
// File paths
|
||||
DataDir string `json:"-"`
|
||||
}
|
||||
|
||||
// DatabaseConfig holds database configuration
|
||||
type DatabaseConfig struct {
|
||||
// SQLite database file path for app data
|
||||
SQLitePath string `json:"sqlite_path"`
|
||||
// Max open connections
|
||||
MaxOpenConns int `json:"max_open_conns"`
|
||||
// Max idle connections
|
||||
MaxIdleConns int `json:"max_idle_conns"`
|
||||
// Connection max lifetime in minutes
|
||||
MaxLifetime int `json:"max_lifetime"`
|
||||
}
|
||||
|
||||
// EncryptionConfig holds encryption configuration
|
||||
type EncryptionConfig struct {
|
||||
// Key for encrypting sensitive data (passwords, etc.)
|
||||
// In production, this should be loaded from secure storage
|
||||
Key string `json:"-"`
|
||||
// KeyFile path to load encryption key from
|
||||
KeyFile string `json:"key_file"`
|
||||
}
|
||||
|
||||
// LoggerConfig holds logger configuration
|
||||
type LoggerConfig struct {
|
||||
// Log level: debug, info, warn, error
|
||||
Level string `json:"level"`
|
||||
// Log format: json, console
|
||||
Format string `json:"format"`
|
||||
// Output file path (empty for stdout)
|
||||
OutputPath string `json:"output_path"`
|
||||
}
|
||||
|
||||
// APIConfig holds HTTP API configuration
|
||||
type APIConfig struct {
|
||||
// Enable HTTP API server (for debugging)
|
||||
Enabled bool `json:"enabled"`
|
||||
// Port for HTTP API server
|
||||
Port string `json:"port"`
|
||||
}
|
||||
|
||||
var (
|
||||
instance *Config
|
||||
once sync.Once
|
||||
logger *zap.Logger
|
||||
)
|
||||
|
||||
// Get returns the singleton config instance
|
||||
func Get() *Config {
|
||||
return instance
|
||||
}
|
||||
|
||||
// GetLogger returns the zap logger
|
||||
func GetLogger() *zap.Logger {
|
||||
return logger
|
||||
}
|
||||
|
||||
// Init initializes the configuration
|
||||
// If config file doesn't exist, creates default config
|
||||
func Init(dataDir string) (*Config, error) {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
instance = &Config{
|
||||
DataDir: dataDir,
|
||||
}
|
||||
err = instance.load(dataDir)
|
||||
})
|
||||
return instance, err
|
||||
}
|
||||
|
||||
// load loads configuration from file or creates default
|
||||
func (c *Config) load(dataDir string) error {
|
||||
configPath := filepath.Join(dataDir, "config.json")
|
||||
|
||||
// Try to load existing config
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := json.Unmarshal(data, c); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Create default config
|
||||
c.setDefaults()
|
||||
if err := c.save(configPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Override with environment variables
|
||||
c.loadEnv()
|
||||
|
||||
// Initialize logger
|
||||
if err := c.initLogger(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("configuration loaded",
|
||||
zap.String("environment", string(c.Environment)),
|
||||
zap.String("data_dir", c.DataDir),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDefaults sets default configuration values
|
||||
func (c *Config) setDefaults() {
|
||||
c.AppName = "uzdb"
|
||||
c.Version = "1.0.0"
|
||||
c.Environment = EnvDevelopment
|
||||
|
||||
c.Database = DatabaseConfig{
|
||||
SQLitePath: filepath.Join(c.DataDir, "uzdb.db"),
|
||||
MaxOpenConns: 25,
|
||||
MaxIdleConns: 5,
|
||||
MaxLifetime: 5,
|
||||
}
|
||||
|
||||
c.Encryption = EncryptionConfig{
|
||||
Key: "", // Will be generated if empty
|
||||
KeyFile: filepath.Join(c.DataDir, "encryption.key"),
|
||||
}
|
||||
|
||||
c.Logger = LoggerConfig{
|
||||
Level: "debug",
|
||||
Format: "console",
|
||||
OutputPath: "",
|
||||
}
|
||||
|
||||
c.API = APIConfig{
|
||||
Enabled: true,
|
||||
Port: "8080",
|
||||
}
|
||||
}
|
||||
|
||||
// loadEnv loads configuration from environment variables
|
||||
func (c *Config) loadEnv() {
|
||||
if env := os.Getenv("UZDB_ENV"); env != "" {
|
||||
c.Environment = Environment(env)
|
||||
}
|
||||
|
||||
if port := os.Getenv("UZDB_API_PORT"); port != "" {
|
||||
c.API.Port = port
|
||||
}
|
||||
|
||||
if logLevel := os.Getenv("UZDB_LOG_LEVEL"); logLevel != "" {
|
||||
c.Logger.Level = logLevel
|
||||
}
|
||||
|
||||
if dbPath := os.Getenv("UZDB_DB_PATH"); dbPath != "" {
|
||||
c.Database.SQLitePath = dbPath
|
||||
}
|
||||
}
|
||||
|
||||
// save saves configuration to file
|
||||
func (c *Config) save(path string) error {
|
||||
data, err := json.MarshalIndent(c, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
// initLogger initializes the zap logger
|
||||
func (c *Config) initLogger() error {
|
||||
var cfg zap.Config
|
||||
|
||||
switch c.Logger.Format {
|
||||
case "json":
|
||||
cfg = zap.NewProductionConfig()
|
||||
default:
|
||||
cfg = zap.NewDevelopmentConfig()
|
||||
}
|
||||
|
||||
// Set log level
|
||||
level, parseErr := zapcore.ParseLevel(c.Logger.Level)
|
||||
if parseErr != nil {
|
||||
level = zapcore.InfoLevel
|
||||
}
|
||||
cfg.Level.SetLevel(level)
|
||||
|
||||
// Configure output
|
||||
if c.Logger.OutputPath != "" {
|
||||
cfg.OutputPaths = []string{c.Logger.OutputPath}
|
||||
cfg.ErrorOutputPaths = []string{c.Logger.OutputPath}
|
||||
}
|
||||
|
||||
var buildErr error
|
||||
logger, buildErr = cfg.Build(
|
||||
zap.AddCaller(),
|
||||
zap.AddStacktrace(zapcore.ErrorLevel),
|
||||
)
|
||||
if buildErr != nil {
|
||||
return buildErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsDevelopment returns true if running in development mode
|
||||
func (c *Config) IsDevelopment() bool {
|
||||
return c.Environment == EnvDevelopment
|
||||
}
|
||||
|
||||
// IsProduction returns true if running in production mode
|
||||
func (c *Config) IsProduction() bool {
|
||||
return c.Environment == EnvProduction
|
||||
}
|
||||
|
||||
// Sync flushes any buffered log entries
|
||||
func Sync() error {
|
||||
if logger != nil {
|
||||
return logger.Sync()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
137
internal/database/manager.go
Normal file
137
internal/database/manager.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// ConnectionManager manages database connections for different database types
|
||||
type ConnectionManager struct {
|
||||
connections map[string]DatabaseConnector
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// DatabaseConnector is the interface for all database connections
|
||||
type DatabaseConnector interface {
|
||||
GetDB() *sql.DB
|
||||
Close() error
|
||||
IsConnected() bool
|
||||
ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error)
|
||||
ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error)
|
||||
GetTables(schema string) ([]models.Table, error)
|
||||
GetTableStructure(tableName string) (*models.TableStructure, error)
|
||||
GetMetadata() (*models.DBMetadata, error)
|
||||
}
|
||||
|
||||
// NewConnectionManager creates a new connection manager
|
||||
func NewConnectionManager() *ConnectionManager {
|
||||
return &ConnectionManager{
|
||||
connections: make(map[string]DatabaseConnector),
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnection gets or creates a connection for a user connection config
|
||||
func (m *ConnectionManager) GetConnection(conn *models.UserConnection, password string) (DatabaseConnector, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if connection already exists
|
||||
if existing, ok := m.connections[conn.ID]; ok {
|
||||
if existing.IsConnected() {
|
||||
return existing, nil
|
||||
}
|
||||
// Connection is dead, remove it
|
||||
existing.Close()
|
||||
delete(m.connections, conn.ID)
|
||||
}
|
||||
|
||||
// Create new connection based on type
|
||||
var connector DatabaseConnector
|
||||
var err error
|
||||
|
||||
switch conn.Type {
|
||||
case models.ConnectionTypeMySQL:
|
||||
connector, err = NewMySQLConnection(conn, password)
|
||||
case models.ConnectionTypePostgreSQL:
|
||||
connector, err = NewPostgreSQLConnection(conn, password)
|
||||
case models.ConnectionTypeSQLite:
|
||||
connector, err = NewSQLiteConnection(conn)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", conn.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.connections[conn.ID] = connector
|
||||
|
||||
config.GetLogger().Info("connection created in manager",
|
||||
zap.String("connection_id", conn.ID),
|
||||
zap.String("type", string(conn.Type)),
|
||||
)
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
// RemoveConnection removes a connection from the manager
|
||||
func (m *ConnectionManager) RemoveConnection(connectionID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if conn, ok := m.connections[connectionID]; ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(m.connections, connectionID)
|
||||
config.GetLogger().Info("connection removed from manager",
|
||||
zap.String("connection_id", connectionID),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnections returns all managed connections
|
||||
func (m *ConnectionManager) GetAllConnections() map[string]DatabaseConnector {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]DatabaseConnector)
|
||||
for k, v := range m.connections {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// CloseAll closes all managed connections
|
||||
func (m *ConnectionManager) CloseAll() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
for id, conn := range m.connections {
|
||||
if err := conn.Close(); err != nil {
|
||||
lastErr = err
|
||||
config.GetLogger().Error("failed to close connection",
|
||||
zap.String("connection_id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
m.connections = make(map[string]DatabaseConnector)
|
||||
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
|
||||
config.GetLogger().Info("all connections closed")
|
||||
return nil
|
||||
}
|
||||
414
internal/database/mysql.go
Normal file
414
internal/database/mysql.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// MySQLConnection represents a MySQL connection
|
||||
type MySQLConnection struct {
|
||||
db *sql.DB
|
||||
dsn string
|
||||
host string
|
||||
port int
|
||||
database string
|
||||
username string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BuildMySQLDSN builds MySQL DSN connection string
|
||||
func BuildMySQLDSN(conn *models.UserConnection, password string) string {
|
||||
// MySQL DSN format: user:pass@tcp(host:port)/dbname?params
|
||||
params := []string{
|
||||
"parseTime=true",
|
||||
"loc=Local",
|
||||
fmt.Sprintf("timeout=%ds", conn.Timeout),
|
||||
"charset=utf8mb4",
|
||||
"collation=utf8mb4_unicode_ci",
|
||||
}
|
||||
|
||||
if conn.SSLMode != "" && conn.SSLMode != "disable" {
|
||||
switch conn.SSLMode {
|
||||
case "require":
|
||||
params = append(params, "tls=required")
|
||||
case "verify-ca":
|
||||
params = append(params, "tls=skip-verify")
|
||||
case "verify-full":
|
||||
params = append(params, "tls=skip-verify")
|
||||
}
|
||||
}
|
||||
|
||||
queryString := strings.Join(params, "&")
|
||||
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s",
|
||||
conn.Username,
|
||||
password,
|
||||
conn.Host,
|
||||
conn.Port,
|
||||
conn.Database,
|
||||
queryString,
|
||||
)
|
||||
}
|
||||
|
||||
// NewMySQLConnection creates a new MySQL connection
|
||||
func NewMySQLConnection(conn *models.UserConnection, password string) (*MySQLConnection, error) {
|
||||
dsn := BuildMySQLDSN(conn, password)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open MySQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping MySQL: %w", err)
|
||||
}
|
||||
|
||||
mysqlConn := &MySQLConnection{
|
||||
db: db,
|
||||
dsn: dsn,
|
||||
host: conn.Host,
|
||||
port: conn.Port,
|
||||
database: conn.Database,
|
||||
username: conn.Username,
|
||||
}
|
||||
|
||||
config.GetLogger().Info("MySQL connection established",
|
||||
zap.String("host", conn.Host),
|
||||
zap.Int("port", conn.Port),
|
||||
zap.String("database", conn.Database),
|
||||
)
|
||||
|
||||
return mysqlConn, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying sql.DB
|
||||
func (m *MySQLConnection) GetDB() *sql.DB {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.db
|
||||
}
|
||||
|
||||
// Close closes the MySQL connection
|
||||
func (m *MySQLConnection) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.db != nil {
|
||||
if err := m.db.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close MySQL connection: %w", err)
|
||||
}
|
||||
m.db = nil
|
||||
config.GetLogger().Info("MySQL connection closed",
|
||||
zap.String("host", m.host),
|
||||
zap.Int("port", m.port),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the connection is alive
|
||||
func (m *MySQLConnection) IsConnected() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := m.db.Ping()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query and returns results
|
||||
func (m *MySQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
rows, err := db.Query(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query execution failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %w", err)
|
||||
}
|
||||
result.Columns = columns
|
||||
|
||||
// Scan rows
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
row := make([]interface{}, len(columns))
|
||||
for i, v := range values {
|
||||
row[i] = convertValue(v)
|
||||
}
|
||||
result.Rows = append(result.Rows, row)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("row iteration error: %w", err)
|
||||
}
|
||||
|
||||
result.RowCount = int64(len(result.Rows))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
|
||||
func (m *MySQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("statement execution failed: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := res.RowsAffected()
|
||||
lastInsertID, _ := res.LastInsertId()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
AffectedRows: rowsAffected,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
Rows: [][]interface{}{{lastInsertID}},
|
||||
Columns: []string{"last_insert_id"},
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTables returns all tables in the database
|
||||
func (m *MySQLConnection) GetTables(schema string) ([]models.Table, error) {
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
TABLE_NAME as table_name,
|
||||
TABLE_TYPE as table_type,
|
||||
TABLE_ROWS as row_count
|
||||
FROM information_schema.TABLES
|
||||
WHERE TABLE_SCHEMA = ?
|
||||
ORDER BY TABLE_NAME
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, m.database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []models.Table
|
||||
for rows.Next() {
|
||||
var t models.Table
|
||||
var rowCount sql.NullInt64
|
||||
if err := rows.Scan(&t.Name, &t.Type, &rowCount); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table: %w", err)
|
||||
}
|
||||
if rowCount.Valid {
|
||||
t.RowCount = rowCount.Int64
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (m *MySQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
structure := &models.TableStructure{
|
||||
TableName: tableName,
|
||||
}
|
||||
|
||||
// Get columns
|
||||
query := `
|
||||
SELECT
|
||||
COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT,
|
||||
COLUMN_KEY, EXTRA, CHARACTER_MAXIMUM_LENGTH, NUMERIC_PRECISION,
|
||||
NUMERIC_SCALE, COLUMN_COMMENT
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
ORDER BY ORDINAL_POSITION
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, m.database, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var col models.TableColumn
|
||||
var nullable, columnKey, extra string
|
||||
var defaultVal, comment sql.NullString
|
||||
var maxLength, precision, scale sql.NullInt32
|
||||
|
||||
if err := rows.Scan(
|
||||
&col.Name, &col.DataType, &nullable, &defaultVal,
|
||||
&columnKey, &extra, &maxLength, &precision,
|
||||
&scale, &comment,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan column: %w", err)
|
||||
}
|
||||
|
||||
col.Nullable = nullable == "YES"
|
||||
col.IsPrimary = columnKey == "PRI"
|
||||
col.IsUnique = columnKey == "UNI"
|
||||
col.AutoIncrement = strings.Contains(extra, "auto_increment")
|
||||
|
||||
if defaultVal.Valid {
|
||||
col.Default = defaultVal.String
|
||||
}
|
||||
if maxLength.Valid {
|
||||
col.Length = int(maxLength.Int32)
|
||||
}
|
||||
if precision.Valid {
|
||||
col.Scale = int(precision.Int32)
|
||||
}
|
||||
if scale.Valid {
|
||||
col.Scale = int(scale.Int32)
|
||||
}
|
||||
if comment.Valid {
|
||||
col.Comment = comment.String
|
||||
}
|
||||
|
||||
structure.Columns = append(structure.Columns, col)
|
||||
}
|
||||
|
||||
// Get indexes
|
||||
indexQuery := `
|
||||
SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX
|
||||
FROM information_schema.STATISTICS
|
||||
WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?
|
||||
ORDER BY INDEX_NAME, SEQ_IN_INDEX
|
||||
`
|
||||
|
||||
idxRows, err := db.Query(indexQuery, m.database, tableName)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
|
||||
} else {
|
||||
defer idxRows.Close()
|
||||
|
||||
indexMap := make(map[string]*models.TableIndex)
|
||||
for idxRows.Next() {
|
||||
var indexName, columnName string
|
||||
var nonUnique bool
|
||||
var seqInIndex int
|
||||
|
||||
if err := idxRows.Scan(&indexName, &columnName, &nonUnique, &seqInIndex); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
idx, exists := indexMap[indexName]
|
||||
if !exists {
|
||||
idx = &models.TableIndex{
|
||||
Name: indexName,
|
||||
IsUnique: !nonUnique,
|
||||
IsPrimary: indexName == "PRIMARY",
|
||||
Columns: []string{},
|
||||
}
|
||||
indexMap[indexName] = idx
|
||||
}
|
||||
idx.Columns = append(idx.Columns, columnName)
|
||||
}
|
||||
|
||||
for _, idx := range indexMap {
|
||||
structure.Indexes = append(structure.Indexes, *idx)
|
||||
}
|
||||
}
|
||||
|
||||
return structure, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns database metadata
|
||||
func (m *MySQLConnection) GetMetadata() (*models.DBMetadata, error) {
|
||||
db := m.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
metadata := &models.DBMetadata{
|
||||
Database: m.database,
|
||||
User: m.username,
|
||||
Host: m.host,
|
||||
Port: m.port,
|
||||
}
|
||||
|
||||
// Get MySQL version
|
||||
var version string
|
||||
err := db.QueryRow("SELECT VERSION()").Scan(&version)
|
||||
if err == nil {
|
||||
metadata.Version = version
|
||||
}
|
||||
|
||||
// Get server time
|
||||
var serverTime string
|
||||
err = db.QueryRow("SELECT NOW()").Scan(&serverTime)
|
||||
if err == nil {
|
||||
metadata.ServerTime = serverTime
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// convertValue converts database value to interface{}
|
||||
func convertValue(val interface{}) interface{} {
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle []byte (common from MySQL driver)
|
||||
if b, ok := val.([]byte); ok {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
449
internal/database/postgres.go
Normal file
449
internal/database/postgres.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// PostgreSQLConnection represents a PostgreSQL connection
|
||||
type PostgreSQLConnection struct {
|
||||
db *sql.DB
|
||||
dsn string
|
||||
host string
|
||||
port int
|
||||
database string
|
||||
username string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// BuildPostgreSQLDSN builds PostgreSQL DSN connection string
|
||||
func BuildPostgreSQLDSN(conn *models.UserConnection, password string) string {
|
||||
// PostgreSQL DSN format: postgres://user:pass@host:port/dbname?params
|
||||
params := []string{
|
||||
fmt.Sprintf("connect_timeout=%d", conn.Timeout),
|
||||
"sslmode=disable",
|
||||
}
|
||||
|
||||
if conn.SSLMode != "" {
|
||||
switch conn.SSLMode {
|
||||
case "require":
|
||||
params = append(params, "sslmode=require")
|
||||
case "verify-ca":
|
||||
params = append(params, "sslmode=verify-ca")
|
||||
case "verify-full":
|
||||
params = append(params, "sslmode=verify-full")
|
||||
case "disable":
|
||||
params = append(params, "sslmode=disable")
|
||||
default:
|
||||
params = append(params, fmt.Sprintf("sslmode=%s", conn.SSLMode))
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?%s",
|
||||
conn.Username,
|
||||
password,
|
||||
conn.Host,
|
||||
conn.Port,
|
||||
conn.Database,
|
||||
strings.Join(params, "&"),
|
||||
)
|
||||
}
|
||||
|
||||
// NewPostgreSQLConnection creates a new PostgreSQL connection
|
||||
func NewPostgreSQLConnection(conn *models.UserConnection, password string) (*PostgreSQLConnection, error) {
|
||||
dsn := BuildPostgreSQLDSN(conn, password)
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open PostgreSQL connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping PostgreSQL: %w", err)
|
||||
}
|
||||
|
||||
pgConn := &PostgreSQLConnection{
|
||||
db: db,
|
||||
dsn: dsn,
|
||||
host: conn.Host,
|
||||
port: conn.Port,
|
||||
database: conn.Database,
|
||||
username: conn.Username,
|
||||
}
|
||||
|
||||
config.GetLogger().Info("PostgreSQL connection established",
|
||||
zap.String("host", conn.Host),
|
||||
zap.Int("port", conn.Port),
|
||||
zap.String("database", conn.Database),
|
||||
)
|
||||
|
||||
return pgConn, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying sql.DB
|
||||
func (p *PostgreSQLConnection) GetDB() *sql.DB {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.db
|
||||
}
|
||||
|
||||
// Close closes the PostgreSQL connection
|
||||
func (p *PostgreSQLConnection) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.db != nil {
|
||||
if err := p.db.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close PostgreSQL connection: %w", err)
|
||||
}
|
||||
p.db = nil
|
||||
config.GetLogger().Info("PostgreSQL connection closed",
|
||||
zap.String("host", p.host),
|
||||
zap.Int("port", p.port),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the connection is alive
|
||||
func (p *PostgreSQLConnection) IsConnected() bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := p.db.Ping()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query and returns results
|
||||
func (p *PostgreSQLConnection) ExecuteQuery(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
rows, err := db.Query(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query execution failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %w", err)
|
||||
}
|
||||
result.Columns = columns
|
||||
|
||||
// Scan rows
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
row := make([]interface{}, len(columns))
|
||||
for i, v := range values {
|
||||
row[i] = convertValue(v)
|
||||
}
|
||||
result.Rows = append(result.Rows, row)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("row iteration error: %w", err)
|
||||
}
|
||||
|
||||
result.RowCount = int64(len(result.Rows))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
|
||||
func (p *PostgreSQLConnection) ExecuteStatement(sql string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("statement execution failed: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := res.RowsAffected()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
AffectedRows: rowsAffected,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTables returns all tables in the database
|
||||
func (p *PostgreSQLConnection) GetTables(schema string) ([]models.Table, error) {
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
table_name,
|
||||
table_type,
|
||||
(SELECT COUNT(*) FROM information_schema.columns c
|
||||
WHERE c.table_name = t.table_name AND c.table_schema = t.table_schema) as column_count
|
||||
FROM information_schema.tables t
|
||||
WHERE table_schema = $1
|
||||
ORDER BY table_name
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []models.Table
|
||||
for rows.Next() {
|
||||
var t models.Table
|
||||
var columnCount int
|
||||
t.Schema = schema
|
||||
if err := rows.Scan(&t.Name, &t.Type, &columnCount); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table: %w", err)
|
||||
}
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (p *PostgreSQLConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
structure := &models.TableStructure{
|
||||
TableName: tableName,
|
||||
Schema: "public",
|
||||
}
|
||||
|
||||
// Get columns
|
||||
query := `
|
||||
SELECT
|
||||
c.column_name,
|
||||
c.data_type,
|
||||
c.is_nullable,
|
||||
c.column_default,
|
||||
c.character_maximum_length,
|
||||
c.numeric_precision,
|
||||
c.numeric_scale,
|
||||
pg_catalog.col_description(c.oid, c.ordinal_position) as comment
|
||||
FROM information_schema.columns c
|
||||
JOIN pg_class cl ON cl.relname = c.table_name
|
||||
WHERE c.table_schema = 'public' AND c.table_name = $1
|
||||
ORDER BY c.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := db.Query(query, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var col models.TableColumn
|
||||
var nullable string
|
||||
var defaultVal, comment sql.NullString
|
||||
var maxLength, precision, scale sql.NullInt32
|
||||
|
||||
if err := rows.Scan(
|
||||
&col.Name, &col.DataType, &nullable, &defaultVal,
|
||||
&maxLength, &precision, &scale, &comment,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan column: %w", err)
|
||||
}
|
||||
|
||||
col.Nullable = nullable == "YES"
|
||||
|
||||
if defaultVal.Valid {
|
||||
col.Default = defaultVal.String
|
||||
}
|
||||
if maxLength.Valid && maxLength.Int32 > 0 {
|
||||
col.Length = int(maxLength.Int32)
|
||||
}
|
||||
if precision.Valid {
|
||||
col.Scale = int(precision.Int32)
|
||||
}
|
||||
if scale.Valid {
|
||||
col.Scale = int(scale.Int32)
|
||||
}
|
||||
if comment.Valid {
|
||||
col.Comment = comment.String
|
||||
}
|
||||
|
||||
structure.Columns = append(structure.Columns, col)
|
||||
}
|
||||
|
||||
// Get primary key information
|
||||
pkQuery := `
|
||||
SELECT a.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||||
WHERE i.indrelid = $1::regclass AND i.indisprimary
|
||||
`
|
||||
|
||||
pkRows, err := db.Query(pkQuery, tableName)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query primary keys", zap.Error(err))
|
||||
} else {
|
||||
defer pkRows.Close()
|
||||
|
||||
pkColumns := make(map[string]bool)
|
||||
for pkRows.Next() {
|
||||
var colName string
|
||||
if err := pkRows.Scan(&colName); err != nil {
|
||||
continue
|
||||
}
|
||||
pkColumns[colName] = true
|
||||
}
|
||||
|
||||
// Mark primary key columns
|
||||
for i := range structure.Columns {
|
||||
if pkColumns[structure.Columns[i].Name] {
|
||||
structure.Columns[i].IsPrimary = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get indexes
|
||||
indexQuery := `
|
||||
SELECT
|
||||
i.indexname,
|
||||
i.indexdef,
|
||||
i.indisunique
|
||||
FROM pg_indexes i
|
||||
WHERE i.schemaname = 'public' AND i.tablename = $1
|
||||
`
|
||||
|
||||
idxRows, err := db.Query(indexQuery, tableName)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
|
||||
} else {
|
||||
defer idxRows.Close()
|
||||
|
||||
for idxRows.Next() {
|
||||
var idx models.TableIndex
|
||||
var indexDef string
|
||||
if err := idxRows.Scan(&idx.Name, &indexDef, &idx.IsUnique); err != nil {
|
||||
continue
|
||||
}
|
||||
idx.IsPrimary = idx.Name == (tableName + "_pkey")
|
||||
|
||||
// Extract column names from index definition
|
||||
idx.Columns = extractColumnsFromIndexDef(indexDef)
|
||||
idx.Type = "btree" // Default for PostgreSQL
|
||||
|
||||
structure.Indexes = append(structure.Indexes, idx)
|
||||
}
|
||||
}
|
||||
|
||||
return structure, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns database metadata
|
||||
func (p *PostgreSQLConnection) GetMetadata() (*models.DBMetadata, error) {
|
||||
db := p.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
metadata := &models.DBMetadata{
|
||||
Database: p.database,
|
||||
User: p.username,
|
||||
Host: p.host,
|
||||
Port: p.port,
|
||||
}
|
||||
|
||||
// Get PostgreSQL version
|
||||
var version string
|
||||
err := db.QueryRow("SELECT version()").Scan(&version)
|
||||
if err == nil {
|
||||
metadata.Version = version
|
||||
}
|
||||
|
||||
// Get server time
|
||||
var serverTime string
|
||||
err = db.QueryRow("SELECT NOW()").Scan(&serverTime)
|
||||
if err == nil {
|
||||
metadata.ServerTime = serverTime
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// extractColumnsFromIndexDef extracts column names from PostgreSQL index definition
|
||||
func extractColumnsFromIndexDef(indexDef string) []string {
|
||||
var columns []string
|
||||
|
||||
// Simple extraction - look for content between parentheses
|
||||
start := strings.Index(indexDef, "(")
|
||||
end := strings.LastIndex(indexDef, ")")
|
||||
|
||||
if start != -1 && end != -1 && end > start {
|
||||
content := indexDef[start+1 : end]
|
||||
parts := strings.Split(content, ",")
|
||||
for _, part := range parts {
|
||||
col := strings.TrimSpace(part)
|
||||
// Remove any expressions like "LOWER(column)"
|
||||
if parenIdx := strings.Index(col, "("); parenIdx != -1 {
|
||||
continue
|
||||
}
|
||||
if col != "" {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
128
internal/database/sqlite.go
Normal file
128
internal/database/sqlite.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteDB *gorm.DB
|
||||
sqliteMu sync.RWMutex
|
||||
)
|
||||
|
||||
// InitSQLite initializes the SQLite database for application data
|
||||
func InitSQLite(dbPath string, cfg *config.DatabaseConfig) (*gorm.DB, error) {
|
||||
sqliteMu.Lock()
|
||||
defer sqliteMu.Unlock()
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
|
||||
// Configure GORM logger
|
||||
var gormLogger logger.Interface
|
||||
if config.Get().IsDevelopment() {
|
||||
gormLogger = logger.Default.LogMode(logger.Info)
|
||||
} else {
|
||||
gormLogger = logger.Default.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
// Open SQLite connection
|
||||
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite database: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool settings
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get underlying DB: %w", err)
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(time.Duration(cfg.MaxLifetime) * time.Minute)
|
||||
|
||||
// Enable WAL mode for better concurrency
|
||||
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||
config.GetLogger().Warn("failed to enable WAL mode", zap.Error(err))
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
if err := runMigrations(db); err != nil {
|
||||
return nil, fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
sqliteDB = db
|
||||
|
||||
config.GetLogger().Info("SQLite database initialized",
|
||||
zap.String("path", dbPath),
|
||||
zap.Int("max_open_conns", cfg.MaxOpenConns),
|
||||
zap.Int("max_idle_conns", cfg.MaxIdleConns),
|
||||
)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// runMigrations runs database migrations
|
||||
func runMigrations(db *gorm.DB) error {
|
||||
migrations := []interface{}{
|
||||
&models.UserConnection{},
|
||||
&models.QueryHistory{},
|
||||
&models.SavedQuery{},
|
||||
}
|
||||
|
||||
for _, model := range migrations {
|
||||
if err := db.AutoMigrate(model); err != nil {
|
||||
return fmt.Errorf("failed to migrate %T: %w", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
config.GetLogger().Info("database migrations completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSQLiteDB returns the SQLite database instance
|
||||
func GetSQLiteDB() *gorm.DB {
|
||||
sqliteMu.RLock()
|
||||
defer sqliteMu.RUnlock()
|
||||
return sqliteDB
|
||||
}
|
||||
|
||||
// CloseSQLite closes the SQLite database connection
|
||||
func CloseSQLite() error {
|
||||
sqliteMu.Lock()
|
||||
defer sqliteMu.Unlock()
|
||||
|
||||
if sqliteDB == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := sqliteDB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close SQLite database: %w", err)
|
||||
}
|
||||
|
||||
sqliteDB = nil
|
||||
config.GetLogger().Info("SQLite database closed")
|
||||
return nil
|
||||
}
|
||||
369
internal/database/sqlite_driver.go
Normal file
369
internal/database/sqlite_driver.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// SQLiteConnection represents a SQLite connection
|
||||
type SQLiteConnection struct {
|
||||
db *sql.DB
|
||||
filePath string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSQLiteConnection creates a new SQLite connection
|
||||
func NewSQLiteConnection(conn *models.UserConnection) (*SQLiteConnection, error) {
|
||||
filePath := conn.Database
|
||||
|
||||
// Ensure file exists for new connections
|
||||
db, err := sql.Open("sqlite", filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open SQLite connection: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
db.SetMaxOpenConns(1) // SQLite only supports one writer
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// Enable WAL mode
|
||||
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
config.GetLogger().Warn("failed to enable WAL mode", zap.Error(err))
|
||||
}
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping SQLite: %w", err)
|
||||
}
|
||||
|
||||
sqliteConn := &SQLiteConnection{
|
||||
db: db,
|
||||
filePath: filePath,
|
||||
}
|
||||
|
||||
config.GetLogger().Info("SQLite connection established",
|
||||
zap.String("path", filePath),
|
||||
)
|
||||
|
||||
return sqliteConn, nil
|
||||
}
|
||||
|
||||
// GetDB returns the underlying sql.DB
|
||||
func (s *SQLiteConnection) GetDB() *sql.DB {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db
|
||||
}
|
||||
|
||||
// Close closes the SQLite connection
|
||||
func (s *SQLiteConnection) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.db != nil {
|
||||
if err := s.db.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close SQLite connection: %w", err)
|
||||
}
|
||||
s.db = nil
|
||||
config.GetLogger().Info("SQLite connection closed",
|
||||
zap.String("path", s.filePath),
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the connection is alive
|
||||
func (s *SQLiteConnection) IsConnected() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.db == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err := s.db.Ping()
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query and returns results
|
||||
func (s *SQLiteConnection) ExecuteQuery(query string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query execution failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get columns: %w", err)
|
||||
}
|
||||
result.Columns = columns
|
||||
|
||||
// Scan rows
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(columns))
|
||||
valuePtrs := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
valuePtrs[i] = &values[i]
|
||||
}
|
||||
|
||||
if err := rows.Scan(valuePtrs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
row := make([]interface{}, len(columns))
|
||||
for i, v := range values {
|
||||
row[i] = convertValue(v)
|
||||
}
|
||||
result.Rows = append(result.Rows, row)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("row iteration error: %w", err)
|
||||
}
|
||||
|
||||
result.RowCount = int64(len(result.Rows))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteStatement executes a SQL statement (INSERT, UPDATE, DELETE, etc.)
|
||||
func (s *SQLiteConnection) ExecuteStatement(stmt string, args ...interface{}) (*models.QueryResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
res, err := db.Exec(stmt, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("statement execution failed: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := res.RowsAffected()
|
||||
lastInsertID, _ := res.LastInsertId()
|
||||
|
||||
result := &models.QueryResult{
|
||||
Success: true,
|
||||
AffectedRows: rowsAffected,
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
Rows: [][]interface{}{{lastInsertID}},
|
||||
Columns: []string{"last_insert_id"},
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTables returns all tables in the database
|
||||
func (s *SQLiteConnection) GetTables(schema string) ([]models.Table, error) {
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT name, type
|
||||
FROM sqlite_master
|
||||
WHERE type IN ('table', 'view') AND name NOT LIKE 'sqlite_%'
|
||||
ORDER BY name
|
||||
`
|
||||
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query tables: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []models.Table
|
||||
for rows.Next() {
|
||||
var t models.Table
|
||||
if err := rows.Scan(&t.Name, &t.Type); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table: %w", err)
|
||||
}
|
||||
|
||||
// Get row count for tables
|
||||
if t.Type == "table" {
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM \"%s\"", t.Name)
|
||||
var rowCount int64
|
||||
if err := db.QueryRow(countQuery).Scan(&rowCount); err == nil {
|
||||
t.RowCount = rowCount
|
||||
}
|
||||
}
|
||||
|
||||
tables = append(tables, t)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (s *SQLiteConnection) GetTableStructure(tableName string) (*models.TableStructure, error) {
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
structure := &models.TableStructure{
|
||||
TableName: tableName,
|
||||
}
|
||||
|
||||
// Get table info using PRAGMA
|
||||
query := fmt.Sprintf("PRAGMA table_info(\"%s\")", tableName)
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query table info: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var col models.TableColumn
|
||||
var pk int
|
||||
|
||||
if err := rows.Scan(&col.Name, &col.DataType, &col.Default, &col.Nullable, &pk, &col.IsPrimary); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan column: %w", err)
|
||||
}
|
||||
|
||||
col.IsPrimary = pk > 0
|
||||
col.Nullable = col.Nullable || !col.IsPrimary
|
||||
|
||||
structure.Columns = append(structure.Columns, col)
|
||||
}
|
||||
|
||||
// Get indexes
|
||||
idxQuery := fmt.Sprintf("PRAGMA index_list(\"%s\")", tableName)
|
||||
idxRows, err := db.Query(idxQuery)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query indexes", zap.Error(err))
|
||||
} else {
|
||||
defer idxRows.Close()
|
||||
|
||||
for idxRows.Next() {
|
||||
var idx models.TableIndex
|
||||
var origin string
|
||||
|
||||
if err := idxRows.Scan(&idx.Name, &idx.IsUnique, &origin); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
idx.IsPrimary = idx.Name == "sqlite_autoindex_"+tableName+"_1"
|
||||
idx.Type = origin
|
||||
|
||||
// Get index columns
|
||||
colQuery := fmt.Sprintf("PRAGMA index_info(\"%s\")", idx.Name)
|
||||
colRows, err := db.Query(colQuery)
|
||||
if err == nil {
|
||||
defer colRows.Close()
|
||||
for colRows.Next() {
|
||||
var seqno int
|
||||
var colName string
|
||||
if err := colRows.Scan(&seqno, &colName, &colName); err != nil {
|
||||
continue
|
||||
}
|
||||
idx.Columns = append(idx.Columns, colName)
|
||||
}
|
||||
}
|
||||
|
||||
structure.Indexes = append(structure.Indexes, idx)
|
||||
}
|
||||
}
|
||||
|
||||
// Get foreign keys
|
||||
fkQuery := fmt.Sprintf("PRAGMA foreign_key_list(\"%s\")", tableName)
|
||||
fkRows, err := db.Query(fkQuery)
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to query foreign keys", zap.Error(err))
|
||||
} else {
|
||||
defer fkRows.Close()
|
||||
|
||||
fkMap := make(map[string]*models.ForeignKey)
|
||||
for fkRows.Next() {
|
||||
var fk models.ForeignKey
|
||||
var id, seq int
|
||||
|
||||
if err := fkRows.Scan(&id, &seq, &fk.Name, &fk.ReferencedTable,
|
||||
&fk.Columns, &fk.ReferencedColumns, &fk.OnUpdate, &fk.OnDelete, ""); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle array fields properly
|
||||
fk.Columns = []string{}
|
||||
fk.ReferencedColumns = []string{}
|
||||
|
||||
var fromCol, toCol string
|
||||
if err := fkRows.Scan(&id, &seq, &fk.Name, &fk.ReferencedTable,
|
||||
&fromCol, &toCol, &fk.OnUpdate, &fk.OnDelete, ""); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
existingFk, exists := fkMap[fk.Name]
|
||||
if !exists {
|
||||
existingFk = &models.ForeignKey{
|
||||
Name: fk.Name,
|
||||
ReferencedTable: fk.ReferencedTable,
|
||||
OnDelete: fk.OnDelete,
|
||||
OnUpdate: fk.OnUpdate,
|
||||
Columns: []string{},
|
||||
ReferencedColumns: []string{},
|
||||
}
|
||||
fkMap[fk.Name] = existingFk
|
||||
}
|
||||
existingFk.Columns = append(existingFk.Columns, fromCol)
|
||||
existingFk.ReferencedColumns = append(existingFk.ReferencedColumns, toCol)
|
||||
}
|
||||
|
||||
for _, fk := range fkMap {
|
||||
structure.ForeignKeys = append(structure.ForeignKeys, *fk)
|
||||
}
|
||||
}
|
||||
|
||||
return structure, nil
|
||||
}
|
||||
|
||||
// GetMetadata returns database metadata
|
||||
func (s *SQLiteConnection) GetMetadata() (*models.DBMetadata, error) {
|
||||
db := s.GetDB()
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("connection is closed")
|
||||
}
|
||||
|
||||
metadata := &models.DBMetadata{
|
||||
Database: s.filePath,
|
||||
Version: "3", // SQLite version 3
|
||||
}
|
||||
|
||||
// Get SQLite version
|
||||
var version string
|
||||
err := db.QueryRow("SELECT sqlite_version()").Scan(&version)
|
||||
if err == nil {
|
||||
metadata.Version = version
|
||||
}
|
||||
|
||||
// Get current time
|
||||
metadata.ServerTime = time.Now().Format(time.RFC3339)
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
195
internal/handler/connection.go
Normal file
195
internal/handler/connection.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
"uzdb/internal/services"
|
||||
"uzdb/internal/utils"
|
||||
)
|
||||
|
||||
// ConnectionHandler handles connection-related HTTP requests
|
||||
type ConnectionHandler struct {
|
||||
connectionSvc *services.ConnectionService
|
||||
}
|
||||
|
||||
// NewConnectionHandler creates a new connection handler
|
||||
func NewConnectionHandler(connectionSvc *services.ConnectionService) *ConnectionHandler {
|
||||
return &ConnectionHandler{
|
||||
connectionSvc: connectionSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllConnections handles GET /api/connections
|
||||
func (h *ConnectionHandler) GetAllConnections(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
connections, err := h.connectionSvc.GetAllConnections(ctx)
|
||||
if err != nil {
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get connections")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"connections": connections,
|
||||
})
|
||||
}
|
||||
|
||||
// GetConnection handles GET /api/connections/:id
|
||||
func (h *ConnectionHandler) GetConnection(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.Param("id")
|
||||
|
||||
if id == "" {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := h.connectionSvc.GetConnectionByID(ctx, id)
|
||||
if err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get connection")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"connection": conn,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateConnection handles POST /api/connections
|
||||
func (h *ConnectionHandler) CreateConnection(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var req models.CreateConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := h.connectionSvc.CreateConnection(ctx, &req)
|
||||
if err != nil {
|
||||
if err == models.ErrValidationFailed {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, err, "Validation failed")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to create connection")
|
||||
return
|
||||
}
|
||||
|
||||
config.GetLogger().Info("connection created via API",
|
||||
zap.String("id", conn.ID),
|
||||
zap.String("name", conn.Name))
|
||||
|
||||
utils.CreatedResponse(c, gin.H{
|
||||
"connection": conn,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConnection handles PUT /api/connections/:id
|
||||
func (h *ConnectionHandler) UpdateConnection(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.Param("id")
|
||||
|
||||
if id == "" {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req models.UpdateConnectionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := h.connectionSvc.UpdateConnection(ctx, id, &req)
|
||||
if err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to update connection")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"connection": conn,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteConnection handles DELETE /api/connections/:id
|
||||
func (h *ConnectionHandler) DeleteConnection(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.Param("id")
|
||||
|
||||
if id == "" {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.connectionSvc.DeleteConnection(ctx, id); err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to delete connection")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"message": "Connection deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnection handles POST /api/connections/:id/test
|
||||
func (h *ConnectionHandler) TestConnection(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.Param("id")
|
||||
|
||||
if id == "" {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.connectionSvc.TestConnection(ctx, id)
|
||||
if err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to test connection")
|
||||
return
|
||||
}
|
||||
|
||||
statusCode := http.StatusOK
|
||||
if !result.Success {
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
|
||||
c.JSON(statusCode, gin.H{
|
||||
"success": result.Success,
|
||||
"message": result.Message,
|
||||
"duration_ms": result.Duration,
|
||||
"metadata": result.Metadata,
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterRoutes registers connection routes
|
||||
func (h *ConnectionHandler) RegisterRoutes(router *gin.RouterGroup) {
|
||||
connections := router.Group("/connections")
|
||||
{
|
||||
connections.GET("", h.GetAllConnections)
|
||||
connections.GET("/:id", h.GetConnection)
|
||||
connections.POST("", h.CreateConnection)
|
||||
connections.PUT("/:id", h.UpdateConnection)
|
||||
connections.DELETE("/:id", h.DeleteConnection)
|
||||
connections.POST("/:id/test", h.TestConnection)
|
||||
}
|
||||
}
|
||||
287
internal/handler/query.go
Normal file
287
internal/handler/query.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"uzdb/internal/models"
|
||||
"uzdb/internal/services"
|
||||
"uzdb/internal/utils"
|
||||
)
|
||||
|
||||
// QueryHandler handles query-related HTTP requests
|
||||
type QueryHandler struct {
|
||||
connectionSvc *services.ConnectionService
|
||||
querySvc *services.QueryService
|
||||
}
|
||||
|
||||
// NewQueryHandler creates a new query handler
|
||||
func NewQueryHandler(
|
||||
connectionSvc *services.ConnectionService,
|
||||
querySvc *services.QueryService,
|
||||
) *QueryHandler {
|
||||
return &QueryHandler{
|
||||
connectionSvc: connectionSvc,
|
||||
querySvc: querySvc,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteQuery handles POST /api/query
|
||||
func (h *QueryHandler) ExecuteQuery(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var req struct {
|
||||
ConnectionID string `json:"connection_id" binding:"required"`
|
||||
SQL string `json:"sql" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.connectionSvc.ExecuteQuery(ctx, req.ConnectionID, req.SQL)
|
||||
if err != nil {
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Query execution failed")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"result": result,
|
||||
})
|
||||
}
|
||||
|
||||
// GetTables handles GET /api/connections/:id/tables
|
||||
func (h *QueryHandler) GetTables(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
connectionID := c.Param("id")
|
||||
|
||||
if connectionID == "" {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
tables, err := h.connectionSvc.GetTables(ctx, connectionID)
|
||||
if err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get tables")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"tables": tables,
|
||||
})
|
||||
}
|
||||
|
||||
// GetTableData handles GET /api/connections/:id/tables/:name/data
|
||||
func (h *QueryHandler) GetTableData(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
connectionID := c.Param("id")
|
||||
tableName := c.Param("name")
|
||||
|
||||
if connectionID == "" || tableName == "" {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID and table name are required")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse limit and offset
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
|
||||
result, err := h.connectionSvc.GetTableData(ctx, connectionID, tableName, limit, offset)
|
||||
if err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get table data")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"result": result,
|
||||
"table": tableName,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// GetTableStructure handles GET /api/connections/:id/tables/:name/structure
|
||||
func (h *QueryHandler) GetTableStructure(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
connectionID := c.Param("id")
|
||||
tableName := c.Param("name")
|
||||
|
||||
if connectionID == "" || tableName == "" {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Connection ID and table name are required")
|
||||
return
|
||||
}
|
||||
|
||||
structure, err := h.connectionSvc.GetTableStructure(ctx, connectionID, tableName)
|
||||
if err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Connection not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get table structure")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"structure": structure,
|
||||
})
|
||||
}
|
||||
|
||||
// GetQueryHistory handles GET /api/history
|
||||
func (h *QueryHandler) GetQueryHistory(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
connectionID := c.Query("connection_id")
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
history, total, err := h.querySvc.GetQueryHistory(ctx, connectionID, page, pageSize)
|
||||
if err != nil {
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get query history")
|
||||
return
|
||||
}
|
||||
|
||||
totalPages := int(total) / pageSize
|
||||
if int(total)%pageSize > 0 {
|
||||
totalPages++
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"history": history,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"total_pages": totalPages,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetSavedQueries handles GET /api/saved-queries
|
||||
func (h *QueryHandler) GetSavedQueries(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
connectionID := c.Query("connection_id")
|
||||
|
||||
queries, err := h.querySvc.GetSavedQueries(ctx, connectionID)
|
||||
if err != nil {
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to get saved queries")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"queries": queries,
|
||||
})
|
||||
}
|
||||
|
||||
// CreateSavedQuery handles POST /api/saved-queries
|
||||
func (h *QueryHandler) CreateSavedQuery(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var req models.CreateSavedQueryRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
query, err := h.querySvc.CreateSavedQuery(ctx, &req)
|
||||
if err != nil {
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to save query")
|
||||
return
|
||||
}
|
||||
|
||||
utils.CreatedResponse(c, gin.H{
|
||||
"query": query,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateSavedQuery handles PUT /api/saved-queries/:id
|
||||
func (h *QueryHandler) UpdateSavedQuery(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
idStr := c.Param("id")
|
||||
|
||||
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid query ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req models.UpdateSavedQueryRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
query, err := h.querySvc.UpdateSavedQuery(ctx, uint(id), &req)
|
||||
if err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Saved query not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to update saved query")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"query": query,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteSavedQuery handles DELETE /api/saved-queries/:id
|
||||
func (h *QueryHandler) DeleteSavedQuery(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
idStr := c.Param("id")
|
||||
|
||||
id, err := strconv.ParseUint(idStr, 10, 64)
|
||||
if err != nil {
|
||||
utils.ErrorResponse(c, http.StatusBadRequest, models.ErrValidationFailed, "Invalid query ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.querySvc.DeleteSavedQuery(ctx, uint(id)); err != nil {
|
||||
if err == models.ErrNotFound {
|
||||
utils.ErrorResponse(c, http.StatusNotFound, err, "Saved query not found")
|
||||
return
|
||||
}
|
||||
utils.ErrorResponse(c, http.StatusInternalServerError, err, "Failed to delete saved query")
|
||||
return
|
||||
}
|
||||
|
||||
utils.SuccessResponse(c, gin.H{
|
||||
"message": "Saved query deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterRoutes registers query routes
|
||||
func (h *QueryHandler) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// Query execution
|
||||
router.POST("/query", h.ExecuteQuery)
|
||||
|
||||
// Table operations
|
||||
router.GET("/connections/:id/tables", h.GetTables)
|
||||
router.GET("/connections/:id/tables/:name/data", h.GetTableData)
|
||||
router.GET("/connections/:id/tables/:name/structure", h.GetTableStructure)
|
||||
|
||||
// Query history
|
||||
router.GET("/history", h.GetQueryHistory)
|
||||
|
||||
// Saved queries
|
||||
savedQueries := router.Group("/saved-queries")
|
||||
{
|
||||
savedQueries.GET("", h.GetSavedQueries)
|
||||
savedQueries.POST("", h.CreateSavedQuery)
|
||||
savedQueries.PUT("/:id", h.UpdateSavedQuery)
|
||||
savedQueries.DELETE("/:id", h.DeleteSavedQuery)
|
||||
}
|
||||
}
|
||||
95
internal/handler/server.go
Normal file
95
internal/handler/server.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/middleware"
|
||||
"uzdb/internal/services"
|
||||
)
|
||||
|
||||
// HTTPServer represents the HTTP API server
|
||||
type HTTPServer struct {
|
||||
engine *gin.Engine
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// NewHTTPServer creates a new HTTP server
|
||||
func NewHTTPServer(
|
||||
connectionSvc *services.ConnectionService,
|
||||
querySvc *services.QueryService,
|
||||
) *HTTPServer {
|
||||
// Set Gin mode based on environment
|
||||
cfg := config.Get()
|
||||
if cfg.IsProduction() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
engine := gin.New()
|
||||
|
||||
// Create handlers
|
||||
connectionHandler := NewConnectionHandler(connectionSvc)
|
||||
queryHandler := NewQueryHandler(connectionSvc, querySvc)
|
||||
|
||||
// Setup middleware
|
||||
engine.Use(middleware.RecoveryMiddleware())
|
||||
engine.Use(middleware.LoggerMiddleware())
|
||||
engine.Use(middleware.CORSMiddleware())
|
||||
engine.Use(middleware.SecureHeadersMiddleware())
|
||||
|
||||
// Setup routes
|
||||
api := engine.Group("/api")
|
||||
{
|
||||
connectionHandler.RegisterRoutes(api)
|
||||
queryHandler.RegisterRoutes(api)
|
||||
}
|
||||
|
||||
// Health check endpoint
|
||||
engine.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
})
|
||||
})
|
||||
|
||||
return &HTTPServer{
|
||||
engine: engine,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *HTTPServer) Start(port string) error {
|
||||
s.server = &http.Server{
|
||||
Addr: ":" + port,
|
||||
Handler: s.engine,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 15 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
config.GetLogger().Info("starting HTTP API server",
|
||||
zap.String("port", port))
|
||||
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
return fmt.Errorf("failed to start HTTP server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the HTTP server
|
||||
func (s *HTTPServer) Shutdown(ctx context.Context) error {
|
||||
if s.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config.GetLogger().Info("shutting down HTTP server")
|
||||
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
101
internal/middleware/cors.go
Normal file
101
internal/middleware/cors.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"uzdb/internal/config"
|
||||
)
|
||||
|
||||
// CORSMiddleware returns a CORS middleware
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
cfg := config.Get()
|
||||
|
||||
// Set CORS headers
|
||||
c.Header("Access-Control-Allow-Origin", getAllowedOrigin(c, cfg))
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-Request-ID")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, PATCH, DELETE")
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type, X-Request-ID")
|
||||
c.Header("Access-Control-Max-Age", "43200") // 12 hours
|
||||
|
||||
// Handle preflight requests
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// getAllowedOrigin returns the allowed origin based on configuration
|
||||
func getAllowedOrigin(c *gin.Context, cfg *config.Config) string {
|
||||
origin := c.GetHeader("Origin")
|
||||
|
||||
// If no origin header, return empty (same-origin)
|
||||
if origin == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// In development mode, allow all origins
|
||||
if cfg.IsDevelopment() {
|
||||
return origin
|
||||
}
|
||||
|
||||
// In production, validate against allowed origins
|
||||
allowedOrigins := []string{
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8080",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:8080",
|
||||
}
|
||||
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed {
|
||||
return origin
|
||||
}
|
||||
}
|
||||
|
||||
// Check for wildcard patterns
|
||||
for _, allowed := range allowedOrigins {
|
||||
if strings.HasSuffix(allowed, "*") {
|
||||
prefix := strings.TrimSuffix(allowed, "*")
|
||||
if strings.HasPrefix(origin, prefix) {
|
||||
return origin
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default to empty (deny) in production if not matched
|
||||
if cfg.IsProduction() {
|
||||
return ""
|
||||
}
|
||||
|
||||
return origin
|
||||
}
|
||||
|
||||
// SecureHeadersMiddleware adds security-related headers
|
||||
func SecureHeadersMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Prevent MIME type sniffing
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
|
||||
// Enable XSS filter
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Prevent clickjacking
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
|
||||
// Referrer policy
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
|
||||
// Content Security Policy (adjust as needed)
|
||||
c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
125
internal/middleware/error.go
Normal file
125
internal/middleware/error.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// RecoveryMiddleware returns a recovery middleware that handles panics
|
||||
func RecoveryMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger := config.GetLogger()
|
||||
|
||||
// Log the panic with stack trace
|
||||
logger.Error("panic recovered",
|
||||
zap.Any("error", err),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
)
|
||||
|
||||
// Send error response
|
||||
c.JSON(http.StatusInternalServerError, models.ErrorResponse{
|
||||
Error: "INTERNAL_ERROR",
|
||||
Message: "An unexpected error occurred",
|
||||
Timestamp: time.Now(),
|
||||
Path: c.Request.URL.Path,
|
||||
})
|
||||
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorMiddleware returns an error handling middleware
|
||||
func ErrorMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// Check if there are any errors
|
||||
if len(c.Errors) > 0 {
|
||||
handleErrors(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleErrors processes and formats errors
|
||||
func handleErrors(c *gin.Context) {
|
||||
logger := config.GetLogger()
|
||||
|
||||
for _, e := range c.Errors {
|
||||
err := e.Err
|
||||
|
||||
// Log the error
|
||||
logger.Error("request error",
|
||||
zap.String("error", err.Error()),
|
||||
zap.Any("type", e.Type),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
)
|
||||
|
||||
// Determine status code
|
||||
statusCode := http.StatusInternalServerError
|
||||
|
||||
// Map specific errors to status codes
|
||||
switch {
|
||||
case errors.Is(err, models.ErrNotFound):
|
||||
statusCode = http.StatusNotFound
|
||||
case errors.Is(err, models.ErrValidationFailed):
|
||||
statusCode = http.StatusBadRequest
|
||||
case errors.Is(err, models.ErrUnauthorized):
|
||||
statusCode = http.StatusUnauthorized
|
||||
case errors.Is(err, models.ErrForbidden):
|
||||
statusCode = http.StatusForbidden
|
||||
case errors.Is(err, models.ErrConnectionFailed):
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
|
||||
// Send response if not already sent
|
||||
if !c.Writer.Written() {
|
||||
c.JSON(statusCode, models.ErrorResponse{
|
||||
Error: getErrorCode(err),
|
||||
Message: err.Error(),
|
||||
Timestamp: time.Now(),
|
||||
Path: c.Request.URL.Path,
|
||||
})
|
||||
}
|
||||
|
||||
// Only handle first error
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// getErrorCode maps errors to error codes
|
||||
func getErrorCode(err error) string {
|
||||
switch {
|
||||
case errors.Is(err, models.ErrNotFound):
|
||||
return string(models.CodeNotFound)
|
||||
case errors.Is(err, models.ErrValidationFailed):
|
||||
return string(models.CodeValidation)
|
||||
case errors.Is(err, models.ErrUnauthorized):
|
||||
return string(models.CodeUnauthorized)
|
||||
case errors.Is(err, models.ErrForbidden):
|
||||
return string(models.CodeForbidden)
|
||||
case errors.Is(err, models.ErrConnectionFailed):
|
||||
return string(models.CodeConnection)
|
||||
case errors.Is(err, models.ErrQueryFailed):
|
||||
return string(models.CodeQuery)
|
||||
case errors.Is(err, models.ErrEncryptionFailed):
|
||||
return string(models.CodeEncryption)
|
||||
default:
|
||||
return string(models.CodeInternal)
|
||||
}
|
||||
}
|
||||
98
internal/middleware/logger.go
Normal file
98
internal/middleware/logger.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"uzdb/internal/config"
|
||||
)
|
||||
|
||||
// LoggerMiddleware returns a logging middleware using Zap
|
||||
func LoggerMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
logger := config.GetLogger()
|
||||
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
|
||||
// Calculate duration
|
||||
duration := time.Since(start)
|
||||
|
||||
// Get client IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// Get method
|
||||
method := c.Request.Method
|
||||
|
||||
// Get path
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Get status code
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// Get body size
|
||||
bodySize := c.Writer.Size()
|
||||
|
||||
// Determine log level based on status code
|
||||
var level zapcore.Level
|
||||
switch {
|
||||
case statusCode >= 500:
|
||||
level = zapcore.ErrorLevel
|
||||
case statusCode >= 400:
|
||||
level = zapcore.WarnLevel
|
||||
default:
|
||||
level = zapcore.InfoLevel
|
||||
}
|
||||
|
||||
// Create fields
|
||||
fields := []zap.Field{
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.String("ip", clientIP),
|
||||
zap.Duration("duration", duration),
|
||||
zap.Int("body_size", bodySize),
|
||||
}
|
||||
|
||||
// Add error message if any
|
||||
if len(c.Errors) > 0 {
|
||||
fields = append(fields, zap.String("error", c.Errors.String()))
|
||||
}
|
||||
|
||||
// Log the request
|
||||
logger.Log(level, "HTTP request", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
// RequestIDMiddleware adds a request ID to each request
|
||||
func RequestIDMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Try to get request ID from header
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
|
||||
// Generate if not present
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
|
||||
// Set request ID in context
|
||||
c.Set("request_id", requestID)
|
||||
|
||||
// Add to response header
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID generates a simple request ID
|
||||
func generateRequestID() string {
|
||||
return time.Now().Format("20060102150405") + "-" +
|
||||
time.Now().Format("000000.000000")[7:]
|
||||
}
|
||||
85
internal/models/connection.go
Normal file
85
internal/models/connection.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnectionType represents the type of database connection
|
||||
type ConnectionType string
|
||||
|
||||
const (
|
||||
// ConnectionTypeMySQL represents MySQL database
|
||||
ConnectionTypeMySQL ConnectionType = "mysql"
|
||||
// ConnectionTypePostgreSQL represents PostgreSQL database
|
||||
ConnectionTypePostgreSQL ConnectionType = "postgres"
|
||||
// ConnectionTypeSQLite represents SQLite database
|
||||
ConnectionTypeSQLite ConnectionType = "sqlite"
|
||||
)
|
||||
|
||||
// UserConnection represents a user's database connection configuration
|
||||
// This model is stored in the local SQLite database
|
||||
type UserConnection struct {
|
||||
ID string `gorm:"type:varchar(36);primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(100);not null" json:"name"`
|
||||
Type ConnectionType `gorm:"type:varchar(20);not null" json:"type"`
|
||||
Host string `gorm:"type:varchar(255)" json:"host,omitempty"`
|
||||
Port int `gorm:"type:integer" json:"port,omitempty"`
|
||||
Username string `gorm:"type:varchar(100)" json:"username,omitempty"`
|
||||
Password string `gorm:"type:text" json:"password"` // Encrypted password
|
||||
Database string `gorm:"type:varchar(255)" json:"database"`
|
||||
SSLMode string `gorm:"type:varchar(50)" json:"ssl_mode,omitempty"`
|
||||
Timeout int `gorm:"type:integer;default:30" json:"timeout"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
|
||||
|
||||
// For SQLite connections, Database field contains the file path
|
||||
}
|
||||
|
||||
// TableName returns the table name for UserConnection
|
||||
func (UserConnection) TableName() string {
|
||||
return "user_connections"
|
||||
}
|
||||
|
||||
// CreateConnectionRequest represents a request to create a new connection
|
||||
type CreateConnectionRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Type ConnectionType `json:"type" binding:"required"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database" binding:"required"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// UpdateConnectionRequest represents a request to update an existing connection
|
||||
type UpdateConnectionRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type ConnectionType `json:"type"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Database string `json:"database"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// Validate validates the connection request
|
||||
func (r *CreateConnectionRequest) Validate() error {
|
||||
switch r.Type {
|
||||
case ConnectionTypeMySQL, ConnectionTypePostgreSQL:
|
||||
if r.Host == "" {
|
||||
return ErrValidationFailed
|
||||
}
|
||||
if r.Port <= 0 || r.Port > 65535 {
|
||||
return ErrValidationFailed
|
||||
}
|
||||
case ConnectionTypeSQLite:
|
||||
if r.Database == "" {
|
||||
return ErrValidationFailed
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
74
internal/models/errors.go
Normal file
74
internal/models/errors.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package models
|
||||
|
||||
import "errors"
|
||||
|
||||
// Application errors
|
||||
var (
|
||||
// ErrNotFound resource not found
|
||||
ErrNotFound = errors.New("resource not found")
|
||||
|
||||
// ErrAlreadyExists resource already exists
|
||||
ErrAlreadyExists = errors.New("resource already exists")
|
||||
|
||||
// ErrValidationFailed validation failed
|
||||
ErrValidationFailed = errors.New("validation failed")
|
||||
|
||||
// ErrUnauthorized unauthorized access
|
||||
ErrUnauthorized = errors.New("unauthorized access")
|
||||
|
||||
// ErrForbidden forbidden access
|
||||
ErrForbidden = errors.New("forbidden access")
|
||||
|
||||
// ErrInternalServer internal server error
|
||||
ErrInternalServer = errors.New("internal server error")
|
||||
|
||||
// ErrConnectionFailed connection failed
|
||||
ErrConnectionFailed = errors.New("connection failed")
|
||||
|
||||
// ErrQueryFailed query execution failed
|
||||
ErrQueryFailed = errors.New("query execution failed")
|
||||
|
||||
// ErrEncryptionFailed encryption/decryption failed
|
||||
ErrEncryptionFailed = errors.New("encryption/decryption failed")
|
||||
|
||||
// ErrInvalidConfig invalid configuration
|
||||
ErrInvalidConfig = errors.New("invalid configuration")
|
||||
|
||||
// ErrDatabaseLocked database is locked
|
||||
ErrDatabaseLocked = errors.New("database is locked")
|
||||
|
||||
// ErrTimeout operation timeout
|
||||
ErrTimeout = errors.New("operation timeout")
|
||||
)
|
||||
|
||||
// ErrorCode represents error codes for API responses
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
// CodeSuccess successful operation
|
||||
CodeSuccess ErrorCode = "SUCCESS"
|
||||
|
||||
// CodeNotFound resource not found
|
||||
CodeNotFound ErrorCode = "NOT_FOUND"
|
||||
|
||||
// CodeValidation validation error
|
||||
CodeValidation ErrorCode = "VALIDATION_ERROR"
|
||||
|
||||
// CodeUnauthorized unauthorized
|
||||
CodeUnauthorized ErrorCode = "UNAUTHORIZED"
|
||||
|
||||
// CodeForbidden forbidden
|
||||
CodeForbidden ErrorCode = "FORBIDDEN"
|
||||
|
||||
// CodeInternal internal error
|
||||
CodeInternal ErrorCode = "INTERNAL_ERROR"
|
||||
|
||||
// CodeConnection connection error
|
||||
CodeConnection ErrorCode = "CONNECTION_ERROR"
|
||||
|
||||
// CodeQuery query error
|
||||
CodeQuery ErrorCode = "QUERY_ERROR"
|
||||
|
||||
// CodeEncryption encryption error
|
||||
CodeEncryption ErrorCode = "ENCRYPTION_ERROR"
|
||||
)
|
||||
58
internal/models/query.go
Normal file
58
internal/models/query.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// QueryHistory represents a record of executed queries
|
||||
type QueryHistory struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
ConnectionID string `gorm:"type:varchar(36);not null;index" json:"connection_id"`
|
||||
SQL string `gorm:"type:text;not null" json:"sql"`
|
||||
Duration int64 `gorm:"type:bigint" json:"duration_ms"` // Duration in milliseconds
|
||||
ExecutedAt time.Time `gorm:"autoCreateTime;index" json:"executed_at"`
|
||||
RowsAffected int64 `gorm:"type:bigint" json:"rows_affected"`
|
||||
Error string `gorm:"type:text" json:"error,omitempty"`
|
||||
Success bool `gorm:"type:boolean;default:false" json:"success"`
|
||||
ResultPreview string `gorm:"type:text" json:"result_preview,omitempty"` // JSON preview of first few rows
|
||||
}
|
||||
|
||||
// TableName returns the table name for QueryHistory
|
||||
func (QueryHistory) TableName() string {
|
||||
return "query_history"
|
||||
}
|
||||
|
||||
// SavedQuery represents a saved SQL query
|
||||
type SavedQuery struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"type:varchar(255);not null" json:"name"`
|
||||
Description string `gorm:"type:text" json:"description"`
|
||||
SQL string `gorm:"type:text;not null" json:"sql"`
|
||||
ConnectionID string `gorm:"type:varchar(36);index" json:"connection_id"`
|
||||
Tags string `gorm:"type:text" json:"tags"` // Comma-separated tags
|
||||
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for SavedQuery
|
||||
func (SavedQuery) TableName() string {
|
||||
return "saved_queries"
|
||||
}
|
||||
|
||||
// CreateSavedQueryRequest represents a request to save a query
|
||||
type CreateSavedQueryRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
SQL string `json:"sql" binding:"required"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Tags string `json:"tags"`
|
||||
}
|
||||
|
||||
// UpdateSavedQueryRequest represents a request to update a saved query
|
||||
type UpdateSavedQueryRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
SQL string `json:"sql"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Tags string `json:"tags"`
|
||||
}
|
||||
128
internal/models/response.go
Normal file
128
internal/models/response.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// APIResponse represents a standard API response
|
||||
type APIResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResult represents the result of a SQL query execution
|
||||
type QueryResult struct {
|
||||
Columns []string `json:"columns"`
|
||||
Rows [][]interface{} `json:"rows"`
|
||||
RowCount int64 `json:"row_count"`
|
||||
AffectedRows int64 `json:"affected_rows"`
|
||||
Duration int64 `json:"duration_ms"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Table represents a database table
|
||||
type Table struct {
|
||||
Name string `json:"name"`
|
||||
Schema string `json:"schema,omitempty"`
|
||||
Type string `json:"type"` // table, view, etc.
|
||||
RowCount int64 `json:"row_count,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// TableStructure represents the structure of a database table
|
||||
type TableStructure struct {
|
||||
TableName string `json:"table_name"`
|
||||
Schema string `json:"schema,omitempty"`
|
||||
Columns []TableColumn `json:"columns"`
|
||||
Indexes []TableIndex `json:"indexes,omitempty"`
|
||||
ForeignKeys []ForeignKey `json:"foreign_keys,omitempty"`
|
||||
}
|
||||
|
||||
// TableColumn represents a column in a database table
|
||||
type TableColumn struct {
|
||||
Name string `json:"name"`
|
||||
DataType string `json:"data_type"`
|
||||
Nullable bool `json:"nullable"`
|
||||
Default string `json:"default,omitempty"`
|
||||
IsPrimary bool `json:"is_primary"`
|
||||
IsUnique bool `json:"is_unique"`
|
||||
AutoIncrement bool `json:"auto_increment"`
|
||||
Length int `json:"length,omitempty"`
|
||||
Scale int `json:"scale,omitempty"`
|
||||
Comment string `json:"comment,omitempty"`
|
||||
}
|
||||
|
||||
// TableIndex represents an index on a database table
|
||||
type TableIndex struct {
|
||||
Name string `json:"name"`
|
||||
Columns []string `json:"columns"`
|
||||
IsUnique bool `json:"is_unique"`
|
||||
IsPrimary bool `json:"is_primary"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// ForeignKey represents a foreign key constraint
|
||||
type ForeignKey struct {
|
||||
Name string `json:"name"`
|
||||
Columns []string `json:"columns"`
|
||||
ReferencedTable string `json:"referenced_table"`
|
||||
ReferencedColumns []string `json:"referenced_columns"`
|
||||
OnDelete string `json:"on_delete,omitempty"`
|
||||
OnUpdate string `json:"on_update,omitempty"`
|
||||
}
|
||||
|
||||
// ConnectionTestResult represents the result of a connection test
|
||||
type ConnectionTestResult struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Duration int64 `json:"duration_ms"`
|
||||
Metadata *DBMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// DBMetadata represents database metadata
|
||||
type DBMetadata struct {
|
||||
Version string `json:"version"`
|
||||
Database string `json:"database"`
|
||||
User string `json:"user"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
ServerTime string `json:"server_time"`
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error response
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Details map[string]interface{} `json:"details,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
// PaginatedResponse represents a paginated response
|
||||
type PaginatedResponse struct {
|
||||
Data interface{} `json:"data"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// NewAPIResponse creates a new API response
|
||||
func NewAPIResponse(data interface{}) *APIResponse {
|
||||
return &APIResponse{
|
||||
Success: true,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorResponse creates a new error response
|
||||
func NewErrorResponse(message string) *APIResponse {
|
||||
return &APIResponse{
|
||||
Success: false,
|
||||
Error: message,
|
||||
}
|
||||
}
|
||||
382
internal/services/connection.go
Normal file
382
internal/services/connection.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/database"
|
||||
"uzdb/internal/models"
|
||||
"uzdb/internal/utils"
|
||||
)
|
||||
|
||||
// ConnectionService manages database connections
|
||||
type ConnectionService struct {
|
||||
db *gorm.DB
|
||||
connManager *database.ConnectionManager
|
||||
encryptSvc *EncryptionService
|
||||
}
|
||||
|
||||
// NewConnectionService creates a new connection service
|
||||
func NewConnectionService(
|
||||
db *gorm.DB,
|
||||
connManager *database.ConnectionManager,
|
||||
encryptSvc *EncryptionService,
|
||||
) *ConnectionService {
|
||||
return &ConnectionService{
|
||||
db: db,
|
||||
connManager: connManager,
|
||||
encryptSvc: encryptSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllConnections returns all user connections
|
||||
func (s *ConnectionService) GetAllConnections(ctx context.Context) ([]models.UserConnection, error) {
|
||||
var connections []models.UserConnection
|
||||
|
||||
result := s.db.WithContext(ctx).Find(&connections)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get connections: %w", result.Error)
|
||||
}
|
||||
|
||||
// Mask passwords in response
|
||||
for i := range connections {
|
||||
connections[i].Password = s.encryptSvc.MaskPasswordForLogging(connections[i].Password)
|
||||
}
|
||||
|
||||
config.GetLogger().Debug("retrieved all connections",
|
||||
zap.Int("count", len(connections)))
|
||||
|
||||
return connections, nil
|
||||
}
|
||||
|
||||
// GetConnectionByID returns a connection by ID
|
||||
func (s *ConnectionService) GetConnectionByID(ctx context.Context, id string) (*models.UserConnection, error) {
|
||||
var conn models.UserConnection
|
||||
|
||||
result := s.db.WithContext(ctx).First(&conn, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, models.ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get connection: %w", result.Error)
|
||||
}
|
||||
|
||||
return &conn, nil
|
||||
}
|
||||
|
||||
// CreateConnection creates a new connection
|
||||
func (s *ConnectionService) CreateConnection(ctx context.Context, req *models.CreateConnectionRequest) (*models.UserConnection, error) {
|
||||
// Validate request
|
||||
if err := req.Validate(); err != nil {
|
||||
return nil, models.ErrValidationFailed
|
||||
}
|
||||
|
||||
// Encrypt password
|
||||
encryptedPassword := req.Password
|
||||
if req.Password != "" {
|
||||
var err error
|
||||
encryptedPassword, err = s.encryptSvc.EncryptPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt password: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn := &models.UserConnection{
|
||||
ID: utils.GenerateID(),
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: encryptedPassword,
|
||||
Database: req.Database,
|
||||
SSLMode: req.SSLMode,
|
||||
Timeout: req.Timeout,
|
||||
}
|
||||
|
||||
if conn.Timeout <= 0 {
|
||||
conn.Timeout = 30
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Create(conn)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to create connection: %w", result.Error)
|
||||
}
|
||||
|
||||
// Mask password in response
|
||||
conn.Password = s.encryptSvc.MaskPasswordForLogging(conn.Password)
|
||||
|
||||
config.GetLogger().Info("connection created",
|
||||
zap.String("id", conn.ID),
|
||||
zap.String("name", conn.Name),
|
||||
zap.String("type", string(conn.Type)))
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// UpdateConnection updates an existing connection
|
||||
func (s *ConnectionService) UpdateConnection(ctx context.Context, id string, req *models.UpdateConnectionRequest) (*models.UserConnection, error) {
|
||||
// Get existing connection
|
||||
existing, err := s.GetConnectionByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if req.Name != "" {
|
||||
existing.Name = req.Name
|
||||
}
|
||||
if req.Type != "" {
|
||||
existing.Type = req.Type
|
||||
}
|
||||
if req.Host != "" {
|
||||
existing.Host = req.Host
|
||||
}
|
||||
if req.Port > 0 {
|
||||
existing.Port = req.Port
|
||||
}
|
||||
if req.Username != "" {
|
||||
existing.Username = req.Username
|
||||
}
|
||||
if req.Password != "" {
|
||||
// Encrypt new password
|
||||
encryptedPassword, err := s.encryptSvc.EncryptPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt password: %w", err)
|
||||
}
|
||||
existing.Password = encryptedPassword
|
||||
}
|
||||
if req.Database != "" {
|
||||
existing.Database = req.Database
|
||||
}
|
||||
if req.SSLMode != "" {
|
||||
existing.SSLMode = req.SSLMode
|
||||
}
|
||||
if req.Timeout > 0 {
|
||||
existing.Timeout = req.Timeout
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Save(existing)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to update connection: %w", result.Error)
|
||||
}
|
||||
|
||||
// Remove cached connection if exists
|
||||
if err := s.connManager.RemoveConnection(id); err != nil {
|
||||
config.GetLogger().Warn("failed to remove cached connection", zap.Error(err))
|
||||
}
|
||||
|
||||
// Mask password in response
|
||||
existing.Password = s.encryptSvc.MaskPasswordForLogging(existing.Password)
|
||||
|
||||
config.GetLogger().Info("connection updated",
|
||||
zap.String("id", id),
|
||||
zap.String("name", existing.Name))
|
||||
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// DeleteConnection deletes a connection
|
||||
func (s *ConnectionService) DeleteConnection(ctx context.Context, id string) error {
|
||||
// Check if connection exists
|
||||
if _, err := s.GetConnectionByID(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove from connection manager
|
||||
if err := s.connManager.RemoveConnection(id); err != nil {
|
||||
config.GetLogger().Warn("failed to remove from connection manager", zap.Error(err))
|
||||
}
|
||||
|
||||
// Delete from database
|
||||
result := s.db.WithContext(ctx).Delete(&models.UserConnection{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete connection: %w", result.Error)
|
||||
}
|
||||
|
||||
config.GetLogger().Info("connection deleted", zap.String("id", id))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection tests a database connection
|
||||
func (s *ConnectionService) TestConnection(ctx context.Context, id string) (*models.ConnectionTestResult, error) {
|
||||
// Get connection config
|
||||
conn, err := s.GetConnectionByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt password
|
||||
password, err := s.encryptSvc.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Create temporary connection
|
||||
tempConn, err := s.connManager.GetConnection(conn, password)
|
||||
if err != nil {
|
||||
return &models.ConnectionTestResult{
|
||||
Success: false,
|
||||
Message: fmt.Sprintf("Connection failed: %v", err),
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get metadata
|
||||
metadata, err := tempConn.GetMetadata()
|
||||
if err != nil {
|
||||
config.GetLogger().Warn("failed to get metadata", zap.Error(err))
|
||||
}
|
||||
|
||||
return &models.ConnectionTestResult{
|
||||
Success: true,
|
||||
Message: "Connection successful",
|
||||
Duration: time.Since(startTime).Milliseconds(),
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecuteQuery executes a SQL query on a connection
|
||||
func (s *ConnectionService) ExecuteQuery(ctx context.Context, connectionID, sql string) (*models.QueryResult, error) {
|
||||
// Get connection config
|
||||
conn, err := s.GetConnectionByID(ctx, connectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt password
|
||||
password, err := s.encryptSvc.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
// Get or create connection
|
||||
dbConn, err := s.connManager.GetConnection(conn, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
|
||||
// Execute query
|
||||
startTime := time.Now()
|
||||
|
||||
var result *models.QueryResult
|
||||
if utils.IsReadOnlyQuery(sql) {
|
||||
result, err = dbConn.ExecuteQuery(sql)
|
||||
} else {
|
||||
result, err = dbConn.ExecuteStatement(sql)
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// Record in history
|
||||
history := &models.QueryHistory{
|
||||
ConnectionID: connectionID,
|
||||
SQL: utils.TruncateString(sql, 10000),
|
||||
Duration: duration.Milliseconds(),
|
||||
Success: err == nil,
|
||||
}
|
||||
|
||||
if result != nil {
|
||||
history.RowsAffected = result.AffectedRows
|
||||
}
|
||||
if err != nil {
|
||||
history.Error = err.Error()
|
||||
}
|
||||
|
||||
s.db.WithContext(ctx).Create(history)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query execution failed: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTables returns all tables for a connection
|
||||
func (s *ConnectionService) GetTables(ctx context.Context, connectionID string) ([]models.Table, error) {
|
||||
// Get connection config
|
||||
conn, err := s.GetConnectionByID(ctx, connectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt password
|
||||
password, err := s.encryptSvc.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
// Get or create connection
|
||||
dbConn, err := s.connManager.GetConnection(conn, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
|
||||
return dbConn.GetTables("")
|
||||
}
|
||||
|
||||
// GetTableData returns data from a table
|
||||
func (s *ConnectionService) GetTableData(
|
||||
ctx context.Context,
|
||||
connectionID, tableName string,
|
||||
limit, offset int,
|
||||
) (*models.QueryResult, error) {
|
||||
// Validate limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
// Build query based on connection type
|
||||
conn, err := s.GetConnectionByID(ctx, connectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var query string
|
||||
switch conn.Type {
|
||||
case models.ConnectionTypeMySQL:
|
||||
query = fmt.Sprintf("SELECT * FROM `%s` LIMIT %d OFFSET %d", tableName, limit, offset)
|
||||
case models.ConnectionTypePostgreSQL:
|
||||
query = fmt.Sprintf(`SELECT * FROM "%s" LIMIT %d OFFSET %d`, tableName, limit, offset)
|
||||
case models.ConnectionTypeSQLite:
|
||||
query = fmt.Sprintf(`SELECT * FROM "%s" LIMIT %d OFFSET %d`, tableName, limit, offset)
|
||||
default:
|
||||
return nil, models.ErrValidationFailed
|
||||
}
|
||||
|
||||
return s.ExecuteQuery(ctx, connectionID, query)
|
||||
}
|
||||
|
||||
// GetTableStructure returns the structure of a table
|
||||
func (s *ConnectionService) GetTableStructure(ctx context.Context, connectionID, tableName string) (*models.TableStructure, error) {
|
||||
// Get connection config
|
||||
conn, err := s.GetConnectionByID(ctx, connectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt password
|
||||
password, err := s.encryptSvc.DecryptPassword(conn.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
// Get or create connection
|
||||
dbConn, err := s.connManager.GetConnection(conn, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get connection: %w", err)
|
||||
}
|
||||
|
||||
return dbConn.GetTableStructure(tableName)
|
||||
}
|
||||
199
internal/services/encryption.go
Normal file
199
internal/services/encryption.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
"uzdb/internal/utils"
|
||||
)
|
||||
|
||||
// EncryptionService handles encryption and decryption of sensitive data
|
||||
type EncryptionService struct {
|
||||
key []byte
|
||||
cipher cipher.Block
|
||||
gcm cipher.AEAD
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
encryptionInstance *EncryptionService
|
||||
encryptionOnce sync.Once
|
||||
)
|
||||
|
||||
// GetEncryptionService returns the singleton encryption service instance
|
||||
func GetEncryptionService() *EncryptionService {
|
||||
return encryptionInstance
|
||||
}
|
||||
|
||||
// InitEncryptionService initializes the encryption service
|
||||
func InitEncryptionService(cfg *config.EncryptionConfig) (*EncryptionService, error) {
|
||||
var err error
|
||||
encryptionOnce.Do(func() {
|
||||
encryptionInstance = &EncryptionService{}
|
||||
err = encryptionInstance.init(cfg)
|
||||
})
|
||||
return encryptionInstance, err
|
||||
}
|
||||
|
||||
// init initializes the encryption service with a key
|
||||
func (s *EncryptionService) init(cfg *config.EncryptionConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Try to load existing key or generate new one
|
||||
key, err := s.loadOrGenerateKey(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load/generate key: %w", err)
|
||||
}
|
||||
|
||||
s.key = key
|
||||
|
||||
// Create AES cipher
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
s.cipher = block
|
||||
|
||||
// Create GCM mode
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
s.gcm = gcm
|
||||
|
||||
config.GetLogger().Info("encryption service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadOrGenerateKey loads existing key or generates a new one
|
||||
func (s *EncryptionService) loadOrGenerateKey(cfg *config.EncryptionConfig) ([]byte, error) {
|
||||
// Use provided key if available
|
||||
if cfg.Key != "" {
|
||||
key := []byte(cfg.Key)
|
||||
// Ensure key is correct length (32 bytes for AES-256)
|
||||
if len(key) < 32 {
|
||||
// Pad key
|
||||
padded := make([]byte, 32)
|
||||
copy(padded, key)
|
||||
key = padded
|
||||
} else if len(key) > 32 {
|
||||
key = key[:32]
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Try to load from file
|
||||
if cfg.KeyFile != "" {
|
||||
if data, err := os.ReadFile(cfg.KeyFile); err == nil {
|
||||
key := []byte(data)
|
||||
if len(key) >= 32 {
|
||||
return key[:32], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate new key
|
||||
key := make([]byte, 32) // AES-256
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
// Save key to file if path provided
|
||||
if cfg.KeyFile != "" {
|
||||
if err := os.WriteFile(cfg.KeyFile, key, 0600); err != nil {
|
||||
config.GetLogger().Warn("failed to save encryption key", zap.Error(err))
|
||||
} else {
|
||||
config.GetLogger().Info("encryption key generated and saved",
|
||||
zap.String("path", cfg.KeyFile))
|
||||
}
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-GCM
|
||||
func (s *EncryptionService) Encrypt(plaintext string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.gcm == nil {
|
||||
return "", models.ErrEncryptionFailed
|
||||
}
|
||||
|
||||
// Generate nonce
|
||||
nonce := make([]byte, s.gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
ciphertext := s.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode to base64
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-GCM
|
||||
func (s *EncryptionService) Decrypt(ciphertext string) (string, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.gcm == nil {
|
||||
return "", models.ErrEncryptionFailed
|
||||
}
|
||||
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
|
||||
}
|
||||
|
||||
// Verify nonce size
|
||||
nonceSize := s.gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", models.ErrEncryptionFailed
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := s.gcm.Open(nil, nonce, ciphertextBytes, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// EncryptPassword encrypts a password for storage
|
||||
func (s *EncryptionService) EncryptPassword(password string) (string, error) {
|
||||
if password == "" {
|
||||
return "", nil
|
||||
}
|
||||
return s.Encrypt(password)
|
||||
}
|
||||
|
||||
// DecryptPassword decrypts a stored password
|
||||
func (s *EncryptionService) DecryptPassword(encryptedPassword string) (string, error) {
|
||||
if encryptedPassword == "" {
|
||||
return "", nil
|
||||
}
|
||||
return s.Decrypt(encryptedPassword)
|
||||
}
|
||||
|
||||
// MaskPasswordForLogging masks password for safe logging
|
||||
func (s *EncryptionService) MaskPasswordForLogging(password string) string {
|
||||
return utils.MaskPassword(password)
|
||||
}
|
||||
236
internal/services/query.go
Normal file
236
internal/services/query.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// QueryService handles query-related operations
|
||||
type QueryService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewQueryService creates a new query service
|
||||
func NewQueryService(db *gorm.DB) *QueryService {
|
||||
return &QueryService{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// GetQueryHistory returns query history with pagination
|
||||
func (s *QueryService) GetQueryHistory(
|
||||
ctx context.Context,
|
||||
connectionID string,
|
||||
page, pageSize int,
|
||||
) ([]models.QueryHistory, int64, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
var total int64
|
||||
var history []models.QueryHistory
|
||||
|
||||
query := s.db.WithContext(ctx).Model(&models.QueryHistory{})
|
||||
|
||||
if connectionID != "" {
|
||||
query = query.Where("connection_id = ?", connectionID)
|
||||
}
|
||||
|
||||
// Get total count
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count history: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Order("executed_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&history).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to get history: %w", err)
|
||||
}
|
||||
|
||||
config.GetLogger().Debug("retrieved query history",
|
||||
zap.String("connection_id", connectionID),
|
||||
zap.Int("page", page),
|
||||
zap.Int("page_size", pageSize),
|
||||
zap.Int64("total", total))
|
||||
|
||||
return history, total, nil
|
||||
}
|
||||
|
||||
// GetSavedQueries returns all saved queries
|
||||
func (s *QueryService) GetSavedQueries(
|
||||
ctx context.Context,
|
||||
connectionID string,
|
||||
) ([]models.SavedQuery, error) {
|
||||
var queries []models.SavedQuery
|
||||
|
||||
query := s.db.WithContext(ctx)
|
||||
if connectionID != "" {
|
||||
query = query.Where("connection_id = ?", connectionID)
|
||||
}
|
||||
|
||||
result := query.Order("name ASC").Find(&queries)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get saved queries: %w", result.Error)
|
||||
}
|
||||
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
// GetSavedQueryByID returns a saved query by ID
|
||||
func (s *QueryService) GetSavedQueryByID(ctx context.Context, id uint) (*models.SavedQuery, error) {
|
||||
var query models.SavedQuery
|
||||
|
||||
result := s.db.WithContext(ctx).First(&query, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, models.ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get saved query: %w", result.Error)
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// CreateSavedQuery creates a new saved query
|
||||
func (s *QueryService) CreateSavedQuery(
|
||||
ctx context.Context,
|
||||
req *models.CreateSavedQueryRequest,
|
||||
) (*models.SavedQuery, error) {
|
||||
query := &models.SavedQuery{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
SQL: req.SQL,
|
||||
ConnectionID: req.ConnectionID,
|
||||
Tags: req.Tags,
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Create(query)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to create saved query: %w", result.Error)
|
||||
}
|
||||
|
||||
config.GetLogger().Info("saved query created",
|
||||
zap.Uint("id", query.ID),
|
||||
zap.String("name", query.Name))
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// UpdateSavedQuery updates an existing saved query
|
||||
func (s *QueryService) UpdateSavedQuery(
|
||||
ctx context.Context,
|
||||
id uint,
|
||||
req *models.UpdateSavedQueryRequest,
|
||||
) (*models.SavedQuery, error) {
|
||||
// Get existing query
|
||||
existing, err := s.GetSavedQueryByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Update fields
|
||||
if req.Name != "" {
|
||||
existing.Name = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
existing.Description = req.Description
|
||||
}
|
||||
if req.SQL != "" {
|
||||
existing.SQL = req.SQL
|
||||
}
|
||||
if req.ConnectionID != "" {
|
||||
existing.ConnectionID = req.ConnectionID
|
||||
}
|
||||
if req.Tags != "" {
|
||||
existing.Tags = req.Tags
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Save(existing)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to update saved query: %w", result.Error)
|
||||
}
|
||||
|
||||
config.GetLogger().Info("saved query updated",
|
||||
zap.Uint("id", id),
|
||||
zap.String("name", existing.Name))
|
||||
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// DeleteSavedQuery deletes a saved query
|
||||
func (s *QueryService) DeleteSavedQuery(ctx context.Context, id uint) error {
|
||||
// Check if exists
|
||||
if _, err := s.GetSavedQueryByID(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).Delete(&models.SavedQuery{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete saved query: %w", result.Error)
|
||||
}
|
||||
|
||||
config.GetLogger().Info("saved query deleted", zap.Uint("id", id))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearOldHistory clears query history older than specified days
|
||||
func (s *QueryService) ClearOldHistory(ctx context.Context, days int) (int64, error) {
|
||||
if days <= 0 {
|
||||
days = 30 // Default to 30 days
|
||||
}
|
||||
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("executed_at < ?", cutoffTime).
|
||||
Delete(&models.QueryHistory{})
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to clear old history: %w", result.Error)
|
||||
}
|
||||
|
||||
config.GetLogger().Info("cleared old query history",
|
||||
zap.Int64("deleted_count", result.RowsAffected),
|
||||
zap.Int("days", days))
|
||||
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// GetRecentQueries returns recent queries for quick access
|
||||
func (s *QueryService) GetRecentQueries(
|
||||
ctx context.Context,
|
||||
connectionID string,
|
||||
limit int,
|
||||
) ([]models.QueryHistory, error) {
|
||||
if limit <= 0 || limit > 50 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
var queries []models.QueryHistory
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
Where("connection_id = ? AND success = ?", connectionID, true).
|
||||
Order("executed_at DESC").
|
||||
Limit(limit).
|
||||
Find(&queries)
|
||||
|
||||
if query.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get recent queries: %w", query.Error)
|
||||
}
|
||||
|
||||
return queries, nil
|
||||
}
|
||||
136
internal/utils/errors.go
Normal file
136
internal/utils/errors.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GenerateID generates a unique ID (UUID-like)
|
||||
func GenerateID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return formatUUID(b)
|
||||
}
|
||||
|
||||
// formatUUID formats bytes as UUID string
|
||||
func formatUUID(b []byte) string {
|
||||
uuid := make([]byte, 36)
|
||||
hex.Encode(uuid[0:8], b[0:4])
|
||||
hex.Encode(uuid[9:13], b[4:6])
|
||||
hex.Encode(uuid[14:18], b[6:8])
|
||||
hex.Encode(uuid[19:23], b[8:10])
|
||||
hex.Encode(uuid[24:], b[10:])
|
||||
uuid[8] = '-'
|
||||
uuid[13] = '-'
|
||||
uuid[18] = '-'
|
||||
uuid[23] = '-'
|
||||
return string(uuid)
|
||||
}
|
||||
|
||||
// FormatDuration formats duration in milliseconds
|
||||
func FormatDuration(d time.Duration) int64 {
|
||||
return d.Milliseconds()
|
||||
}
|
||||
|
||||
// SanitizeSQL removes potentially dangerous SQL patterns
|
||||
// Note: This is not a replacement for parameterized queries
|
||||
func SanitizeSQL(sql string) string {
|
||||
// Remove multiple semicolons
|
||||
sql = strings.ReplaceAll(sql, ";;", ";")
|
||||
// Trim whitespace
|
||||
sql = strings.TrimSpace(sql)
|
||||
return sql
|
||||
}
|
||||
|
||||
// TruncateString truncates a string to max length
|
||||
func TruncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
// GenerateRandomBytes generates random bytes
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// EncodeBase64 encodes bytes to base64 string
|
||||
func EncodeBase64(data []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// DecodeBase64 decodes base64 string to bytes
|
||||
func DecodeBase64(s string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
// MaskPassword masks password for logging
|
||||
func MaskPassword(password string) string {
|
||||
if len(password) <= 4 {
|
||||
return strings.Repeat("*", len(password))
|
||||
}
|
||||
return password[:2] + strings.Repeat("*", len(password)-2)
|
||||
}
|
||||
|
||||
// ContainsAny checks if string contains any of the substrings
|
||||
func ContainsAny(s string, substrings ...string) bool {
|
||||
for _, sub := range substrings {
|
||||
if strings.Contains(s, sub) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsReadOnlyQuery checks if SQL query is read-only
|
||||
func IsReadOnlyQuery(sql string) bool {
|
||||
sql = strings.TrimSpace(strings.ToUpper(sql))
|
||||
|
||||
// Read-only operations
|
||||
readOnlyPrefixes := []string{
|
||||
"SELECT",
|
||||
"SHOW",
|
||||
"DESCRIBE",
|
||||
"EXPLAIN",
|
||||
"WITH",
|
||||
}
|
||||
|
||||
for _, prefix := range readOnlyPrefixes {
|
||||
if strings.HasPrefix(sql, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDDLQuery checks if SQL query is DDL
|
||||
func IsDDLQuery(sql string) bool {
|
||||
sql = strings.TrimSpace(strings.ToUpper(sql))
|
||||
|
||||
ddlKeywords := []string{
|
||||
"CREATE",
|
||||
"ALTER",
|
||||
"DROP",
|
||||
"TRUNCATE",
|
||||
"RENAME",
|
||||
}
|
||||
|
||||
for _, keyword := range ddlKeywords {
|
||||
if strings.HasPrefix(sql, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
102
internal/utils/response.go
Normal file
102
internal/utils/response.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"uzdb/internal/config"
|
||||
"uzdb/internal/models"
|
||||
)
|
||||
|
||||
// ErrorResponse sends an error response
|
||||
func ErrorResponse(c *gin.Context, statusCode int, err error, message string) {
|
||||
logger := config.GetLogger()
|
||||
|
||||
response := models.ErrorResponse{
|
||||
Error: getErrorCode(err),
|
||||
Message: message,
|
||||
Timestamp: time.Now(),
|
||||
Path: c.Request.URL.Path,
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
response.Message = err.Error()
|
||||
}
|
||||
|
||||
// Log the error
|
||||
logger.Error("error response",
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.String("error", response.Error),
|
||||
zap.String("message", response.Message),
|
||||
zap.String("path", response.Path),
|
||||
)
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// SuccessResponse sends a success response
|
||||
func SuccessResponse(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, models.NewAPIResponse(data))
|
||||
}
|
||||
|
||||
// CreatedResponse sends a created response
|
||||
func CreatedResponse(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusCreated, models.NewAPIResponse(data))
|
||||
}
|
||||
|
||||
// getErrorCode maps errors to error codes
|
||||
func getErrorCode(err error) string {
|
||||
if err == nil {
|
||||
return string(models.CodeSuccess)
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, models.ErrNotFound):
|
||||
return string(models.CodeNotFound)
|
||||
case errors.Is(err, models.ErrValidationFailed):
|
||||
return string(models.CodeValidation)
|
||||
case errors.Is(err, models.ErrUnauthorized):
|
||||
return string(models.CodeUnauthorized)
|
||||
case errors.Is(err, models.ErrForbidden):
|
||||
return string(models.CodeForbidden)
|
||||
case errors.Is(err, models.ErrConnectionFailed):
|
||||
return string(models.CodeConnection)
|
||||
case errors.Is(err, models.ErrQueryFailed):
|
||||
return string(models.CodeQuery)
|
||||
case errors.Is(err, models.ErrEncryptionFailed):
|
||||
return string(models.CodeEncryption)
|
||||
default:
|
||||
return string(models.CodeInternal)
|
||||
}
|
||||
}
|
||||
|
||||
// WrapError wraps an error with context
|
||||
func WrapError(err error, message string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
return &wrappedError{
|
||||
err: err,
|
||||
message: message,
|
||||
}
|
||||
}
|
||||
|
||||
type wrappedError struct {
|
||||
err error
|
||||
message string
|
||||
}
|
||||
|
||||
func (w *wrappedError) Error() string {
|
||||
if w.message != "" {
|
||||
return w.message + ": " + w.err.Error()
|
||||
}
|
||||
return w.err.Error()
|
||||
}
|
||||
|
||||
func (w *wrappedError) Unwrap() error {
|
||||
return w.err
|
||||
}
|
||||
Reference in New Issue
Block a user