package api import ( "context" "crypto/tls" "errors" "io" "log" "net" "net/http" "path" "regexp" "sync" "github.com/loveuer/upp/internal/bytesconv" "github.com/loveuer/upp/pkg/interfaces" ) var ( _ IRouter = (*App)(nil) regSafePrefix = regexp.MustCompile("[^a-zA-Z0-9/-]+") regRemoveRepeatedChar = regexp.MustCompile("/{2,}") ) type App struct { RouterGroup Upp interfaces.Upp config *Config groups []*RouterGroup server *http.Server trees methodTrees pool *sync.Pool maxParams uint16 maxSections uint16 redirectTrailingSlash bool // true redirectFixedPath bool // false handleMethodNotAllowed bool // false useRawPath bool // false unescapePathValues bool // true removeExtraSlash bool // false } func (a *App) allocateContext() *Ctx { var ( skippedNodes = make([]skippedNode, 0, a.maxSections) v = make(Params, 0, a.maxParams) ) ctx := Ctx{ lock: sync.Mutex{}, app: a, index: -1, locals: make(map[string]any), handlers: make([]HandlerFunc, 0), skippedNodes: &skippedNodes, params: &v, } return &ctx } func (a *App) ServeHTTP(writer http.ResponseWriter, request *http.Request) { var ( err error c = a.pool.Get().(*Ctx) nfe = new(Err) ) c.reset(writer, request) if err = c.verify(); err != nil { if errors.As(err, nfe) { _ = c.Status(nfe.Status).SendString(nfe.Msg) return } _ = c.Status(500).SendString(err.Error()) return } a.handleHTTPRequest(c) a.pool.Put(c) } func (a *App) run(ln net.Listener) error { srv := &http.Server{Handler: a} if a.config.DisableHttpErrorLog { srv.ErrorLog = log.New(io.Discard, "", 0) } a.server = srv err := a.server.Serve(ln) if !errors.Is(err, http.ErrServerClosed) || a.config.ErrServeClose { return err } return nil } func (a *App) Run(address string) error { ln, err := net.Listen("tcp", address) if err != nil { return err } return a.run(ln) } func (a *App) RunTLS(address string, tlsConfig *tls.Config) error { ln, err := tls.Listen("tcp", address, tlsConfig) if err != nil { return err } return a.run(ln) } func (a *App) RunListener(ln net.Listener) error { a.server = &http.Server{Addr: ln.Addr().String()} return a.run(ln) } func (a *App) Shutdown(ctx context.Context) error { return a.server.Shutdown(ctx) } func (a *App) addRoute(method, path string, handlers ...HandlerFunc) { elsePanic(path[0] == '/', "path must begin with '/'") elsePanic(method != "", "HTTP method can not be empty") elsePanic(len(handlers) > 0, "without enable not implement, there must be at least one handler") // if !a.config.DisableMessagePrint { // fmt.Printf("[NF] Add Route: %-8s - %-25s (%2d handlers)\n", method, path, len(handlers)) // } root := a.trees.get(method) if root == nil { root = new(node) root.fullPath = "/" a.trees = append(a.trees, methodTree{method: method, root: root}) } root.addRoute(path, handlers...) if paramsCount := countParams(path); paramsCount > a.maxParams { a.maxParams = paramsCount } if sectionsCount := countSections(path); sectionsCount > a.maxSections { a.maxSections = sectionsCount } } func (a *App) handleHTTPRequest(c *Ctx) { var err error httpMethod := c.Request.Method rPath := c.Request.URL.Path unescape := false if a.useRawPath && len(c.Request.URL.RawPath) > 0 { rPath = c.Request.URL.RawPath unescape = a.unescapePathValues } if a.removeExtraSlash { rPath = cleanPath(rPath) } // Find root of the tree for the given HTTP method t := a.trees for i, tl := 0, len(t); i < tl; i++ { if t[i].method != httpMethod { continue } root := t[i].root // Find route in tree value := root.getValue(rPath, c.params, c.skippedNodes, unescape) if value.params != nil { c.params = value.params } if value.handlers != nil { c.handlers = value.handlers c.fullPath = value.fullPath if err = c.Next(); err != nil { serveError(c, errorHandler) } return } if httpMethod != http.MethodConnect && rPath != "/" { if value.tsr && a.redirectTrailingSlash { redirectTrailingSlash(c) return } if a.redirectFixedPath && redirectFixedPath(c, root, a.redirectFixedPath) { return } } break } if a.handleMethodNotAllowed { // According to RFC 7231 section 6.5.5, MUST generate an Allow header field in response // containing a list of the target resource's currently supported methods. allowed := make([]string, 0, len(t)-1) for _, tree := range a.trees { if tree.method == httpMethod { continue } if value := tree.root.getValue(rPath, nil, c.skippedNodes, unescape); value.handlers != nil { allowed = append(allowed, tree.method) } } if len(allowed) > 0 { c.handlers = a.combineHandlers(a.config.MethodNotAllowedHandler) _ = c.Next() return } } c.handlers = a.combineHandlers(a.config.NotFoundHandler) _ = c.Next() return } func errorHandler(c *Ctx) error { return c.Status(500).SendString(_500) } func serveError(c *Ctx, handler HandlerFunc) { err := c.Next() if c.writermem.Written() { return } _ = handler(c) _ = err } func redirectTrailingSlash(c *Ctx) { req := c.Request p := req.URL.Path if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." { prefix = regSafePrefix.ReplaceAllString(prefix, "") prefix = regRemoveRepeatedChar.ReplaceAllString(prefix, "/") p = prefix + "/" + req.URL.Path } req.URL.Path = p + "/" if length := len(p); length > 1 && p[length-1] == '/' { req.URL.Path = p[:length-1] } redirectRequest(c) } func redirectFixedPath(c *Ctx, root *node, trailingSlash bool) bool { req := c.Request rPath := req.URL.Path if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok { req.URL.Path = bytesconv.BytesToString(fixedPath) redirectRequest(c) return true } return false } func redirectRequest(c *Ctx) { req := c.Request // rPath := req.URL.Path rURL := req.URL.String() code := http.StatusMovedPermanently // Permanent redirect, request with GET method if req.Method != http.MethodGet { code = http.StatusTemporaryRedirect } // debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) http.Redirect(c.Writer, req, rURL, code) c.writermem.WriteHeaderNow() }