package services import ( "context" "fmt" "time" "go.uber.org/zap" "gorm.io/gorm" "uzdb/internal/config" "uzdb/internal/database" "uzdb/internal/models" "uzdb/internal/utils" ) // ConnectionService manages database connections type ConnectionService struct { db *gorm.DB connManager *database.ConnectionManager encryptSvc *EncryptionService } // NewConnectionService creates a new connection service func NewConnectionService( db *gorm.DB, connManager *database.ConnectionManager, encryptSvc *EncryptionService, ) *ConnectionService { return &ConnectionService{ db: db, connManager: connManager, encryptSvc: encryptSvc, } } // GetAllConnections returns all user connections func (s *ConnectionService) GetAllConnections(ctx context.Context) ([]models.UserConnection, error) { var connections []models.UserConnection result := s.db.WithContext(ctx).Find(&connections) if result.Error != nil { return nil, fmt.Errorf("failed to get connections: %w", result.Error) } // Mask passwords in response for i := range connections { connections[i].Password = s.encryptSvc.MaskPasswordForLogging(connections[i].Password) } config.GetLogger().Debug("retrieved all connections", zap.Int("count", len(connections))) return connections, nil } // GetConnectionByID returns a connection by ID func (s *ConnectionService) GetConnectionByID(ctx context.Context, id string) (*models.UserConnection, error) { var conn models.UserConnection result := s.db.WithContext(ctx).First(&conn, "id = ?", id) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, models.ErrNotFound } return nil, fmt.Errorf("failed to get connection: %w", result.Error) } return &conn, nil } // CreateConnection creates a new connection func (s *ConnectionService) CreateConnection(ctx context.Context, req *models.CreateConnectionRequest) (*models.UserConnection, error) { // Validate request if err := req.Validate(); err != nil { return nil, models.ErrValidationFailed } // Encrypt password encryptedPassword := req.Password if req.Password != "" { var err error encryptedPassword, err = s.encryptSvc.EncryptPassword(req.Password) if err != nil { return nil, fmt.Errorf("failed to encrypt password: %w", err) } } conn := &models.UserConnection{ ID: utils.GenerateID(), Name: req.Name, Type: req.Type, Host: req.Host, Port: req.Port, Username: req.Username, Password: encryptedPassword, Database: req.Database, SSLMode: req.SSLMode, Timeout: req.Timeout, } if conn.Timeout <= 0 { conn.Timeout = 30 } result := s.db.WithContext(ctx).Create(conn) if result.Error != nil { return nil, fmt.Errorf("failed to create connection: %w", result.Error) } // Mask password in response conn.Password = s.encryptSvc.MaskPasswordForLogging(conn.Password) config.GetLogger().Info("connection created", zap.String("id", conn.ID), zap.String("name", conn.Name), zap.String("type", string(conn.Type))) return conn, nil } // UpdateConnection updates an existing connection func (s *ConnectionService) UpdateConnection(ctx context.Context, id string, req *models.UpdateConnectionRequest) (*models.UserConnection, error) { // Get existing connection existing, err := s.GetConnectionByID(ctx, id) if err != nil { return nil, err } // Update fields if req.Name != "" { existing.Name = req.Name } if req.Type != "" { existing.Type = req.Type } if req.Host != "" { existing.Host = req.Host } if req.Port > 0 { existing.Port = req.Port } if req.Username != "" { existing.Username = req.Username } if req.Password != "" { // Encrypt new password encryptedPassword, err := s.encryptSvc.EncryptPassword(req.Password) if err != nil { return nil, fmt.Errorf("failed to encrypt password: %w", err) } existing.Password = encryptedPassword } if req.Database != "" { existing.Database = req.Database } if req.SSLMode != "" { existing.SSLMode = req.SSLMode } if req.Timeout > 0 { existing.Timeout = req.Timeout } result := s.db.WithContext(ctx).Save(existing) if result.Error != nil { return nil, fmt.Errorf("failed to update connection: %w", result.Error) } // Remove cached connection if exists if err := s.connManager.RemoveConnection(id); err != nil { config.GetLogger().Warn("failed to remove cached connection", zap.Error(err)) } // Mask password in response existing.Password = s.encryptSvc.MaskPasswordForLogging(existing.Password) config.GetLogger().Info("connection updated", zap.String("id", id), zap.String("name", existing.Name)) return existing, nil } // DeleteConnection deletes a connection func (s *ConnectionService) DeleteConnection(ctx context.Context, id string) error { // Check if connection exists if _, err := s.GetConnectionByID(ctx, id); err != nil { return err } // Remove from connection manager if err := s.connManager.RemoveConnection(id); err != nil { config.GetLogger().Warn("failed to remove from connection manager", zap.Error(err)) } // Delete from database result := s.db.WithContext(ctx).Delete(&models.UserConnection{}, "id = ?", id) if result.Error != nil { return fmt.Errorf("failed to delete connection: %w", result.Error) } config.GetLogger().Info("connection deleted", zap.String("id", id)) return nil } // TestConnection tests a database connection func (s *ConnectionService) TestConnection(ctx context.Context, id string) (*models.ConnectionTestResult, error) { // Get connection config conn, err := s.GetConnectionByID(ctx, id) if err != nil { return nil, err } // Decrypt password password, err := s.encryptSvc.DecryptPassword(conn.Password) if err != nil { return nil, fmt.Errorf("failed to decrypt password: %w", err) } startTime := time.Now() // Create temporary connection tempConn, err := s.connManager.GetConnection(conn, password) if err != nil { return &models.ConnectionTestResult{ Success: false, Message: fmt.Sprintf("Connection failed: %v", err), Duration: time.Since(startTime).Milliseconds(), }, nil } // Get metadata metadata, err := tempConn.GetMetadata() if err != nil { config.GetLogger().Warn("failed to get metadata", zap.Error(err)) } return &models.ConnectionTestResult{ Success: true, Message: "Connection successful", Duration: time.Since(startTime).Milliseconds(), Metadata: metadata, }, nil } // ExecuteQuery executes a SQL query on a connection func (s *ConnectionService) ExecuteQuery(ctx context.Context, connectionID, sql string) (*models.QueryResult, error) { // Get connection config conn, err := s.GetConnectionByID(ctx, connectionID) if err != nil { return nil, err } // Decrypt password password, err := s.encryptSvc.DecryptPassword(conn.Password) if err != nil { return nil, fmt.Errorf("failed to decrypt password: %w", err) } // Get or create connection dbConn, err := s.connManager.GetConnection(conn, password) if err != nil { return nil, fmt.Errorf("failed to get connection: %w", err) } // Execute query startTime := time.Now() var result *models.QueryResult if utils.IsReadOnlyQuery(sql) { result, err = dbConn.ExecuteQuery(sql) } else { result, err = dbConn.ExecuteStatement(sql) } duration := time.Since(startTime) // Record in history history := &models.QueryHistory{ ConnectionID: connectionID, SQL: utils.TruncateString(sql, 10000), Duration: duration.Milliseconds(), Success: err == nil, } if result != nil { history.RowsAffected = result.AffectedRows } if err != nil { history.Error = err.Error() } s.db.WithContext(ctx).Create(history) if err != nil { return nil, fmt.Errorf("query execution failed: %w", err) } return result, nil } // GetTables returns all tables for a connection func (s *ConnectionService) GetTables(ctx context.Context, connectionID string) ([]models.Table, error) { // Get connection config conn, err := s.GetConnectionByID(ctx, connectionID) if err != nil { return nil, err } // Decrypt password password, err := s.encryptSvc.DecryptPassword(conn.Password) if err != nil { return nil, fmt.Errorf("failed to decrypt password: %w", err) } // Get or create connection dbConn, err := s.connManager.GetConnection(conn, password) if err != nil { return nil, fmt.Errorf("failed to get connection: %w", err) } return dbConn.GetTables("") } // GetTableData returns data from a table func (s *ConnectionService) GetTableData( ctx context.Context, connectionID, tableName string, limit, offset int, ) (*models.QueryResult, error) { // Validate limit if limit <= 0 || limit > 1000 { limit = 100 } if offset < 0 { offset = 0 } // Build query based on connection type conn, err := s.GetConnectionByID(ctx, connectionID) if err != nil { return nil, err } var query string switch conn.Type { case models.ConnectionTypeMySQL: query = fmt.Sprintf("SELECT * FROM `%s` LIMIT %d OFFSET %d", tableName, limit, offset) case models.ConnectionTypePostgreSQL: query = fmt.Sprintf(`SELECT * FROM "%s" LIMIT %d OFFSET %d`, tableName, limit, offset) case models.ConnectionTypeSQLite: query = fmt.Sprintf(`SELECT * FROM "%s" LIMIT %d OFFSET %d`, tableName, limit, offset) default: return nil, models.ErrValidationFailed } return s.ExecuteQuery(ctx, connectionID, query) } // GetTableStructure returns the structure of a table func (s *ConnectionService) GetTableStructure(ctx context.Context, connectionID, tableName string) (*models.TableStructure, error) { // Get connection config conn, err := s.GetConnectionByID(ctx, connectionID) if err != nil { return nil, err } // Decrypt password password, err := s.encryptSvc.DecryptPassword(conn.Password) if err != nil { return nil, fmt.Errorf("failed to decrypt password: %w", err) } // Get or create connection dbConn, err := s.connManager.GetConnection(conn, password) if err != nil { return nil, fmt.Errorf("failed to get connection: %w", err) } return dbConn.GetTableStructure(tableName) }