Compare commits

...

17 Commits

33 changed files with 2237 additions and 677 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea .idea
.vscode .vscode
.DS_Store .DS_Store
xtest

220
app.go
View File

@ -5,41 +5,59 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"github.com/loveuer/nf/internal/bytesconv"
"io" "io"
"log" "log"
"net" "net"
"net/http" "net/http"
"strings" "path"
"regexp"
)
var (
_ IRouter = (*App)(nil)
regSafePrefix = regexp.MustCompile("[^a-zA-Z0-9/-]+")
regRemoveRepeatedChar = regexp.MustCompile("/{2,}")
) )
type App struct { type App struct {
*RouterGroup RouterGroup
config *Config config *Config
router *router
groups []*RouterGroup groups []*RouterGroup
server *http.Server server *http.Server
trees methodTrees
maxParams uint16
maxSections uint16
redirectTrailingSlash bool // true
redirectFixedPath bool // false
handleMethodNotAllowed bool // false
useRawPath bool // false
unescapePathValues bool // true
removeExtraSlash bool // false
} }
func (a *App) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (a *App) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
c := newContext(a, writer, request) var (
err error
c = newContext(a, writer, request)
nfe = new(Err)
)
for _, group := range a.groups { if err = c.verify(); err != nil {
if strings.HasPrefix(request.URL.Path, group.prefix) { if errors.As(err, nfe) {
c.handlers = append(c.handlers, group.middlewares...) _ = c.Status(nfe.Status).SendString(nfe.Msg)
} return
} }
if err := a.router.handle(c); err != nil { _ = c.Status(500).SendString(err.Error())
var ne = &Err{} return
if errors.As(err, ne) {
writer.WriteHeader(ne.Status)
} else {
writer.WriteHeader(500)
} }
_, _ = writer.Write([]byte(err.Error())) a.handleHTTPRequest(c)
}
} }
func (a *App) run(ln net.Listener) error { func (a *App) run(ln net.Listener) error {
@ -90,3 +108,171 @@ func (a *App) RunListener(ln net.Listener) error {
func (a *App) Shutdown(ctx context.Context) error { func (a *App) Shutdown(ctx context.Context) error {
return a.server.Shutdown(ctx) return a.server.Shutdown(ctx)
} }
func (a *App) addRoute(method, path string, handlers ...HandlerFunc) {
elsePanic(path[0] == '/', "path must begin with '/'")
elsePanic(method != "", "HTTP method can not be empty")
elsePanic(len(handlers) > 0, "without enable not implement, there must be at least one handler")
if !a.config.DisableMessagePrint {
fmt.Printf("[NF] Add Route: %-8s - %-25s (%2d handlers)\n", method, path, len(handlers))
}
root := a.trees.get(method)
if root == nil {
root = new(node)
root.fullPath = "/"
a.trees = append(a.trees, methodTree{method: method, root: root})
}
root.addRoute(path, handlers...)
if paramsCount := countParams(path); paramsCount > a.maxParams {
a.maxParams = paramsCount
}
if sectionsCount := countSections(path); sectionsCount > a.maxSections {
a.maxSections = sectionsCount
}
}
func (a *App) handleHTTPRequest(c *Ctx) {
var (
err error
)
httpMethod := c.Request.Method
rPath := c.Request.URL.Path
unescape := false
if a.useRawPath && len(c.Request.URL.RawPath) > 0 {
rPath = c.Request.URL.RawPath
unescape = a.unescapePathValues
}
if a.removeExtraSlash {
rPath = cleanPath(rPath)
}
// Find root of the tree for the given HTTP method
t := a.trees
for i, tl := 0, len(t); i < tl; i++ {
if t[i].method != httpMethod {
continue
}
root := t[i].root
// Find route in tree
value := root.getValue(rPath, c.params, c.skippedNodes, unescape)
if value.params != nil {
c.params = value.params
}
if value.handlers != nil {
c.handlers = value.handlers
c.fullPath = value.fullPath
if err = c.Next(); err != nil {
serveError(c, errorHandler)
}
return
}
if httpMethod != http.MethodConnect && rPath != "/" {
if value.tsr && a.redirectTrailingSlash {
redirectTrailingSlash(c)
return
}
if a.redirectFixedPath && redirectFixedPath(c, root, a.redirectFixedPath) {
return
}
}
break
}
if a.handleMethodNotAllowed {
// According to RFC 7231 section 6.5.5, MUST generate an Allow header field in response
// containing a list of the target resource's currently supported methods.
allowed := make([]string, 0, len(t)-1)
for _, tree := range a.trees {
if tree.method == httpMethod {
continue
}
if value := tree.root.getValue(rPath, nil, c.skippedNodes, unescape); value.handlers != nil {
allowed = append(allowed, tree.method)
}
}
if len(allowed) > 0 {
c.handlers = a.combineHandlers(a.config.MethodNotAllowedHandler)
_ = c.Next()
return
}
}
c.handlers = a.combineHandlers(a.config.NotFoundHandler)
_ = c.Next()
return
}
func errorHandler(c *Ctx) error {
return c.Status(500).SendString(_500)
}
func serveError(c *Ctx, handler HandlerFunc) {
err := c.Next()
if c.writermem.Written() {
return
}
_ = handler(c)
_ = err
}
func redirectTrailingSlash(c *Ctx) {
req := c.Request
p := req.URL.Path
if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." {
prefix = regSafePrefix.ReplaceAllString(prefix, "")
prefix = regRemoveRepeatedChar.ReplaceAllString(prefix, "/")
p = prefix + "/" + req.URL.Path
}
req.URL.Path = p + "/"
if length := len(p); length > 1 && p[length-1] == '/' {
req.URL.Path = p[:length-1]
}
redirectRequest(c)
}
func redirectFixedPath(c *Ctx, root *node, trailingSlash bool) bool {
req := c.Request
rPath := req.URL.Path
if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok {
req.URL.Path = bytesconv.BytesToString(fixedPath)
redirectRequest(c)
return true
}
return false
}
func redirectRequest(c *Ctx) {
req := c.Request
//rPath := req.URL.Path
rURL := req.URL.String()
code := http.StatusMovedPermanently // Permanent redirect, request with GET method
if req.Method != http.MethodGet {
code = http.StatusTemporaryRedirect
}
//debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL)
http.Redirect(c.writer, req, rURL, code)
c.writermem.WriteHeaderNow()
}

157
ctx.go
View File

@ -3,45 +3,63 @@ package nf
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/loveuer/nf/internal/sse"
"io" "io"
"log"
"mime/multipart" "mime/multipart"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"sync"
) )
type Ctx struct { type Ctx struct {
// origin objects lock sync.Mutex
writermem responseWriter
writer http.ResponseWriter writer http.ResponseWriter
Request *http.Request Request *http.Request
// request info
path string path string
Method string method string
// response info
StatusCode int StatusCode int
app *App app *App
params map[string]string params *Params
index int index int
handlers []HandlerFunc handlers []HandlerFunc
locals map[string]interface{} locals map[string]interface{}
skippedNodes *[]skippedNode
fullPath string
} }
func newContext(app *App, writer http.ResponseWriter, request *http.Request) *Ctx { func newContext(app *App, writer http.ResponseWriter, request *http.Request) *Ctx {
return &Ctx{
skippedNodes := make([]skippedNode, 0, app.maxSections)
v := make(Params, 0, app.maxParams)
ctx := &Ctx{
lock: sync.Mutex{},
writer: writer, writer: writer,
Request: request, Request: request,
path: request.URL.Path, path: request.URL.Path,
Method: request.Method, method: request.Method,
StatusCode: 200, StatusCode: 200,
app: app, app: app,
index: -1, index: -1,
locals: map[string]interface{}{}, locals: map[string]interface{}{},
handlers: make([]HandlerFunc, 0), handlers: make([]HandlerFunc, 0),
skippedNodes: &skippedNodes,
params: &v,
} }
ctx.writermem = responseWriter{
ResponseWriter: ctx.writer,
size: -1,
status: 0,
}
return ctx
} }
func (c *Ctx) Locals(key string, value ...interface{}) interface{} { func (c *Ctx) Locals(key string, value ...interface{}) interface{} {
@ -53,6 +71,16 @@ func (c *Ctx) Locals(key string, value ...interface{}) interface{} {
return data return data
} }
func (c *Ctx) Method(overWrite ...string) string {
method := c.Request.Method
if len(overWrite) > 0 && overWrite[0] != "" {
c.Request.Method = overWrite[0]
}
return method
}
func (c *Ctx) Path(overWrite ...string) string { func (c *Ctx) Path(overWrite ...string) string {
path := c.Request.URL.Path path := c.Request.URL.Path
if len(overWrite) > 0 && overWrite[0] != "" { if len(overWrite) > 0 && overWrite[0] != "" {
@ -62,15 +90,43 @@ func (c *Ctx) Path(overWrite ...string) string {
return path return path
} }
func (c *Ctx) Cookies(key string, defaultValue ...string) string {
var (
dv = ""
)
if len(defaultValue) > 0 {
dv = defaultValue[0]
}
cookie, err := c.Request.Cookie(key)
if err != nil || cookie.Value == "" {
return dv
}
return cookie.Value
}
func (c *Ctx) Next() error { func (c *Ctx) Next() error {
c.index++ c.index++
s := len(c.handlers)
for ; c.index < s; c.index++ { if c.index >= len(c.handlers) {
if err := c.handlers[c.index](c); err != nil { return nil
}
var (
err error
handler = c.handlers[c.index]
)
if handler != nil {
if err = handler(c); err != nil {
return err return err
} }
} }
c.index++
return nil return nil
} }
@ -88,18 +144,39 @@ func (c *Ctx) verify() error {
} }
func (c *Ctx) Param(key string) string { func (c *Ctx) Param(key string) string {
return c.params[key] return c.params.ByName(key)
}
func (c *Ctx) SetParam(key, value string) {
c.lock.Lock()
defer c.lock.Unlock()
params := append(*c.params, Param{Key: key, Value: value})
c.params = &params
} }
func (c *Ctx) Form(key string) string { func (c *Ctx) Form(key string) string {
return c.Request.FormValue(key) return c.Request.FormValue(key)
} }
// FormValue fiber ctx function
func (c *Ctx) FormValue(key string) string {
return c.Request.FormValue(key)
}
func (c *Ctx) FormFile(key string) (*multipart.FileHeader, error) { func (c *Ctx) FormFile(key string) (*multipart.FileHeader, error) {
_, fh, err := c.Request.FormFile(key) _, fh, err := c.Request.FormFile(key)
return fh, err return fh, err
} }
func (c *Ctx) MultipartForm() (*multipart.Form, error) {
if err := c.Request.ParseMultipartForm(c.app.config.BodyLimit); err != nil {
return nil, err
}
return c.Request.MultipartForm, nil
}
func (c *Ctx) Query(key string) string { func (c *Ctx) Query(key string) string {
return c.Request.URL.Query().Get(key) return c.Request.URL.Query().Get(key)
} }
@ -127,8 +204,6 @@ func (c *Ctx) BodyParser(out interface{}) error {
ctype = strings.ToLower(c.Request.Header.Get("Content-Type")) ctype = strings.ToLower(c.Request.Header.Get("Content-Type"))
) )
log.Printf("BodyParser: Content-Type=%s", ctype)
ctype = parseVendorSpecificContentType(ctype) ctype = parseVendorSpecificContentType(ctype)
ctypeEnd := strings.IndexByte(ctype, ';') ctypeEnd := strings.IndexByte(ctype, ';')
@ -139,9 +214,9 @@ func (c *Ctx) BodyParser(out interface{}) error {
if strings.HasSuffix(ctype, "json") { if strings.HasSuffix(ctype, "json") {
bs, err := io.ReadAll(c.Request.Body) bs, err := io.ReadAll(c.Request.Body)
if err != nil { if err != nil {
log.Printf("BodyParser: read all err=%v", err)
return err return err
} }
_ = c.Request.Body.Close()
c.Request.Body = io.NopCloser(bytes.NewReader(bs)) c.Request.Body = io.NopCloser(bytes.NewReader(bs))
@ -169,11 +244,6 @@ func (c *Ctx) BodyParser(out interface{}) error {
} }
func (c *Ctx) QueryParser(out interface{}) error { func (c *Ctx) QueryParser(out interface{}) error {
//v := reflect.ValueOf(out)
//
//if v.Kind() == reflect.Ptr && v.Elem().Kind() != reflect.Map {
//}
return parseToStruct("query", out, c.Request.URL.Query()) return parseToStruct("query", out, c.Request.URL.Query())
} }
@ -182,17 +252,27 @@ func (c *Ctx) QueryParser(out interface{}) error {
=============================================================== */ =============================================================== */
func (c *Ctx) Status(code int) *Ctx { func (c *Ctx) Status(code int) *Ctx {
c.StatusCode = code c.lock.Lock()
c.writer.WriteHeader(code) defer c.lock.Unlock()
c.writermem.WriteHeader(code)
c.StatusCode = c.writermem.status
return c return c
} }
func (c *Ctx) Set(key string, value string) { func (c *Ctx) Set(key string, value string) {
c.writer.Header().Set(key, value) c.writermem.Header().Set(key, value)
} }
func (c *Ctx) SetHeader(key string, value string) { func (c *Ctx) SetHeader(key string, value string) {
c.writer.Header().Set(key, value) c.writermem.Header().Set(key, value)
}
func (c *Ctx) SendStatus(code int) error {
c.Status(code)
c.writermem.WriteHeaderNow()
return nil
} }
func (c *Ctx) SendString(data string) error { func (c *Ctx) SendString(data string) error {
@ -203,13 +283,13 @@ func (c *Ctx) SendString(data string) error {
func (c *Ctx) Writef(format string, values ...interface{}) (int, error) { func (c *Ctx) Writef(format string, values ...interface{}) (int, error) {
c.SetHeader("Content-Type", "text/plain") c.SetHeader("Content-Type", "text/plain")
return c.writer.Write([]byte(fmt.Sprintf(format, values...))) return c.Write([]byte(fmt.Sprintf(format, values...)))
} }
func (c *Ctx) JSON(data interface{}) error { func (c *Ctx) JSON(data interface{}) error {
c.SetHeader("Content-Type", MIMEApplicationJSON) c.SetHeader("Content-Type", MIMEApplicationJSON)
encoder := json.NewEncoder(c.writer) encoder := json.NewEncoder(&c.writermem)
if err := encoder.Encode(data); err != nil { if err := encoder.Encode(data); err != nil {
return err return err
@ -218,12 +298,25 @@ func (c *Ctx) JSON(data interface{}) error {
return nil return nil
} }
func (c *Ctx) RawWriter() http.ResponseWriter { func (c *Ctx) SSEvent(event string, data interface{}) error {
return c.writer c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Transfer-Encoding", "chunked")
return sse.Encode(c.writer, sse.Event{Event: event, Data: data})
} }
func (c *Ctx) Write(data []byte) (int, error) { func (c *Ctx) Flush() error {
return c.writer.Write(data) if f, ok := c.writer.(http.Flusher); ok {
f.Flush()
return nil
}
return errors.New("http.Flusher is not implemented")
}
func (c *Ctx) RawWriter() http.ResponseWriter {
return c.writer
} }
func (c *Ctx) HTML(html string) error { func (c *Ctx) HTML(html string) error {
@ -231,3 +324,7 @@ func (c *Ctx) HTML(html string) error {
_, err := c.writer.Write([]byte(html)) _, err := c.writer.Write([]byte(html))
return err return err
} }
func (c *Ctx) Write(data []byte) (int, error) {
return c.writermem.Write(data)
}

11
go.mod
View File

@ -1,3 +1,14 @@
module github.com/loveuer/nf module github.com/loveuer/nf
go 1.20 go 1.20
require (
github.com/fatih/color v1.17.0
github.com/google/uuid v1.6.0
)
require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
golang.org/x/sys v0.18.0 // indirect
)

13
go.sum Normal file
View File

@ -0,0 +1,13 @@
github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

View File

@ -1,80 +0,0 @@
package nf
import (
"fmt"
"log"
"net/http"
)
type RouterGroup struct {
prefix string
middlewares []HandlerFunc // support middleware
parent *RouterGroup // support nesting
app *App // all groups share a Engine instance
}
// Group is defined to create a new RouterGroup
// remember all groups share the same Engine instance
func (group *RouterGroup) Group(prefix string) *RouterGroup {
app := group.app
newGroup := &RouterGroup{
prefix: group.prefix + prefix,
parent: group,
app: app,
}
app.groups = append(app.groups, newGroup)
return newGroup
}
func (group *RouterGroup) verifyHandlers(path string, handlers ...HandlerFunc) []HandlerFunc {
if len(handlers) == 0 {
if !group.app.config.EnableNotImplementHandler {
panic(fmt.Sprintf("missing handler in route: %s", path))
}
handlers = append(handlers, ToDoHandler)
}
for _, handler := range handlers {
if handler == nil {
panic(fmt.Sprintf("nil handler found in route: %s", path))
}
}
return handlers
}
func (group *RouterGroup) addRoute(method string, comp string, handlers ...HandlerFunc) {
handlers = group.verifyHandlers(comp, handlers...)
pattern := group.prefix + comp
log.Printf("Add Route %4s - %s", method, pattern)
group.app.router.addRoute(method, pattern, handlers...)
}
func (group *RouterGroup) Get(pattern string, handlers ...HandlerFunc) {
group.addRoute(http.MethodGet, pattern, handlers...)
}
func (group *RouterGroup) Post(pattern string, handlers ...HandlerFunc) {
group.addRoute(http.MethodPost, pattern, handlers...)
}
func (group *RouterGroup) Put(pattern string, handlers ...HandlerFunc) {
group.addRoute(http.MethodPut, pattern, handlers...)
}
func (group *RouterGroup) Delete(pattern string, handlers ...HandlerFunc) {
group.addRoute(http.MethodDelete, pattern, handlers...)
}
func (group *RouterGroup) Patch(pattern string, handlers ...HandlerFunc) {
group.addRoute(http.MethodPatch, pattern, handlers...)
}
func (group *RouterGroup) Head(pattern string, handlers ...HandlerFunc) {
group.addRoute(http.MethodHead, pattern, handlers...)
}
func (group *RouterGroup) Use(middlewares ...HandlerFunc) {
group.middlewares = append(group.middlewares, middlewares...)
}

View File

@ -5,5 +5,5 @@ import "fmt"
type HandlerFunc func(*Ctx) error type HandlerFunc func(*Ctx) error
func ToDoHandler(c *Ctx) error { func ToDoHandler(c *Ctx) error {
return c.Status(501).SendString(fmt.Sprintf("%s - %s Not Implemented", c.Method, c.Path())) return c.Status(501).SendString(fmt.Sprintf("%s - %s Not Implemented", c.Method(), c.Path()))
} }

View File

@ -0,0 +1,26 @@
// Copyright 2020 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
//go:build !go1.20
package bytesconv
import (
"unsafe"
)
// StringToBytes converts string to byte slice without a memory allocation.
func StringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(
&struct {
string
Cap int
}{s, len(s)},
))
}
// BytesToString converts byte slice to string without a memory allocation.
func BytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}

View File

@ -0,0 +1,23 @@
// Copyright 2023 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
//go:build go1.20
package bytesconv
import (
"unsafe"
)
// StringToBytes converts string to byte slice without a memory allocation.
// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077.
func StringToBytes(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}
// BytesToString converts byte slice to string without a memory allocation.
// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077.
func BytesToString(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b))
}

View File

@ -0,0 +1,99 @@
// Copyright 2020 Gin Core Team. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package bytesconv
import (
"bytes"
"math/rand"
"strings"
"testing"
"time"
)
var testString = "Albert Einstein: Logic will get you from A to B. Imagination will take you everywhere."
var testBytes = []byte(testString)
func rawBytesToStr(b []byte) string {
return string(b)
}
func rawStrToBytes(s string) []byte {
return []byte(s)
}
// go test -v
func TestBytesToString(t *testing.T) {
data := make([]byte, 1024)
for i := 0; i < 100; i++ {
rand.Read(data)
if rawBytesToStr(data) != BytesToString(data) {
t.Fatal("don't match")
}
}
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const (
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
)
var src = rand.NewSource(time.Now().UnixNano())
func RandStringBytesMaskImprSrcSB(n int) string {
sb := strings.Builder{}
sb.Grow(n)
// A src.Int63() generates 63 random bits, enough for letterIdxMax characters!
for i, cache, remain := n-1, src.Int63(), letterIdxMax; i >= 0; {
if remain == 0 {
cache, remain = src.Int63(), letterIdxMax
}
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
sb.WriteByte(letterBytes[idx])
i--
}
cache >>= letterIdxBits
remain--
}
return sb.String()
}
func TestStringToBytes(t *testing.T) {
for i := 0; i < 100; i++ {
s := RandStringBytesMaskImprSrcSB(64)
if !bytes.Equal(rawStrToBytes(s), StringToBytes(s)) {
t.Fatal("don't match")
}
}
}
// go test -v -run=none -bench=^BenchmarkBytesConv -benchmem=true
func BenchmarkBytesConvBytesToStrRaw(b *testing.B) {
for i := 0; i < b.N; i++ {
rawBytesToStr(testBytes)
}
}
func BenchmarkBytesConvBytesToStr(b *testing.B) {
for i := 0; i < b.N; i++ {
BytesToString(testBytes)
}
}
func BenchmarkBytesConvStrToBytesRaw(b *testing.B) {
for i := 0; i < b.N; i++ {
rawStrToBytes(testString)
}
}
func BenchmarkBytesConvStrToBytes(b *testing.B) {
for i := 0; i < b.N; i++ {
StringToBytes(testString)
}
}

106
internal/sse/sse-encoder.go Normal file
View File

@ -0,0 +1,106 @@
package sse
import (
"encoding/json"
"fmt"
"io"
"net/http"
"reflect"
"strconv"
"strings"
)
// Server-Sent Events
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
const ContentType = "text/event-stream"
var contentType = []string{ContentType}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
"\n", "\\n",
"\r", "\\r")
var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\r", "\\r")
type Event struct {
Event string
Id string
Retry uint
Data interface{}
}
func Encode(writer io.Writer, event Event) error {
w := checkWriter(writer)
writeId(w, event.Id)
writeEvent(w, event.Event)
writeRetry(w, event.Retry)
return writeData(w, event.Data)
}
func writeId(w stringWriter, id string) {
if len(id) > 0 {
w.WriteString("id:")
fieldReplacer.WriteString(w, id)
w.WriteString("\n")
}
}
func writeEvent(w stringWriter, event string) {
if len(event) > 0 {
w.WriteString("event:")
fieldReplacer.WriteString(w, event)
w.WriteString("\n")
}
}
func writeRetry(w stringWriter, retry uint) {
if retry > 0 {
w.WriteString("retry:")
w.WriteString(strconv.FormatUint(uint64(retry), 10))
w.WriteString("\n")
}
}
func writeData(w stringWriter, data interface{}) error {
w.WriteString("data:")
switch kindOfData(data) {
case reflect.Struct, reflect.Slice, reflect.Map:
err := json.NewEncoder(w).Encode(data)
if err != nil {
return err
}
w.WriteString("\n")
default:
dataReplacer.WriteString(w, fmt.Sprint(data))
w.WriteString("\n\n")
}
return nil
}
func (r Event) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return Encode(w, r)
}
func (r Event) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
}
}
func kindOfData(data interface{}) reflect.Kind {
value := reflect.ValueOf(data)
valueType := value.Kind()
if valueType == reflect.Ptr {
valueType = value.Elem().Kind()
}
return valueType
}

24
internal/sse/writer.go Normal file
View File

@ -0,0 +1,24 @@
package sse
import "io"
type stringWriter interface {
io.Writer
WriteString(string) (int, error)
}
type stringWrapper struct {
io.Writer
}
func (w stringWrapper) WriteString(str string) (int, error) {
return w.Writer.Write([]byte(str))
}
func checkWriter(writer io.Writer) stringWriter {
if w, ok := writer.(stringWriter); ok {
return w
} else {
return stringWrapper{writer}
}
}

View File

@ -2,9 +2,11 @@ package nf
import ( import (
"fmt" "fmt"
"log" "github.com/google/uuid"
"github.com/loveuer/nf/nft/log"
"os" "os"
"runtime/debug" "runtime/debug"
"strings"
"time" "time"
) )
@ -17,6 +19,9 @@ func NewRecover(enableStackTrace bool) HandlerFunc {
} else { } else {
os.Stderr.WriteString(fmt.Sprintf("recovered from panic: %v\n", r)) os.Stderr.WriteString(fmt.Sprintf("recovered from panic: %v\n", r))
} }
//serveError(c, 500, []byte(fmt.Sprint(r)))
_ = c.Status(500).SendString(fmt.Sprint(r))
} }
}() }()
@ -24,51 +29,44 @@ func NewRecover(enableStackTrace bool) HandlerFunc {
} }
} }
func NewLogger() HandlerFunc { func NewLogger(traceHeader ...string) HandlerFunc {
l := log.New(os.Stdout, "[NF] ", 0) Header := "X-Trace-ID"
if len(traceHeader) > 0 && traceHeader[0] != "" {
durationFormat := func(num int64) string { Header = traceHeader[0]
var (
unit = "ns"
)
if num > 1000 {
num = num / 1000
unit = "µs"
}
if num > 1000 {
num = num / 1000
unit = "ms"
}
if num > 1000 {
num = num / 1000
unit = " s"
}
return fmt.Sprintf("%v %s", num, unit)
} }
return func(c *Ctx) error { return func(c *Ctx) error {
start := time.Now() var (
now = time.Now()
trace = c.Get(Header)
logFn func(msg string, data ...any)
ip = c.IP()
)
if trace == "" {
trace = uuid.Must(uuid.NewV7()).String()
}
c.SetHeader(Header, trace)
traces := strings.Split(trace, "-")
shortTrace := traces[len(traces)-1]
err := c.Next() err := c.Next()
duration := time.Since(now)
var ( msg := fmt.Sprintf("NF | %s | %15s | %3d | %s | %6s | %s", shortTrace, ip, c.StatusCode, HumanDuration(duration.Nanoseconds()), c.Method(), c.Path())
duration = time.Now().Sub(start).Nanoseconds()
status = c.StatusCode
path = c.path
method = c.Request.Method
)
l.Printf("%s | %5s | %d | %s | %s", switch {
start.Format("06/01/02T15:04:05"), case c.StatusCode >= 500:
method, logFn = log.Error
status, case c.StatusCode >= 400:
durationFormat(duration), logFn = log.Warn
path, default:
) logFn = log.Info
}
logFn(msg)
return err return err
} }

32
nf.go
View File

@ -3,11 +3,14 @@ package nf
const ( const (
banner = " _ _ _ ___ _ \n | \\| |___| |_ | __|__ _ _ _ _ __| |\n | .` / _ \\ _| | _/ _ \\ || | ' \\/ _` |\n |_|\\_\\___/\\__| |_|\\___/\\_,_|_||_\\__,_|\n " banner = " _ _ _ ___ _ \n | \\| |___| |_ | __|__ _ _ _ _ __| |\n | .` / _ \\ _| | _/ _ \\ || | ' \\/ _` |\n |_|\\_\\___/\\__| |_|\\___/\\_,_|_||_\\__,_|\n "
_404 = "<!doctype html><html lang=\"en\"><head><meta charset=\"UTF-8\"><meta name=\"viewport\" content=\"width=device-width,user-scalable=no,initial-scale=1,maximum-scale=1,minimum-scale=1\"><meta http-equiv=\"X-UA-Compatible\" content=\"ie=edge\"><title>Not Found</title><style>body{background:#333;margin:0;color:#ccc;display:flex;align-items:center;max-height:100vh;height:100vh;justify-content:center}textarea{min-height:5rem;min-width:20rem;text-align:center;border:none;background:0 0;color:#ccc;resize:none;user-input:none;user-select:none;cursor:default;-webkit-user-select:none;-webkit-touch-callout:none;-moz-user-select:none;-ms-user-select:none;outline:0}</style></head><body><textarea id=\"banner\" readonly=\"readonly\"></textarea><script type=\"text/javascript\">let htmlCodes = [\n ' _ _ _ ___ _ ',\n '| \\\\| |___| |_ | __|__ _ _ _ _ __| |',\n '| .` / _ \\\\ _| | _/ _ \\\\ || | \\' \\\\/ _` |',\n '|_|\\\\_\\\\___/\\\\__| |_|\\\\___/\\\\_,_|_||_\\\\__,_|'\n].join('\\n');\ndocument.querySelector('#banner').value = htmlCodes</script></body></html>" _404 = "<!doctype html><html lang=\"en\"><head><meta charset=\"UTF-8\"><meta name=\"viewport\" content=\"width=device-width,user-scalable=no,initial-scale=1,maximum-scale=1,minimum-scale=1\"><meta http-equiv=\"X-UA-Compatible\" content=\"ie=edge\"><title>Not Found</title><style>body{background:#333;margin:0;color:#ccc;display:flex;align-items:center;max-height:100vh;height:100vh;justify-content:center}textarea{min-height:5rem;min-width:20rem;text-align:center;border:none;background:0 0;color:#ccc;resize:none;user-input:none;user-select:none;cursor:default;-webkit-user-select:none;-webkit-touch-callout:none;-moz-user-select:none;-ms-user-select:none;outline:0}</style></head><body><textarea id=\"banner\" readonly=\"readonly\"></textarea><script type=\"text/javascript\">let htmlCodes = [\n ' _ _ _ ___ _ ',\n '| \\\\| |___| |_ | __|__ _ _ _ _ __| |',\n '| .` / _ \\\\ _| | _/ _ \\\\ || | \\' \\\\/ _` |',\n '|_|\\\\_\\\\___/\\\\__| |_|\\\\___/\\\\_,_|_||_\\\\__,_|'\n].join('\\n');\ndocument.querySelector('#banner').value = htmlCodes</script></body></html>"
_405 = `405 Method Not Allowed`
_500 = `500 Internal Server Error`
) )
type Map map[string]interface{} type Map map[string]interface{}
type Config struct { type Config struct {
DisableMessagePrint bool `json:"-"`
// Default: 4 * 1024 * 1024 // Default: 4 * 1024 * 1024
BodyLimit int64 `json:"-"` BodyLimit int64 `json:"-"`
@ -19,8 +22,9 @@ type Config struct {
DisableRecover bool `json:"-"` DisableRecover bool `json:"-"`
DisableHttpErrorLog bool `json:"-"` DisableHttpErrorLog bool `json:"-"`
EnableNotImplementHandler bool `json:"-"` //EnableNotImplementHandler bool `json:"-"`
NotFoundHandler HandlerFunc `json:"-"` NotFoundHandler HandlerFunc `json:"-"`
MethodNotAllowedHandler HandlerFunc `json:"-"`
} }
var ( var (
@ -31,16 +35,33 @@ var (
_, err := c.Status(404).Write([]byte(_404)) _, err := c.Status(404).Write([]byte(_404))
return err return err
}, },
MethodNotAllowedHandler: func(c *Ctx) error {
c.Set("Content-Type", MIMETextPlain)
_, err := c.Status(405).Write([]byte(_405))
return err
},
} }
) )
func New(config ...Config) *App { func New(config ...Config) *App {
app := &App{ app := &App{
router: newRouter(), RouterGroup: RouterGroup{
Handlers: nil,
basePath: "/",
root: true,
},
redirectTrailingSlash: true, // true
redirectFixedPath: false, // false
handleMethodNotAllowed: true, // false
useRawPath: false, // false
unescapePathValues: true, // true
removeExtraSlash: false, // false
} }
if len(config) > 0 { if len(config) > 0 {
app.config = &config[0] app.config = &config[0]
if app.config.BodyLimit == 0 { if app.config.BodyLimit == 0 {
app.config.BodyLimit = defaultConfig.BodyLimit app.config.BodyLimit = defaultConfig.BodyLimit
} }
@ -49,12 +70,15 @@ func New(config ...Config) *App {
app.config.NotFoundHandler = defaultConfig.NotFoundHandler app.config.NotFoundHandler = defaultConfig.NotFoundHandler
} }
if app.config.MethodNotAllowedHandler == nil {
app.config.MethodNotAllowedHandler = defaultConfig.MethodNotAllowedHandler
}
} else { } else {
app.config = defaultConfig app.config = defaultConfig
} }
app.RouterGroup = &RouterGroup{app: app} app.RouterGroup.app = app
app.groups = []*RouterGroup{app.RouterGroup}
if !app.config.DisableLogger { if !app.config.DisableLogger {
app.Use(NewLogger()) app.Use(NewLogger())

67
nft/log/default.go Normal file
View File

@ -0,0 +1,67 @@
package log
import (
"fmt"
"os"
"sync"
)
var (
nilLogger = func(prefix, timestamp, msg string, data ...any) {}
normalLogger = func(prefix, timestamp, msg string, data ...any) {
fmt.Printf(prefix+"| "+timestamp+" | "+msg+"\n", data...)
}
panicLogger = func(prefix, timestamp, msg string, data ...any) {
panic(fmt.Sprintf(prefix+"| "+timestamp+" | "+msg+"\n", data...))
}
fatalLogger = func(prefix, timestamp, msg string, data ...any) {
fmt.Printf(prefix+"| "+timestamp+" | "+msg+"\n", data...)
os.Exit(1)
}
defaultLogger = &logger{
Mutex: sync.Mutex{},
timeFormat: "2006-01-02T15:04:05",
writer: os.Stdout,
level: LogLevelInfo,
debug: nilLogger,
info: normalLogger,
warn: normalLogger,
error: normalLogger,
panic: panicLogger,
fatal: fatalLogger,
}
)
func SetTimeFormat(format string) {
defaultLogger.SetTimeFormat(format)
}
func SetLogLevel(level LogLevel) {
defaultLogger.SetLogLevel(level)
}
func Debug(msg string, data ...any) {
defaultLogger.Debug(msg, data...)
}
func Info(msg string, data ...any) {
defaultLogger.Info(msg, data...)
}
func Warn(msg string, data ...any) {
defaultLogger.Warn(msg, data...)
}
func Error(msg string, data ...any) {
defaultLogger.Error(msg, data...)
}
func Panic(msg string, data ...any) {
defaultLogger.Panic(msg, data...)
}
func Fatal(msg string, data ...any) {
defaultLogger.Fatal(msg, data...)
}

115
nft/log/log.go Normal file
View File

@ -0,0 +1,115 @@
package log
import (
"github.com/fatih/color"
"io"
"sync"
"time"
)
type LogLevel uint32
const (
LogLevelDebug = iota
LogLevelInfo
LogLevelWarn
LogLevelError
LogLevelPanic
LogLevelFatal
)
type logger struct {
sync.Mutex
timeFormat string
writer io.Writer
level LogLevel
debug func(prefix, timestamp, msg string, data ...any)
info func(prefix, timestamp, msg string, data ...any)
warn func(prefix, timestamp, msg string, data ...any)
error func(prefix, timestamp, msg string, data ...any)
panic func(prefix, timestamp, msg string, data ...any)
fatal func(prefix, timestamp, msg string, data ...any)
}
var (
red = color.New(color.FgRed)
hired = color.New(color.FgHiRed)
green = color.New(color.FgGreen)
yellow = color.New(color.FgYellow)
white = color.New(color.FgWhite)
)
func (l *logger) SetTimeFormat(format string) {
l.Lock()
defer l.Unlock()
l.timeFormat = format
}
func (l *logger) SetLogLevel(level LogLevel) {
l.Lock()
defer l.Unlock()
if level > LogLevelDebug {
l.debug = nilLogger
} else {
l.debug = normalLogger
}
if level > LogLevelInfo {
l.info = nilLogger
} else {
l.info = normalLogger
}
if level > LogLevelWarn {
l.warn = nilLogger
} else {
l.warn = normalLogger
}
if level > LogLevelError {
l.error = nilLogger
} else {
l.error = normalLogger
}
if level > LogLevelPanic {
l.panic = nilLogger
} else {
l.panic = panicLogger
}
if level > LogLevelFatal {
l.fatal = nilLogger
} else {
l.fatal = fatalLogger
}
}
func (l *logger) Debug(msg string, data ...any) {
l.debug(white.Sprint("Debug "), time.Now().Format(l.timeFormat), msg, data...)
}
func (l *logger) Info(msg string, data ...any) {
l.info(green.Sprint("Info "), time.Now().Format(l.timeFormat), msg, data...)
}
func (l *logger) Warn(msg string, data ...any) {
l.warn(yellow.Sprint("Warn "), time.Now().Format(l.timeFormat), msg, data...)
}
func (l *logger) Error(msg string, data ...any) {
l.error(red.Sprint("Error "), time.Now().Format(l.timeFormat), msg, data...)
}
func (l *logger) Panic(msg string, data ...any) {
l.panic(hired.Sprint("Panic "), time.Now().Format(l.timeFormat), msg, data...)
}
func (l *logger) Fatal(msg string, data ...any) {
l.fatal(hired.Sprint("Fatal "), time.Now().Format(l.timeFormat), msg, data...)
}
type WroteLogger interface {
Info(msg string, data ...any)
}

21
nft/log/new.go Normal file
View File

@ -0,0 +1,21 @@
package log
import (
"os"
"sync"
)
func New() *logger {
return &logger{
Mutex: sync.Mutex{},
timeFormat: "2006-01-02T15:04:05",
writer: os.Stdout,
level: LogLevelInfo,
debug: nilLogger,
info: normalLogger,
warn: normalLogger,
error: normalLogger,
panic: panicLogger,
fatal: fatalLogger,
}
}

View File

@ -2,7 +2,6 @@ package resp
import ( import (
"errors" "errors"
"fmt"
"github.com/loveuer/nf" "github.com/loveuer/nf"
) )
@ -15,28 +14,28 @@ type Error struct {
func (e Error) Error() string { func (e Error) Error() string {
if e.msg != "" { if e.msg != "" {
return fmt.Sprintf("%s: %s", e.msg, e.err.Error()) return e.msg
} }
switch e.status { switch e.status {
case 200: case 200:
return fmt.Sprintf("%s: %s", MSG200, e.err.Error()) return MSG200
case 202: case 202:
return fmt.Sprintf("%s: %s", MSG202, e.err.Error()) return MSG202
case 400: case 400:
return fmt.Sprintf("%s: %s", MSG400, e.err.Error()) return MSG400
case 401: case 401:
return fmt.Sprintf("%s: %s", MSG401, e.err.Error()) return MSG401
case 403: case 403:
return fmt.Sprintf("%s: %s", MSG403, e.err.Error()) return MSG403
case 404: case 404:
return fmt.Sprintf("%s: %s", MSG404, e.err.Error()) return MSG404
case 429: case 429:
return fmt.Sprintf("%s: %s", MSG429, e.err.Error()) return MSG429
case 500: case 500:
return fmt.Sprintf("%s: %s", MSG500, e.err.Error()) return MSG500
case 501: case 501:
return fmt.Sprintf("%s: %s", MSG501, e.err.Error()) return MSG501
} }
return e.err.Error() return e.err.Error()

View File

@ -1 +0,0 @@
package nf

133
response_writer.go Normal file
View File

@ -0,0 +1,133 @@
package nf
import (
"bufio"
"io"
"log"
"net"
"net/http"
)
const (
noWritten = -1
defaultStatus = http.StatusOK
)
// ResponseWriter ...
type ResponseWriter interface {
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
// Status returns the HTTP response status code of the current request.
Status() int
// Size returns the number of bytes already written into the response http body.
// See Written()
Size() int
// WriteString writes the string into the response body.
WriteString(string) (int, error)
// Written returns true if the response body was already written.
Written() bool
// WriteHeaderNow forces to write the http header (status code + headers).
WriteHeaderNow()
// Pusher get the http.Pusher for server push
Pusher() http.Pusher
}
type responseWriter struct {
http.ResponseWriter
size int
status int
}
var _ ResponseWriter = (*responseWriter)(nil)
func (w *responseWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
func (w *responseWriter) reset(writer http.ResponseWriter) {
w.ResponseWriter = writer
w.size = noWritten
w.status = defaultStatus
}
func (w *responseWriter) WriteHeader(code int) {
if code > 0 && w.status != code {
if w.Written() {
log.Printf("[NF] WARNING: Headers were already written. Wanted to override status code %d with %d", w.status, code)
return
}
w.status = code
}
}
func (w *responseWriter) WriteHeaderNow() {
if !w.Written() {
w.size = 0
if w.status == 0 {
w.status = 200
}
w.ResponseWriter.WriteHeader(w.status)
}
}
func (w *responseWriter) Write(data []byte) (n int, err error) {
w.WriteHeaderNow()
n, err = w.ResponseWriter.Write(data)
w.size += n
return
}
func (w *responseWriter) WriteString(s string) (n int, err error) {
w.WriteHeaderNow()
n, err = io.WriteString(w.ResponseWriter, s)
w.size += n
return
}
func (w *responseWriter) Status() int {
return w.status
}
func (w *responseWriter) Size() int {
return w.size
}
func (w *responseWriter) Written() bool {
return w.size != noWritten
}
// Hijack implements the http.Hijacker interface.
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if w.size < 0 {
w.size = 0
}
return w.ResponseWriter.(http.Hijacker).Hijack()
}
// CloseNotify implements the http.CloseNotifier interface.
func (w *responseWriter) CloseNotify() <-chan bool {
return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Flush implements the http.Flusher interface.
func (w *responseWriter) Flush() {
w.WriteHeaderNow()
w.ResponseWriter.(http.Flusher).Flush()
}
func (w *responseWriter) Pusher() (pusher http.Pusher) {
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
return pusher
}
return nil
}

View File

@ -1,98 +0,0 @@
package nf
import "strings"
type router struct {
roots map[string]*_node
handlers map[string][]HandlerFunc
}
func newRouter() *router {
return &router{
roots: make(map[string]*_node),
handlers: make(map[string][]HandlerFunc),
}
}
// Only one * is allowed
func parsePattern(pattern string) []string {
vs := strings.Split(pattern, "/")
parts := make([]string, 0)
for _, item := range vs {
if item != "" {
parts = append(parts, item)
if item[0] == '*' {
break
}
}
}
return parts
}
func (r *router) addRoute(method string, pattern string, handlers ...HandlerFunc) {
parts := parsePattern(pattern)
key := method + "-" + pattern
_, ok := r.roots[method]
if !ok {
r.roots[method] = &_node{}
}
r.roots[method].insert(pattern, parts, 0)
r.handlers[key] = handlers
}
func (r *router) getRoute(method string, path string) (*_node, map[string]string) {
searchParts := parsePattern(path)
params := make(map[string]string)
root, ok := r.roots[method]
if !ok {
return nil, nil
}
n := root.search(searchParts, 0)
if n != nil {
parts := parsePattern(n.pattern)
for index, part := range parts {
if part[0] == ':' {
params[part[1:]] = searchParts[index]
}
if part[0] == '*' && len(part) > 1 {
params[part[1:]] = strings.Join(searchParts[index:], "/")
break
}
}
return n, params
}
return nil, nil
}
func (r *router) getRoutes(method string) []*_node {
root, ok := r.roots[method]
if !ok {
return nil
}
nodes := make([]*_node, 0)
root.travel(&nodes)
return nodes
}
func (r *router) handle(c *Ctx) error {
if err := c.verify(); err != nil {
return err
}
node, params := r.getRoute(c.Method, c.path)
if node != nil {
c.params = params
key := c.Method + "-" + node.pattern
c.handlers = append(c.handlers, r.handlers[key]...)
} else {
return c.app.config.NotFoundHandler(c)
}
return c.Next()
}

155
routergroup.go Normal file
View File

@ -0,0 +1,155 @@
package nf
import (
"math"
"net/http"
"path"
"regexp"
)
var (
// regEnLetter matches english letters for http method name
regEnLetter = regexp.MustCompile("^[A-Z]+$")
// anyMethods for RouterGroup Any method
anyMethods = []string{
http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch,
http.MethodHead, http.MethodOptions, http.MethodDelete, http.MethodConnect,
http.MethodTrace,
}
)
// IRouter defines all router handle interface includes single and group router.
type IRouter interface {
IRoutes
Group(string, ...HandlerFunc) *RouterGroup
}
// IRoutes defines all router handle interface.
type IRoutes interface {
Use(...HandlerFunc) IRoutes
Handle(string, string, ...HandlerFunc) IRoutes
Any(string, ...HandlerFunc) IRoutes
Get(string, ...HandlerFunc) IRoutes
Post(string, ...HandlerFunc) IRoutes
Delete(string, ...HandlerFunc) IRoutes
Patch(string, ...HandlerFunc) IRoutes
Put(string, ...HandlerFunc) IRoutes
Options(string, ...HandlerFunc) IRoutes
Head(string, ...HandlerFunc) IRoutes
Match([]string, string, ...HandlerFunc) IRoutes
//StaticFile(string, string) IRoutes
//StaticFileFS(string, string, http.FileSystem) IRoutes
//Static(string, string) IRoutes
//StaticFS(string, http.FileSystem) IRoutes
}
type RouterGroup struct {
Handlers []HandlerFunc
basePath string
app *App
root bool
}
var _ IRouter = (*RouterGroup)(nil)
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup {
return &RouterGroup{
Handlers: group.combineHandlers(handlers...),
basePath: group.calculateAbsolutePath(relativePath),
app: group.app,
}
}
func (group *RouterGroup) BasePath() string {
return group.basePath
}
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
handlers = group.combineHandlers(handlers...)
group.app.addRoute(httpMethod, absolutePath, handlers...)
return group.returnObj()
}
func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes {
if matched := regEnLetter.MatchString(httpMethod); !matched {
panic("http method " + httpMethod + " is not valid")
}
return group.handle(httpMethod, relativePath, handlers...)
}
func (group *RouterGroup) Post(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodPost, relativePath, handlers...)
}
func (group *RouterGroup) Get(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodGet, relativePath, handlers...)
}
func (group *RouterGroup) Delete(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodDelete, relativePath, handlers...)
}
func (group *RouterGroup) Patch(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodPatch, relativePath, handlers...)
}
func (group *RouterGroup) Put(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodPut, relativePath, handlers...)
}
func (group *RouterGroup) Options(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodOptions, relativePath, handlers...)
}
func (group *RouterGroup) Head(relativePath string, handlers ...HandlerFunc) IRoutes {
return group.handle(http.MethodHead, relativePath, handlers...)
}
// Any registers a route that matches all the HTTP methods.
// GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE.
func (group *RouterGroup) Any(relativePath string, handlers ...HandlerFunc) IRoutes {
for _, method := range anyMethods {
group.handle(method, relativePath, handlers...)
}
return group.returnObj()
}
func (group *RouterGroup) Match(methods []string, relativePath string, handlers ...HandlerFunc) IRoutes {
for _, method := range methods {
group.handle(method, relativePath, handlers...)
}
return group.returnObj()
}
const abortIndex int8 = math.MaxInt8 >> 1
func (group *RouterGroup) combineHandlers(handlers ...HandlerFunc) []HandlerFunc {
finalSize := len(group.Handlers) + len(handlers)
elsePanic(finalSize < int(abortIndex), "too many handlers")
mergedHandlers := make([]HandlerFunc, finalSize)
copy(mergedHandlers, group.Handlers)
copy(mergedHandlers[len(group.Handlers):], handlers)
return mergedHandlers
}
func (group *RouterGroup) calculateAbsolutePath(relativePath string) string {
return path.Join(group.basePath, relativePath)
}
func (group *RouterGroup) returnObj() IRoutes {
if group.root {
return group.app
}
return group
}

