304 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			304 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package nf
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"crypto/tls"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"log"
 | 
						|
	"net"
 | 
						|
	"net/http"
 | 
						|
	"path"
 | 
						|
	"regexp"
 | 
						|
	"sync"
 | 
						|
 | 
						|
	"github.com/loveuer/nf/internal/bytesconv"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	_ IRouter = (*App)(nil)
 | 
						|
 | 
						|
	regSafePrefix         = regexp.MustCompile("[^a-zA-Z0-9/-]+")
 | 
						|
	regRemoveRepeatedChar = regexp.MustCompile("/{2,}")
 | 
						|
)
 | 
						|
 | 
						|
type App struct {
 | 
						|
	RouterGroup
 | 
						|
	config *Config
 | 
						|
	groups []*RouterGroup
 | 
						|
	server *http.Server
 | 
						|
 | 
						|
	trees methodTrees
 | 
						|
 | 
						|
	pool *sync.Pool
 | 
						|
 | 
						|
	maxParams   uint16
 | 
						|
	maxSections uint16
 | 
						|
 | 
						|
	redirectTrailingSlash  bool // true
 | 
						|
	redirectFixedPath      bool // false
 | 
						|
	handleMethodNotAllowed bool // false
 | 
						|
	useRawPath             bool // false
 | 
						|
	unescapePathValues     bool // true
 | 
						|
	removeExtraSlash       bool // false
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) allocateContext() *Ctx {
 | 
						|
	var (
 | 
						|
		skippedNodes = make([]skippedNode, 0, a.maxSections)
 | 
						|
		v            = make(Params, 0, a.maxParams)
 | 
						|
	)
 | 
						|
 | 
						|
	ctx := Ctx{
 | 
						|
		lock:         sync.Mutex{},
 | 
						|
		app:          a,
 | 
						|
		index:        -1,
 | 
						|
		locals:       make(map[string]any),
 | 
						|
		handlers:     make([]HandlerFunc, 0),
 | 
						|
		skippedNodes: &skippedNodes,
 | 
						|
		params:       &v,
 | 
						|
	}
 | 
						|
 | 
						|
	return &ctx
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
 | 
						|
	var (
 | 
						|
		err error
 | 
						|
		c   = a.pool.Get().(*Ctx)
 | 
						|
		nfe = new(Err)
 | 
						|
	)
 | 
						|
 | 
						|
	c.reset(writer, request)
 | 
						|
 | 
						|
	if err = c.verify(); err != nil {
 | 
						|
		if errors.As(err, nfe) {
 | 
						|
			_ = c.Status(nfe.Status).SendString(nfe.Msg)
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		_ = c.Status(500).SendString(err.Error())
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	a.handleHTTPRequest(c)
 | 
						|
 | 
						|
	a.pool.Put(c)
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) run(ln net.Listener) error {
 | 
						|
	srv := &http.Server{Handler: a}
 | 
						|
 | 
						|
	if a.config.DisableHttpErrorLog {
 | 
						|
		srv.ErrorLog = log.New(io.Discard, "", 0)
 | 
						|
	}
 | 
						|
 | 
						|
	a.server = srv
 | 
						|
 | 
						|
	if !a.config.DisableBanner {
 | 
						|
		fmt.Println(banner + "nf serve at: " + ln.Addr().String() + "\n")
 | 
						|
	}
 | 
						|
 | 
						|
	err := a.server.Serve(ln)
 | 
						|
	if !errors.Is(err, http.ErrServerClosed) || a.config.ErrServeClose {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) Run(address string) error {
 | 
						|
	ln, err := net.Listen("tcp", address)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return a.run(ln)
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) RunTLS(address string, tlsConfig *tls.Config) error {
 | 
						|
	ln, err := tls.Listen("tcp", address, tlsConfig)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return a.run(ln)
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) RunListener(ln net.Listener) error {
 | 
						|
	a.server = &http.Server{Addr: ln.Addr().String()}
 | 
						|
 | 
						|
	return a.run(ln)
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) Shutdown(ctx context.Context) error {
 | 
						|
	return a.server.Shutdown(ctx)
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) addRoute(method, path string, handlers ...HandlerFunc) {
 | 
						|
	elsePanic(path[0] == '/', "path must begin with '/'")
 | 
						|
	elsePanic(method != "", "HTTP method can not be empty")
 | 
						|
	elsePanic(len(handlers) > 0, "without enable not implement, there must be at least one handler")
 | 
						|
 | 
						|
	if !a.config.DisableMessagePrint {
 | 
						|
		fmt.Printf("[NF] Add Route: %-8s - %-25s (%2d handlers)\n", method, path, len(handlers))
 | 
						|
	}
 | 
						|
 | 
						|
	root := a.trees.get(method)
 | 
						|
	if root == nil {
 | 
						|
		root = new(node)
 | 
						|
		root.fullPath = "/"
 | 
						|
		a.trees = append(a.trees, methodTree{method: method, root: root})
 | 
						|
	}
 | 
						|
 | 
						|
	root.addRoute(path, handlers...)
 | 
						|
 | 
						|
	if paramsCount := countParams(path); paramsCount > a.maxParams {
 | 
						|
		a.maxParams = paramsCount
 | 
						|
	}
 | 
						|
 | 
						|
	if sectionsCount := countSections(path); sectionsCount > a.maxSections {
 | 
						|
		a.maxSections = sectionsCount
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (a *App) handleHTTPRequest(c *Ctx) {
 | 
						|
	var err error
 | 
						|
 | 
						|
	httpMethod := c.Request.Method
 | 
						|
	rPath := c.Request.URL.Path
 | 
						|
	unescape := false
 | 
						|
	if a.useRawPath && len(c.Request.URL.RawPath) > 0 {
 | 
						|
		rPath = c.Request.URL.RawPath
 | 
						|
		unescape = a.unescapePathValues
 | 
						|
	}
 | 
						|
 | 
						|
	if a.removeExtraSlash {
 | 
						|
		rPath = cleanPath(rPath)
 | 
						|
	}
 | 
						|
 | 
						|
	// Find root of the tree for the given HTTP method
 | 
						|
	t := a.trees
 | 
						|
	for i, tl := 0, len(t); i < tl; i++ {
 | 
						|
		if t[i].method != httpMethod {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		root := t[i].root
 | 
						|
		// Find route in tree
 | 
						|
		value := root.getValue(rPath, c.params, c.skippedNodes, unescape)
 | 
						|
		if value.params != nil {
 | 
						|
			c.params = value.params
 | 
						|
		}
 | 
						|
 | 
						|
		if value.handlers != nil {
 | 
						|
			c.handlers = value.handlers
 | 
						|
			c.fullPath = value.fullPath
 | 
						|
 | 
						|
			if err = c.Next(); err != nil {
 | 
						|
				serveError(c, errorHandler)
 | 
						|
			}
 | 
						|
 | 
						|
			return
 | 
						|
		}
 | 
						|
		if httpMethod != http.MethodConnect && rPath != "/" {
 | 
						|
			if value.tsr && a.redirectTrailingSlash {
 | 
						|
				redirectTrailingSlash(c)
 | 
						|
				return
 | 
						|
			}
 | 
						|
			if a.redirectFixedPath && redirectFixedPath(c, root, a.redirectFixedPath) {
 | 
						|
				return
 | 
						|
			}
 | 
						|
		}
 | 
						|
		break
 | 
						|
	}
 | 
						|
 | 
						|
	if a.handleMethodNotAllowed {
 | 
						|
		// According to RFC 7231 section 6.5.5, MUST generate an Allow header field in response
 | 
						|
		// containing a list of the target resource's currently supported methods.
 | 
						|
		allowed := make([]string, 0, len(t)-1)
 | 
						|
		for _, tree := range a.trees {
 | 
						|
			if tree.method == httpMethod {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
			if value := tree.root.getValue(rPath, nil, c.skippedNodes, unescape); value.handlers != nil {
 | 
						|
				allowed = append(allowed, tree.method)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if len(allowed) > 0 {
 | 
						|
			c.handlers = a.combineHandlers(a.config.MethodNotAllowedHandler)
 | 
						|
 | 
						|
			_ = c.Next()
 | 
						|
 | 
						|
			return
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	c.handlers = a.combineHandlers(a.config.NotFoundHandler)
 | 
						|
 | 
						|
	_ = c.Next()
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func errorHandler(c *Ctx) error {
 | 
						|
	return c.Status(500).SendString(_500)
 | 
						|
}
 | 
						|
 | 
						|
func serveError(c *Ctx, handler HandlerFunc) {
 | 
						|
	err := c.Next()
 | 
						|
 | 
						|
	if c.writermem.Written() {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	_ = handler(c)
 | 
						|
	_ = err
 | 
						|
}
 | 
						|
 | 
						|
func redirectTrailingSlash(c *Ctx) {
 | 
						|
	req := c.Request
 | 
						|
	p := req.URL.Path
 | 
						|
	if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
 | 
						|
		prefix = regSafePrefix.ReplaceAllString(prefix, "")
 | 
						|
		prefix = regRemoveRepeatedChar.ReplaceAllString(prefix, "/")
 | 
						|
 | 
						|
		p = prefix + "/" + req.URL.Path
 | 
						|
	}
 | 
						|
	req.URL.Path = p + "/"
 | 
						|
	if length := len(p); length > 1 && p[length-1] == '/' {
 | 
						|
		req.URL.Path = p[:length-1]
 | 
						|
	}
 | 
						|
 | 
						|
	redirectRequest(c)
 | 
						|
}
 | 
						|
 | 
						|
func redirectFixedPath(c *Ctx, root *node, trailingSlash bool) bool {
 | 
						|
	req := c.Request
 | 
						|
	rPath := req.URL.Path
 | 
						|
 | 
						|
	if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok {
 | 
						|
		req.URL.Path = bytesconv.BytesToString(fixedPath)
 | 
						|
		redirectRequest(c)
 | 
						|
		return true
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
func redirectRequest(c *Ctx) {
 | 
						|
	req := c.Request
 | 
						|
	// rPath := req.URL.Path
 | 
						|
	rURL := req.URL.String()
 | 
						|
 | 
						|
	code := http.StatusMovedPermanently // Permanent redirect, request with GET method
 | 
						|
	if req.Method != http.MethodGet {
 | 
						|
		code = http.StatusTemporaryRedirect
 | 
						|
	}
 | 
						|
 | 
						|
	// debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
 | 
						|
 | 
						|
	http.Redirect(c.Writer, req, rURL, code)
 | 
						|
	c.writermem.WriteHeaderNow()
 | 
						|
}
 |