package middleware import ( "net/http" "strings" "github.com/gin-gonic/gin" "uzdb/internal/config" ) // CORSMiddleware returns a CORS middleware func CORSMiddleware() gin.HandlerFunc { return func(c *gin.Context) { cfg := config.Get() // Set CORS headers c.Header("Access-Control-Allow-Origin", getAllowedOrigin(c, cfg)) c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-Request-ID") c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, PATCH, DELETE") c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type, X-Request-ID") c.Header("Access-Control-Max-Age", "43200") // 12 hours // Handle preflight requests if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } // getAllowedOrigin returns the allowed origin based on configuration func getAllowedOrigin(c *gin.Context, cfg *config.Config) string { origin := c.GetHeader("Origin") // If no origin header, return empty (same-origin) if origin == "" { return "" } // In development mode, allow all origins if cfg.IsDevelopment() { return origin } // In production, validate against allowed origins allowedOrigins := []string{ "http://localhost:3000", "http://localhost:8080", "http://127.0.0.1:3000", "http://127.0.0.1:8080", } for _, allowed := range allowedOrigins { if origin == allowed { return origin } } // Check for wildcard patterns for _, allowed := range allowedOrigins { if strings.HasSuffix(allowed, "*") { prefix := strings.TrimSuffix(allowed, "*") if strings.HasPrefix(origin, prefix) { return origin } } } // Default to empty (deny) in production if not matched if cfg.IsProduction() { return "" } return origin } // SecureHeadersMiddleware adds security-related headers func SecureHeadersMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // Prevent MIME type sniffing c.Header("X-Content-Type-Options", "nosniff") // Enable XSS filter c.Header("X-XSS-Protection", "1; mode=block") // Prevent clickjacking c.Header("X-Frame-Options", "DENY") // Referrer policy c.Header("Referrer-Policy", "strict-origin-when-cross-origin") // Content Security Policy (adjust as needed) c.Header("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'") c.Next() } }