903
tree.go
View File

@ -1,76 +1,891 @@
package nf package nf
import ( import (
"bytes"
"net/url"
"strings" "strings"
"unicode"
"unicode/utf8"
"github.com/loveuer/nf/internal/bytesconv"
) )
type _node struct { var (
pattern string strColon = []byte(":")
part string strStar = []byte("*")
children []*_node strSlash = []byte("/")
isWild bool )
// Param is a single URL parameter, consisting of a key and a value.
type Param struct {
Key string
Value string
} }
func (n *_node) insert(pattern string, parts []string, height int) { // Params is a Param-slice, as returned by the router.
if len(parts) == height { // The slice is ordered, the first URL parameter is also the first slice value.
n.pattern = pattern // It is therefore safe to read values by the index.
type Params []Param
// Get returns the value of the first Param which key matches the given name and a boolean true.
// If no matching Param is found, an empty string is returned and a boolean false .
func (ps Params) Get(name string) (string, bool) {
for _, entry := range ps {
if entry.Key == name {
return entry.Value, true
}
}
return "", false
}
// ByName returns the value of the first Param which key matches the given name.
// If no matching Param is found, an empty string is returned.
func (ps Params) ByName(name string) (va string) {
va, _ = ps.Get(name)
return return
} }
part := parts[height] type methodTree struct {
child := n.matchChild(part) method string
if child == nil { root *node
child = &_node{part: part, isWild: part[0] == ':' || part[0] == '*'}
n.children = append(n.children, child)
}
child.insert(pattern, parts, height+1)
} }
func (n *_node) search(parts []string, height int) *_node { type methodTrees []methodTree
if len(parts) == height || strings.HasPrefix(n.part, "*") {
if n.pattern == "" { func (trees methodTrees) get(method string) *node {
for _, tree := range trees {
if tree.method == method {
return tree.root
}
}
return nil return nil
} }
func min(a, b int) int {
if a <= b {
return a
}
return b
}
func longestCommonPrefix(a, b string) int {
i := 0
max := min(len(a), len(b))
for i < max && a[i] == b[i] {
i++
}
return i
}
// addChild will add a child node, keeping wildcardChild at the end
func (n *node) addChild(child *node) {
if n.wildChild && len(n.children) > 0 {
wildcardChild := n.children[len(n.children)-1]
n.children = append(n.children[:len(n.children)-1], child, wildcardChild)
} else {
n.children = append(n.children, child)
}
}
func countParams(path string) uint16 {
var n uint16
s := bytesconv.StringToBytes(path)
n += uint16(bytes.Count(s, strColon))
n += uint16(bytes.Count(s, strStar))
return n return n
} }
part := parts[height] func countSections(path string) uint16 {
children := n.matchChildren(part) s := bytesconv.StringToBytes(path)
return uint16(bytes.Count(s, strSlash))
}
for _, child := range children { type nodeType uint8
result := child.search(parts, height+1)
if result != nil { const (
return result static nodeType = iota
root
param
catchAll
)
type node struct {
path string
indices string
wildChild bool
nType nodeType
priority uint32
children []*node // child nodes, at most 1 :param style node at the end of the array
handlers []HandlerFunc
fullPath string
}
// Increments priority of the given child and reorders if necessary
func (n *node) incrementChildPrio(pos int) int {
cs := n.children
cs[pos].priority++
prio := cs[pos].priority
// Adjust position (move to front)
newPos := pos
for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- {
// Swap node positions
cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1]
}
// Build new index char string
if newPos != pos {
n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty
n.indices[pos:pos+1] + // The index char we move
n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos'
}
return newPos
}
// addRoute adds a node with the given handle to the path.
// Not concurrency-safe!
func (n *node) addRoute(path string, handlers ...HandlerFunc) {
fullPath := path
n.priority++
// Empty tree
if len(n.path) == 0 && len(n.children) == 0 {
n.insertChild(path, fullPath, handlers...)
n.nType = root
return
}
parentFullPathIndex := 0
walk:
for {
// Find the longest common prefix.
// This also implies that the common prefix contains no ':' or '*'
// since the existing key can't contain those chars.
i := longestCommonPrefix(path, n.path)
// Split edge
if i < len(n.path) {
child := node{
path: n.path[i:],
wildChild: n.wildChild,
nType: static,
indices: n.indices,
children: n.children,
handlers: n.handlers,
priority: n.priority - 1,
fullPath: n.fullPath,
}
n.children = []*node{&child}
// []byte for proper unicode char conversion, see #65
n.indices = bytesconv.BytesToString([]byte{n.path[i]})
n.path = path[:i]
n.handlers = nil
n.wildChild = false
n.fullPath = fullPath[:parentFullPathIndex+i]
}
// Make new node a child of this node
if i < len(path) {
path = path[i:]
c := path[0]
// '/' after param
if n.nType == param && c == '/' && len(n.children) == 1 {
parentFullPathIndex += len(n.path)
n = n.children[0]
n.priority++
continue walk
}
// Check if a child with the next path byte exists
for i, max := 0, len(n.indices); i < max; i++ {
if c == n.indices[i] {
parentFullPathIndex += len(n.path)
i = n.incrementChildPrio(i)
n = n.children[i]
continue walk
} }
} }
// Otherwise insert it
if c != ':' && c != '*' && n.nType != catchAll {
// []byte for proper unicode char conversion, see #65
n.indices += bytesconv.BytesToString([]byte{c})
child := &node{
fullPath: fullPath,
}
n.addChild(child)
n.incrementChildPrio(len(n.indices) - 1)
n = child
} else if n.wildChild {
// inserting a wildcard node, need to check if it conflicts with the existing wildcard
n = n.children[len(n.children)-1]
n.priority++
// Check if the wildcard matches
if len(path) >= len(n.path) && n.path == path[:len(n.path)] &&
// Adding a child to a catchAll is not possible
n.nType != catchAll &&
// Check for longer wildcard, e.g. :name and :names
(len(n.path) >= len(path) || path[len(n.path)] == '/') {
continue walk
}
// Wildcard conflict
pathSeg := path
if n.nType != catchAll {
pathSeg = strings.SplitN(pathSeg, "/", 2)[0]
}
prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
panic("'" + pathSeg +
"' in new path '" + fullPath +
"' conflicts with existing wildcard '" + n.path +
"' in existing prefix '" + prefix +
"'")
}
n.insertChild(path, fullPath, handlers...)
return
}
// Otherwise add handle to current node
if n.handlers != nil {
panic("handlers are already registered for path '" + fullPath + "'")
}
n.handlers = handlers
n.fullPath = fullPath
return
}
}
// Search for a wildcard segment and check the name for invalid characters.
// Returns -1 as index, if no wildcard was found.
func findWildcard(path string) (wildcard string, i int, valid bool) {
// Find start
for start, c := range []byte(path) {
// A wildcard starts with ':' (param) or '*' (catch-all)
if c != ':' && c != '*' {
continue
}
// Find end and check for invalid characters
valid = true
for end, c := range []byte(path[start+1:]) {
switch c {
case '/':
return path[start : start+1+end], start, valid
case ':', '*':
valid = false
}
}
return path[start:], start, valid
}
return "", -1, false
}
func (n *node) insertChild(path string, fullPath string, handlers ...HandlerFunc) {
for {
// Find prefix until first wildcard
wildcard, i, valid := findWildcard(path)
if i < 0 { // No wildcard found
break
}
// The wildcard name must only contain one ':' or '*' character
if !valid {
panic("only one wildcard per path segment is allowed, has: '" +
wildcard + "' in path '" + fullPath + "'")
}
// check if the wildcard has a name
if len(wildcard) < 2 {
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
}
if wildcard[0] == ':' { // param
if i > 0 {
// Insert prefix before the current wildcard
n.path = path[:i]
path = path[i:]
}
child := &node{
nType: param,
path: wildcard,
fullPath: fullPath,
}
n.addChild(child)
n.wildChild = true
n = child
n.priority++
// if the path doesn't end with the wildcard, then there
// will be another subpath starting with '/'
if len(wildcard) < len(path) {
path = path[len(wildcard):]
child := &node{
priority: 1,
fullPath: fullPath,
}
n.addChild(child)
n = child
continue
}
// Otherwise we're done. Insert the handle in the new leaf
n.handlers = handlers
return
}
// catchAll
if i+len(wildcard) != len(path) {
panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
}
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
pathSeg := ""
if len(n.children) != 0 {
pathSeg = strings.SplitN(n.children[0].path, "/", 2)[0]
}
panic("catch-all wildcard '" + path +
"' in new path '" + fullPath +
"' conflicts with existing path segment '" + pathSeg +
"' in existing prefix '" + n.path + pathSeg +
"'")
}
// currently fixed width 1 for '/'
i--
if path[i] != '/' {
panic("no / before catch-all in path '" + fullPath + "'")
}
n.path = path[:i]
// First node: catchAll node with empty path
child := &node{
wildChild: true,
nType: catchAll,
fullPath: fullPath,
}
n.addChild(child)
n.indices = string('/')
n = child
n.priority++
// second node: node holding the variable
child = &node{
path: path[i:],
nType: catchAll,
handlers: handlers,
priority: 1,
fullPath: fullPath,
}
n.children = []*node{child}
return
}
// If no wildcard was found, simply insert the path and handle
n.path = path
n.handlers = handlers
n.fullPath = fullPath
}
// nodeValue holds return values of (*Node).getValue method
type nodeValue struct {
handlers []HandlerFunc
params *Params
tsr bool
fullPath string
}
type skippedNode struct {
path string
node *node
paramsCount int16
}
// Returns the handle registered with the given path (key). The values of
// wildcards are saved to a map.
// If no handle can be found, a TSR (trailing slash redirect) recommendation is
// made if a handle exists with an extra (without the) trailing slash for the
// given path.
func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) {
var globalParamsCount int16
walk: // Outer loop for walking the tree
for {
prefix := n.path
if len(path) > len(prefix) {
if path[:len(prefix)] == prefix {
path = path[len(prefix):]
// Try all the non-wildcard children first by matching the indices
idxc := path[0]
for i, c := range []byte(n.indices) {
if c == idxc {
// strings.HasPrefix(n.children[len(n.children)-1].path, ":") == n.wildChild
if n.wildChild {
index := len(*skippedNodes)
*skippedNodes = (*skippedNodes)[:index+1]
(*skippedNodes)[index] = skippedNode{
path: prefix + path,
node: &node{
path: n.path,
wildChild: n.wildChild,
nType: n.nType,
priority: n.priority,
children: n.children,
handlers: n.handlers,
fullPath: n.fullPath,
},
paramsCount: globalParamsCount,
}
}
n = n.children[i]
continue walk
}
}
if !n.wildChild {
// If the path at the end of the loop is not equal to '/' and the current node has no child nodes
// the current node needs to roll back to last valid skippedNode
if path != "/" {
for length := len(*skippedNodes); length > 0; length-- {
skippedNode := (*skippedNodes)[length-1]
*skippedNodes = (*skippedNodes)[:length-1]
if strings.HasSuffix(skippedNode.path, path) {
path = skippedNode.path
n = skippedNode.node
if value.params != nil {
*value.params = (*value.params)[:skippedNode.paramsCount]
}
globalParamsCount = skippedNode.paramsCount
continue walk
}
}
}
// Nothing found.
// We can recommend to redirect to the same URL without a
// trailing slash if a leaf exists for that path.
value.tsr = path == "/" && n.handlers != nil
return value
}
// Handle wildcard child, which is always at the end of the array
n = n.children[len(n.children)-1]
globalParamsCount++
switch n.nType {
case param:
// fix truncate the parameter
// tree_test.go line: 204
// Find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// Save param value
if params != nil {
// Preallocate capacity if necessary
if cap(*params) < int(globalParamsCount) {
newParams := make(Params, len(*params), globalParamsCount)
copy(newParams, *params)
*params = newParams
}
if value.params == nil {
value.params = params
}
// Expand slice within preallocated capacity
i := len(*value.params)
*value.params = (*value.params)[:i+1]
val := path[:end]
if unescape {
if v, err := url.QueryUnescape(val); err == nil {
val = v
}
}
(*value.params)[i] = Param{
Key: n.path[1:],
Value: val,
}
}
// we need to go deeper!
if end < len(path) {
if len(n.children) > 0 {
path = path[end:]
n = n.children[0]
continue walk
}
// ... but we can't
value.tsr = len(path) == end+1
return value
}
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return value
}
if len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists for TSR recommendation
n = n.children[0]
value.tsr = (n.path == "/" && n.handlers != nil) || (n.path == "" && n.indices == "/")
}
return value
case catchAll:
// Save param value
if params != nil {
// Preallocate capacity if necessary
if cap(*params) < int(globalParamsCount) {
newParams := make(Params, len(*params), globalParamsCount)
copy(newParams, *params)
*params = newParams
}
if value.params == nil {
value.params = params
}
// Expand slice within preallocated capacity
i := len(*value.params)
*value.params = (*value.params)[:i+1]
val := path
if unescape {
if v, err := url.QueryUnescape(path); err == nil {
val = v
}
}
(*value.params)[i] = Param{
Key: n.path[2:],
Value: val,
}
}
value.handlers = n.handlers
value.fullPath = n.fullPath
return value
default:
panic("invalid node type")
}
}
}
if path == prefix {
// If the current path does not equal '/' and the node does not have a registered handle and the most recently matched node has a child node
// the current node needs to roll back to last valid skippedNode
if n.handlers == nil && path != "/" {
for length := len(*skippedNodes); length > 0; length-- {
skippedNode := (*skippedNodes)[length-1]
*skippedNodes = (*skippedNodes)[:length-1]
if strings.HasSuffix(skippedNode.path, path) {
path = skippedNode.path
n = skippedNode.node
if value.params != nil {
*value.params = (*value.params)[:skippedNode.paramsCount]
}
globalParamsCount = skippedNode.paramsCount
continue walk
}
}
// n = latestNode.children[len(latestNode.children)-1]
}
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return value
}
// If there is no handle for this route, but this route has a
// wildcard child, there must be a handle for this path with an
// additional trailing slash
if path == "/" && n.wildChild && n.nType != root {
value.tsr = true
return value
}
if path == "/" && n.nType == static {
value.tsr = true
return value
}
// No handle found. Check if a handle for this path + a
// trailing slash exists for trailing slash recommendation
for i, c := range []byte(n.indices) {
if c == '/' {
n = n.children[i]
value.tsr = (len(n.path) == 1 && n.handlers != nil) ||
(n.nType == catchAll && n.children[0].handlers != nil)
return value
}
}
return value
}
// Nothing found. We can recommend to redirect to the same URL with an
// extra trailing slash if a leaf exists for that path
value.tsr = path == "/" ||
(len(prefix) == len(path)+1 && prefix[len(path)] == '/' &&
path == prefix[:len(prefix)-1] && n.handlers != nil)
// roll back to last valid skippedNode
if !value.tsr && path != "/" {
for length := len(*skippedNodes); length > 0; length-- {
skippedNode := (*skippedNodes)[length-1]
*skippedNodes = (*skippedNodes)[:length-1]
if strings.HasSuffix(skippedNode.path, path) {
path = skippedNode.path
n = skippedNode.node
if value.params != nil {
*value.params = (*value.params)[:skippedNode.paramsCount]
}
globalParamsCount = skippedNode.paramsCount
continue walk
}
}
}
return value
}
}
// Makes a case-insensitive lookup of the given path and tries to find a handler.
// It can optionally also fix trailing slashes.
// It returns the case-corrected path and a bool indicating whether the lookup
// was successful.
func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) {
const stackBufSize = 128
// Use a static sized buffer on the stack in the common case.
// If the path is too long, allocate a buffer on the heap instead.
buf := make([]byte, 0, stackBufSize)
if length := len(path) + 1; length > stackBufSize {
buf = make([]byte, 0, length)
}
ciPath := n.findCaseInsensitivePathRec(
path,
buf, // Preallocate enough memory for new path
[4]byte{}, // Empty rune buffer
fixTrailingSlash,
)
return ciPath, ciPath != nil
}
// Shift bytes in array by n bytes left
func shiftNRuneBytes(rb [4]byte, n int) [4]byte {
switch n {
case 0:
return rb
case 1:
return [4]byte{rb[1], rb[2], rb[3], 0}
case 2:
return [4]byte{rb[2], rb[3]}
case 3:
return [4]byte{rb[3]}
default:
return [4]byte{}
}
}
// Recursive case-insensitive lookup function used by n.findCaseInsensitivePath
func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte {
npLen := len(n.path)
walk: // Outer loop for walking the tree
for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) {
// Add common prefix to result
oldPath := path
path = path[npLen:]
ciPath = append(ciPath, n.path...)
if len(path) == 0 {
// We should have reached the node containing the handle.
// Check if this node has a handle registered.
if n.handlers != nil {
return ciPath
}
// No handle found.
// Try to fix the path by adding a trailing slash
if fixTrailingSlash {
for i, c := range []byte(n.indices) {
if c == '/' {
n = n.children[i]
if (len(n.path) == 1 && n.handlers != nil) ||
(n.nType == catchAll && n.children[0].handlers != nil) {
return append(ciPath, '/')
}
return nil return nil
} }
func (n *_node) travel(list *([]*_node)) {
if n.pattern != "" {
*list = append(*list, n)
}
for _, child := range n.children {
child.travel(list)
}
}
func (n *_node) matchChild(part string) *_node {
for _, child := range n.children {
if child.part == part || child.isWild {
return child
} }
} }
return nil return nil
} }
func (n *_node) matchChildren(part string) []*_node { // If this node does not have a wildcard (param or catchAll) child,
nodes := make([]*_node, 0) // we can just look up the next child node and continue to walk down
for _, child := range n.children { // the tree
if child.part == part || child.isWild { if !n.wildChild {
nodes = append(nodes, child) // Skip rune bytes already processed
rb = shiftNRuneBytes(rb, npLen)
if rb[0] != 0 {
// Old rune not finished
idxc := rb[0]
for i, c := range []byte(n.indices) {
if c == idxc {
// continue with child node
n = n.children[i]
npLen = len(n.path)
continue walk
} }
} }
return nodes } else {
// Process a new rune
var rv rune
// Find rune start.
// Runes are up to 4 byte long,
// -4 would definitely be another rune.
var off int
for max := min(npLen, 3); off < max; off++ {
if i := npLen - off; utf8.RuneStart(oldPath[i]) {
// read rune from cached path
rv, _ = utf8.DecodeRuneInString(oldPath[i:])
break
}
}
// Calculate lowercase bytes of current rune
lo := unicode.ToLower(rv)
utf8.EncodeRune(rb[:], lo)
// Skip already processed bytes
rb = shiftNRuneBytes(rb, off)
idxc := rb[0]
for i, c := range []byte(n.indices) {
// Lowercase matches
if c == idxc {
// must use a recursive approach since both the
// uppercase byte and the lowercase byte might exist
// as an index
if out := n.children[i].findCaseInsensitivePathRec(
path, ciPath, rb, fixTrailingSlash,
); out != nil {
return out
}
break
}
}
// If we found no match, the same for the uppercase rune,
// if it differs
if up := unicode.ToUpper(rv); up != lo {
utf8.EncodeRune(rb[:], up)
rb = shiftNRuneBytes(rb, off)
idxc := rb[0]
for i, c := range []byte(n.indices) {
// Uppercase matches
if c == idxc {
// Continue with child node
n = n.children[i]
npLen = len(n.path)
continue walk
}
}
}
}
// Nothing found. We can recommend to redirect to the same URL
// without a trailing slash if a leaf exists for that path
if fixTrailingSlash && path == "/" && n.handlers != nil {
return ciPath
}
return nil
}
n = n.children[0]
switch n.nType {
case param:
// Find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// Add param value to case insensitive path
ciPath = append(ciPath, path[:end]...)
// We need to go deeper!
if end < len(path) {
if len(n.children) > 0 {
// Continue with child node
n = n.children[0]
npLen = len(n.path)
path = path[end:]
continue
}
// ... but we can't
if fixTrailingSlash && len(path) == end+1 {
return ciPath
}
return nil
}
if n.handlers != nil {
return ciPath
}
if fixTrailingSlash && len(n.children) == 1 {
// No handle found. Check if a handle for this path + a
// trailing slash exists
n = n.children[0]
if n.path == "/" && n.handlers != nil {
return append(ciPath, '/')
}
}
return nil
case catchAll:
return append(ciPath, path...)
default:
panic("invalid node type")
}
}
// Nothing found.
// Try to fix the path by adding / removing a trailing slash
if fixTrailingSlash {
if path == "/" {
return ciPath
}
if len(path)+1 == npLen && n.path[len(path)] == '/' &&
strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handlers != nil {
return append(ciPath, n.path...)
}
}
return nil
} }

