package middleware import ( "errors" "net/http" "runtime/debug" "time" "github.com/gin-gonic/gin" "go.uber.org/zap" "uzdb/internal/config" "uzdb/internal/models" ) // RecoveryMiddleware returns a recovery middleware that handles panics func RecoveryMiddleware() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { logger := config.GetLogger() // Log the panic with stack trace logger.Error("panic recovered", zap.Any("error", err), zap.String("stack", string(debug.Stack())), zap.String("path", c.Request.URL.Path), zap.String("method", c.Request.Method), ) // Send error response c.JSON(http.StatusInternalServerError, models.ErrorResponse{ Error: "INTERNAL_ERROR", Message: "An unexpected error occurred", Timestamp: time.Now(), Path: c.Request.URL.Path, }) c.Abort() } }() c.Next() } } // ErrorMiddleware returns an error handling middleware func ErrorMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Next() // Check if there are any errors if len(c.Errors) > 0 { handleErrors(c) } } } // handleErrors processes and formats errors func handleErrors(c *gin.Context) { logger := config.GetLogger() for _, e := range c.Errors { err := e.Err // Log the error logger.Error("request error", zap.String("error", err.Error()), zap.Any("type", e.Type), zap.String("path", c.Request.URL.Path), ) // Determine status code statusCode := http.StatusInternalServerError // Map specific errors to status codes switch { case errors.Is(err, models.ErrNotFound): statusCode = http.StatusNotFound case errors.Is(err, models.ErrValidationFailed): statusCode = http.StatusBadRequest case errors.Is(err, models.ErrUnauthorized): statusCode = http.StatusUnauthorized case errors.Is(err, models.ErrForbidden): statusCode = http.StatusForbidden case errors.Is(err, models.ErrConnectionFailed): statusCode = http.StatusBadGateway } // Send response if not already sent if !c.Writer.Written() { c.JSON(statusCode, models.ErrorResponse{ Error: getErrorCode(err), Message: err.Error(), Timestamp: time.Now(), Path: c.Request.URL.Path, }) } // Only handle first error break } } // getErrorCode maps errors to error codes func getErrorCode(err error) string { switch { case errors.Is(err, models.ErrNotFound): return string(models.CodeNotFound) case errors.Is(err, models.ErrValidationFailed): return string(models.CodeValidation) case errors.Is(err, models.ErrUnauthorized): return string(models.CodeUnauthorized) case errors.Is(err, models.ErrForbidden): return string(models.CodeForbidden) case errors.Is(err, models.ErrConnectionFailed): return string(models.CodeConnection) case errors.Is(err, models.ErrQueryFailed): return string(models.CodeQuery) case errors.Is(err, models.ErrEncryptionFailed): return string(models.CodeEncryption) default: return string(models.CodeInternal) } }