158
util.go
View File

@ -65,3 +65,161 @@ func parseToStruct(aliasTag string, out interface{}, data map[string][]string) e
return nil return nil
} }
func elsePanic(guard bool, text string) {
if !guard {
panic(text)
}
}
func cleanPath(p string) string {
const stackBufSize = 128
// Turn empty string into "/"
if p == "" {
return "/"
}
// Reasonably sized buffer on stack to avoid allocations in the common case.
// If a larger buffer is required, it gets allocated dynamically.
buf := make([]byte, 0, stackBufSize)
n := len(p)
// Invariants:
// reading from path; r is index of next byte to process.
// writing to buf; w is index of next byte to write.
// path must start with '/'
r := 1
w := 1
if p[0] != '/' {
r = 0
if n+1 > stackBufSize {
buf = make([]byte, n+1)
} else {
buf = buf[:n+1]
}
buf[0] = '/'
}
trailing := n > 1 && p[n-1] == '/'
// A bit more clunky without a 'lazybuf' like the path package, but the loop
// gets completely inlined (bufApp calls).
// loop has no expensive function calls (except 1x make) // So in contrast to the path package this loop has no expensive function
// calls (except make, if needed).
for r < n {
switch {
case p[r] == '/':
// empty path element, trailing slash is added after the end
r++
case p[r] == '.' && r+1 == n:
trailing = true
r++
case p[r] == '.' && p[r+1] == '/':
// . element
r += 2
case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'):
// .. element: remove to last /
r += 3
if w > 1 {
// can backtrack
w--
if len(buf) == 0 {
for w > 1 && p[w] != '/' {
w--
}
} else {
for w > 1 && buf[w] != '/' {
w--
}
}
}
default:
// Real path element.
// Add slash if needed
if w > 1 {
bufApp(&buf, p, w, '/')
w++
}
// Copy element
for r < n && p[r] != '/' {
bufApp(&buf, p, w, p[r])
w++
r++
}
}
}
// Re-append trailing slash
if trailing && w > 1 {
bufApp(&buf, p, w, '/')
w++
}
// If the original string was not modified (or only shortened at the end),
// return the respective substring of the original string.
// Otherwise return a new string from the buffer.
if len(buf) == 0 {
return p[:w]
}
return string(buf[:w])
}
// Internal helper to lazily create a buffer if necessary.
// Calls to this function get inlined.
func bufApp(buf *[]byte, s string, w int, c byte) {
b := *buf
if len(b) == 0 {
// No modification of the original string so far.
// If the next character is the same as in the original string, we do
// not yet have to allocate a buffer.
if s[w] == c {
return
}
// Otherwise use either the stack buffer, if it is large enough, or
// allocate a new buffer on the heap, and copy all previous characters.
length := len(s)
if length > cap(b) {
*buf = make([]byte, length)
} else {
*buf = (*buf)[:length]
}
b = *buf
copy(b, s[:w])
}
b[w] = c
}
func HumanDuration(nano int64) string {
duration := float64(nano)
unit := "ns"
if duration >= 1000 {
duration /= 1000
unit = "us"
}
if duration >= 1000 {
duration /= 1000
unit = "ms"
}
if duration >= 1000 {
duration /= 1000
unit = " s"
}
return fmt.Sprintf("%6.2f%s", duration, unit)
}

View File

@ -1,6 +0,0 @@
### basic - get
GET http://127.0.0.1/hello/nf
### test resp error
GET http://127.0.0.1/error

View File

@ -1,31 +0,0 @@
package main
import (
"errors"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/resp"
"log"
"net"
"time"
)
func main() {
app := nf.New(nf.Config{EnableNotImplementHandler: true})
app.Get("/hello/:name", func(c *nf.Ctx) error {
name := c.Param("name")
return c.JSON(nf.Map{"status": 200, "data": "hello, " + name})
})
app.Get("/not_impl")
app.Patch("/world", func(c *nf.Ctx) error {
time.Sleep(5 * time.Second)
c.Status(404)
return c.JSON(nf.Map{"method": c.Method, "status": c.StatusCode})
})
app.Get("/error", func(c *nf.Ctx) error {
return resp.RespError(c, resp.NewError(404, "not found", errors.New("NNNot Found"), nil))
})
ln, _ := net.Listen("tcp", ":80")
log.Fatal(app.RunListener(ln))
}

View File

@ -1,9 +0,0 @@
### body_limit
POST http://127.0.0.1/data
Content-Type: application/json
{
"name": "zyp",
"age": 19,
"likes": ["2233"]
}

View File

@ -1,50 +0,0 @@
package main
import (
"github.com/loveuer/nf"
"log"
)
func main() {
app := nf.New(nf.Config{BodyLimit: -1})
app.Post("/data", func(c *nf.Ctx) error {
type Req struct {
Name string `json:"name"`
Age int `json:"age"`
Likes []string `json:"likes"`
}
var (
err error
req = new(Req)
)
if err = c.BodyParser(req); err != nil {
return c.JSON(nf.Map{"status": 400, "err": err.Error()})
}
return c.JSON(nf.Map{"status": 200, "data": req})
})
app.Post("/url", func(c *nf.Ctx) error {
type Req struct {
Name string `form:"name"`
Age int `form:"age"`
Likes []string `form:"likes"`
}
var (
err error
req = new(Req)
)
if err = c.BodyParser(req); err != nil {
return c.JSON(nf.Map{"status": 400, "err": err.Error()})
}
return c.JSON(nf.Map{"status": 200, "data": req})
})
log.Fatal(app.Run("0.0.0.0:80"))
}

View File

@ -1,24 +0,0 @@
package main
import (
"github.com/loveuer/nf"
"log"
)
func main() {
app := nf.New(nf.Config{
DisableRecover: true,
})
app.Get("/hello/:name", func(c *nf.Ctx) error {
name := c.Param("name")
if name == "nf" {
panic("name is nf")
}
return c.JSON("nice")
})
log.Fatal(app.Run("0.0.0.0:80"))
}

View File

@ -1,5 +0,0 @@
### panic test
GET http://127.0.0.1/hello/nf
### if covered?
GET http://127.0.0.1/hello/world

View File

@ -1,36 +0,0 @@
package main
import (
"github.com/loveuer/nf"
"log"
)
func main() {
app := nf.New()
app.Get("/hello", func(c *nf.Ctx) error {
type Req struct {
Name string `query:"name"`
Age int `query:"age"`
Likes []string `query:"likes"`
}
var (
err error
req = new(Req)
rm = make(map[string]interface{})
)
//if err = c.QueryParser(req); err != nil {
// return nf.NewNFError(400, "1:"+err.Error())
//}
if err = c.QueryParser(&rm); err != nil {
return nf.NewNFError(400, "2:"+err.Error())
}
return c.JSON(nf.Map{"status": 200, "data": req, "map": rm})
})
log.Fatal(app.Run("0.0.0.0:80"))
}

View File

@ -1,52 +0,0 @@
package main
import (
"context"
"github.com/loveuer/nf"
"log"
"time"
)
var (
app = nf.New()
quit = make(chan bool)
)
func main() {
app.Get("/name", handleGet)
go func() {
err := app.Run(":80")
log.Print("run with err=", err)
quit <- true
}()
<-quit
}
func handleGet(c *nf.Ctx) error {
type Req struct {
Name string `query:"name"`
Addr []string `query:"addr"`
}
var (
err error
req = Req{}
)
if err = c.QueryParser(&req); err != nil {
return nf.NewNFError(400, err.Error())
}
if req.Name == "quit" {
go func() {
time.Sleep(2 * time.Second)
log.Print("app quit = ", app.Shutdown(context.TODO()))
}()
}
return c.JSON(nf.Map{"req_map": req})
}

View File

@ -1,119 +0,0 @@
package main
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"github.com/loveuer/nf"
"log"
"math/big"
"net"
"time"
)
func main() {
app := nf.New(nf.Config{
DisableHttpErrorLog: true,
})
app.Get("/hello/:name", func(c *nf.Ctx) error {
return c.SendString("hello, " + c.Param("name"))
})
st, _, _ := GenerateTlsConfig()
log.Fatal(app.RunTLS(":443", st))
}
func GenerateTlsConfig() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(99, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
// create our private and public key
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
// create the CA
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}
// pem encode
caPEM := new(bytes.Buffer)
pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})
caPrivKeyPEM := new(bytes.Buffer)
pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})
// set up our server certificate
cert := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}
certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
if err != nil {
return nil, nil, err
}
serverTLSConf = &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
certpool := x509.NewCertPool()
certpool.AppendCertsFromPEM(caPEM.Bytes())
clientTLSConf = &tls.Config{
RootCAs: certpool,
}
return
}