From d1e1a32eedd2b02e9b59b96ffabc8e77caf44c26 Mon Sep 17 00:00:00 2001 From: loveuer Date: Fri, 12 Jan 2024 19:18:33 +0800 Subject: [PATCH] alpha: v0.0.1 --- .gitignore | 3 + app.go | 44 +++ ctx.go | 156 ++++++++++ error.go | 16 + go.mod | 3 + group.go | 61 ++++ handler.go | 3 + internal/schema/LICENSE | 27 ++ internal/schema/cache.go | 305 ++++++++++++++++++ internal/schema/converter.go | 145 +++++++++ internal/schema/decoder.go | 534 ++++++++++++++++++++++++++++++++ internal/schema/doc.go | 148 +++++++++ internal/schema/encoder.go | 202 ++++++++++++ middleware.go | 23 ++ nf.go | 45 +++ resp.go | 49 +++ router.go | 99 ++++++ tree.go | 76 +++++ util.go | 81 +++++ xtest/basic/basic.http | 2 + xtest/basic/main.go | 17 + xtest/bodyLimit/body_limit.http | 9 + xtest/bodyLimit/main.go | 50 +++ xtest/panic/main.go | 24 ++ xtest/panic/panic.http | 5 + xtest/queryParser/main.go | 31 ++ 26 files changed, 2158 insertions(+) create mode 100644 .gitignore create mode 100644 app.go create mode 100644 ctx.go create mode 100644 error.go create mode 100644 go.mod create mode 100644 group.go create mode 100644 handler.go create mode 100644 internal/schema/LICENSE create mode 100644 internal/schema/cache.go create mode 100644 internal/schema/converter.go create mode 100644 internal/schema/decoder.go create mode 100644 internal/schema/doc.go create mode 100644 internal/schema/encoder.go create mode 100644 middleware.go create mode 100644 nf.go create mode 100644 resp.go create mode 100644 router.go create mode 100644 tree.go create mode 100644 util.go create mode 100644 xtest/basic/basic.http create mode 100644 xtest/basic/main.go create mode 100644 xtest/bodyLimit/body_limit.http create mode 100644 xtest/bodyLimit/main.go create mode 100644 xtest/panic/main.go create mode 100644 xtest/panic/panic.http create mode 100644 xtest/queryParser/main.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d4d9525 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea +.vscode +.DS_Store diff --git a/app.go b/app.go new file mode 100644 index 0000000..169cadd --- /dev/null +++ b/app.go @@ -0,0 +1,44 @@ +package nf + +import ( + "errors" + "fmt" + "net/http" + "strings" +) + +type App struct { + *RouterGroup + config *Config + router *router + groups []*RouterGroup +} + +func (a *App) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + c := newContext(a, writer, request) + + for _, group := range a.groups { + if strings.HasPrefix(request.URL.Path, group.prefix) { + c.handlers = append(c.handlers, group.middlewares...) + } + } + + if err := a.router.handle(c); err != nil { + var ne = &Err{} + + if errors.As(err, ne) { + writer.WriteHeader(ne.Status) + } else { + writer.WriteHeader(500) + } + + _, _ = writer.Write([]byte(err.Error())) + } +} + +func (a *App) Run(address string) error { + if !a.config.DisableBanner { + fmt.Println(banner + "nf serve at: " + address + "\n") + } + return http.ListenAndServe(address, a) +} diff --git a/ctx.go b/ctx.go new file mode 100644 index 0000000..d4284ec --- /dev/null +++ b/ctx.go @@ -0,0 +1,156 @@ +package nf + +import ( + "bytes" + "encoding/json" + "io" + "log" + "net/http" + "strings" +) + +type Ctx struct { + // origin objects + Writer http.ResponseWriter + Request *http.Request + // request info + path string + Method string + // response info + StatusCode int + + app *App + params map[string]string + index int + handlers []HandlerFunc + locals map[string]any +} + +func newContext(app *App, writer http.ResponseWriter, request *http.Request) *Ctx { + return &Ctx{ + Writer: writer, + Request: request, + path: request.URL.Path, + Method: request.Method, + + app: app, + index: -1, + locals: map[string]any{}, + handlers: make([]HandlerFunc, 0), + } +} + +func (c *Ctx) Locals(key string, value ...any) any { + data := c.locals[key] + if len(value) > 0 { + c.locals[key] = value[0] + } + + return data +} + +func (c *Ctx) Path(overWrite ...string) string { + path := c.Request.URL.Path + if len(overWrite) > 0 && overWrite[0] != "" { + c.Request.URL.Path = overWrite[0] + } + + return path +} + +func (c *Ctx) Next() error { + c.index++ + s := len(c.handlers) + for ; c.index < s; c.index++ { + if err := c.handlers[c.index](c); err != nil { + return err + } + } + + return nil +} + +/* =============================================================== +|| Handle Ctx Request Part +=============================================================== */ + +func (c *Ctx) verify() error { + // 验证 body size + if c.app.config.BodyLimit != -1 && c.Request.ContentLength > c.app.config.BodyLimit { + return NewNFError(413, "Content Too Large") + } + + return nil +} + +func (c *Ctx) Param(key string) string { + return c.params[key] +} + +func (c *Ctx) Form(key string) string { + return c.Request.FormValue(key) +} + +func (c *Ctx) Query(key string) string { + return c.Request.URL.Query().Get(key) +} + +func (c *Ctx) Get(key string, defaultValue ...string) string { + value := c.Request.Header.Get(key) + if value == "" && len(defaultValue) > 0 { + return defaultValue[0] + } + + return value +} + +func (c *Ctx) BodyParser(out interface{}) error { + var ( + err error + ctype = strings.ToLower(c.Request.Header.Get("Content-Type")) + ) + + log.Printf("BodyParser: Content-Type=%s", ctype) + + ctype = parseVendorSpecificContentType(ctype) + + ctypeEnd := strings.IndexByte(ctype, ';') + if ctypeEnd != -1 { + ctype = ctype[:ctypeEnd] + } + + if strings.HasSuffix(ctype, "json") { + bs, err := io.ReadAll(c.Request.Body) + if err != nil { + log.Printf("BodyParser: read all err=%v", err) + return err + } + + c.Request.Body = io.NopCloser(bytes.NewReader(bs)) + + return json.Unmarshal(bs, out) + } + + if strings.HasPrefix(ctype, MIMEApplicationForm) { + + if err = c.Request.ParseForm(); err != nil { + return NewNFError(400, err.Error()) + } + + return parseToStruct("form", out, c.Request.Form) + } + + if strings.HasPrefix(ctype, MIMEMultipartForm) { + if err = c.Request.ParseMultipartForm(c.app.config.BodyLimit); err != nil { + return NewNFError(400, err.Error()) + } + + return parseToStruct("form", out, c.Request.PostForm) + } + + return NewNFError(422, "Unprocessable Content") +} + +func (c *Ctx) QueryParser(out interface{}) error { + return parseToStruct("query", out, c.Request.URL.Query()) +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..2d408bf --- /dev/null +++ b/error.go @@ -0,0 +1,16 @@ +package nf + +import "strconv" + +type Err struct { + Status int + Msg string +} + +func (n Err) Error() string { + return strconv.Itoa(n.Status) + " " + n.Msg +} + +func NewNFError(status int, msg string) Err { + return Err{Status: status, Msg: msg} +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8b5aac3 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/loveuer/nf + +go 1.20 diff --git a/group.go b/group.go new file mode 100644 index 0000000..2facd20 --- /dev/null +++ b/group.go @@ -0,0 +1,61 @@ +package nf + +import ( + "log" + "net/http" +) + +type RouterGroup struct { + prefix string + middlewares []HandlerFunc // support middleware + parent *RouterGroup // support nesting + app *App // all groups share a Engine instance +} + +// Group is defined to create a new RouterGroup +// remember all groups share the same Engine instance +func (group *RouterGroup) Group(prefix string) *RouterGroup { + app := group.app + newGroup := &RouterGroup{ + prefix: group.prefix + prefix, + parent: group, + app: app, + } + app.groups = append(app.groups, newGroup) + return newGroup +} + +func (group *RouterGroup) addRoute(method string, comp string, handlers ...HandlerFunc) { + verifyHandlers(comp, handlers...) + pattern := group.prefix + comp + log.Printf("Add Route %4s - %s", method, pattern) + group.app.router.addRoute(method, pattern, handlers...) +} + +func (group *RouterGroup) Get(pattern string, handlers ...HandlerFunc) { + group.addRoute(http.MethodGet, pattern, handlers...) +} + +func (group *RouterGroup) Post(pattern string, handlers ...HandlerFunc) { + group.addRoute(http.MethodPost, pattern, handlers...) +} + +func (group *RouterGroup) Put(pattern string, handlers ...HandlerFunc) { + group.addRoute(http.MethodPut, pattern, handlers...) +} + +func (group *RouterGroup) Delete(pattern string, handlers ...HandlerFunc) { + group.addRoute(http.MethodDelete, pattern, handlers...) +} + +func (group *RouterGroup) Patch(pattern string, handlers ...HandlerFunc) { + group.addRoute(http.MethodPatch, pattern, handlers...) +} + +func (group *RouterGroup) Head(pattern string, handlers ...HandlerFunc) { + group.addRoute(http.MethodHead, pattern, handlers...) +} + +func (group *RouterGroup) Use(middlewares ...HandlerFunc) { + group.middlewares = append(group.middlewares, middlewares...) +} diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..64f3a2a --- /dev/null +++ b/handler.go @@ -0,0 +1,3 @@ +package nf + +type HandlerFunc func(*Ctx) error diff --git a/internal/schema/LICENSE b/internal/schema/LICENSE new file mode 100644 index 0000000..0e5fb87 --- /dev/null +++ b/internal/schema/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 Rodrigo Moraes. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/internal/schema/cache.go b/internal/schema/cache.go new file mode 100644 index 0000000..bf21697 --- /dev/null +++ b/internal/schema/cache.go @@ -0,0 +1,305 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schema + +import ( + "errors" + "reflect" + "strconv" + "strings" + "sync" +) + +var errInvalidPath = errors.New("schema: invalid path") + +// newCache returns a new cache. +func newCache() *cache { + c := cache{ + m: make(map[reflect.Type]*structInfo), + regconv: make(map[reflect.Type]Converter), + tag: "schema", + } + return &c +} + +// cache caches meta-data about a struct. +type cache struct { + l sync.RWMutex + m map[reflect.Type]*structInfo + regconv map[reflect.Type]Converter + tag string +} + +// registerConverter registers a converter function for a custom type. +func (c *cache) registerConverter(value interface{}, converterFunc Converter) { + c.regconv[reflect.TypeOf(value)] = converterFunc +} + +// parsePath parses a path in dotted notation verifying that it is a valid +// path to a struct field. +// +// It returns "path parts" which contain indices to fields to be used by +// reflect.Value.FieldByString(). Multiple parts are required for slices of +// structs. +func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { + var struc *structInfo + var field *fieldInfo + var index64 int64 + var err error + parts := make([]pathPart, 0) + path := make([]string, 0) + keys := strings.Split(p, ".") + for i := 0; i < len(keys); i++ { + if t.Kind() != reflect.Struct { + return nil, errInvalidPath + } + if struc = c.get(t); struc == nil { + return nil, errInvalidPath + } + if field = struc.get(keys[i]); field == nil { + return nil, errInvalidPath + } + // Valid field. Append index. + path = append(path, field.name) + if field.isSliceOfStructs && (!field.unmarshalerInfo.IsValid || (field.unmarshalerInfo.IsValid && field.unmarshalerInfo.IsSliceElement)) { + // Parse a special case: slices of structs. + // i+1 must be the slice index. + // + // Now that struct can implements TextUnmarshaler interface, + // we don't need to force the struct's fields to appear in the path. + // So checking i+2 is not necessary anymore. + i++ + if i+1 > len(keys) { + return nil, errInvalidPath + } + if index64, err = strconv.ParseInt(keys[i], 10, 0); err != nil { + return nil, errInvalidPath + } + parts = append(parts, pathPart{ + path: path, + field: field, + index: int(index64), + }) + path = make([]string, 0) + + // Get the next struct type, dropping ptrs. + if field.typ.Kind() == reflect.Ptr { + t = field.typ.Elem() + } else { + t = field.typ + } + if t.Kind() == reflect.Slice { + t = t.Elem() + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + } + } else if field.typ.Kind() == reflect.Ptr { + t = field.typ.Elem() + } else { + t = field.typ + } + } + // Add the remaining. + parts = append(parts, pathPart{ + path: path, + field: field, + index: -1, + }) + return parts, nil +} + +// get returns a cached structInfo, creating it if necessary. +func (c *cache) get(t reflect.Type) *structInfo { + c.l.RLock() + info := c.m[t] + c.l.RUnlock() + if info == nil { + info = c.create(t, "") + c.l.Lock() + c.m[t] = info + c.l.Unlock() + } + return info +} + +// create creates a structInfo with meta-data about a struct. +func (c *cache) create(t reflect.Type, parentAlias string) *structInfo { + info := &structInfo{} + var anonymousInfos []*structInfo + for i := 0; i < t.NumField(); i++ { + if f := c.createField(t.Field(i), parentAlias); f != nil { + info.fields = append(info.fields, f) + if ft := indirectType(f.typ); ft.Kind() == reflect.Struct && f.isAnonymous { + anonymousInfos = append(anonymousInfos, c.create(ft, f.canonicalAlias)) + } + } + } + for i, a := range anonymousInfos { + others := []*structInfo{info} + others = append(others, anonymousInfos[:i]...) + others = append(others, anonymousInfos[i+1:]...) + for _, f := range a.fields { + if !containsAlias(others, f.alias) { + info.fields = append(info.fields, f) + } + } + } + return info +} + +// createField creates a fieldInfo for the given field. +func (c *cache) createField(field reflect.StructField, parentAlias string) *fieldInfo { + alias, options := fieldAlias(field, c.tag) + if alias == "-" { + // Ignore this field. + return nil + } + canonicalAlias := alias + if parentAlias != "" { + canonicalAlias = parentAlias + "." + alias + } + // Check if the type is supported and don't cache it if not. + // First let's get the basic type. + isSlice, isStruct := false, false + ft := field.Type + m := isTextUnmarshaler(reflect.Zero(ft)) + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if isSlice = ft.Kind() == reflect.Slice; isSlice { + ft = ft.Elem() + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + } + if ft.Kind() == reflect.Array { + ft = ft.Elem() + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + } + if isStruct = ft.Kind() == reflect.Struct; !isStruct { + if c.converter(ft) == nil && builtinConverters[ft.Kind()] == nil { + // Type is not supported. + return nil + } + } + + return &fieldInfo{ + typ: field.Type, + name: field.Name, + alias: alias, + canonicalAlias: canonicalAlias, + unmarshalerInfo: m, + isSliceOfStructs: isSlice && isStruct, + isAnonymous: field.Anonymous, + isRequired: options.Contains("required"), + } +} + +// converter returns the converter for a type. +func (c *cache) converter(t reflect.Type) Converter { + return c.regconv[t] +} + +// ---------------------------------------------------------------------------- + +type structInfo struct { + fields []*fieldInfo +} + +func (i *structInfo) get(alias string) *fieldInfo { + for _, field := range i.fields { + if strings.EqualFold(field.alias, alias) { + return field + } + } + return nil +} + +func containsAlias(infos []*structInfo, alias string) bool { + for _, info := range infos { + if info.get(alias) != nil { + return true + } + } + return false +} + +type fieldInfo struct { + typ reflect.Type + // name is the field name in the struct. + name string + alias string + // canonicalAlias is almost the same as the alias, but is prefixed with + // an embedded struct field alias in dotted notation if this field is + // promoted from the struct. + // For instance, if the alias is "N" and this field is an embedded field + // in a struct "X", canonicalAlias will be "X.N". + canonicalAlias string + // unmarshalerInfo contains information regarding the + // encoding.TextUnmarshaler implementation of the field type. + unmarshalerInfo unmarshaler + // isSliceOfStructs indicates if the field type is a slice of structs. + isSliceOfStructs bool + // isAnonymous indicates whether the field is embedded in the struct. + isAnonymous bool + isRequired bool +} + +func (f *fieldInfo) paths(prefix string) []string { + if f.alias == f.canonicalAlias { + return []string{prefix + f.alias} + } + return []string{prefix + f.alias, prefix + f.canonicalAlias} +} + +type pathPart struct { + field *fieldInfo + path []string // path to the field: walks structs using field names. + index int // struct index in slices of structs. +} + +// ---------------------------------------------------------------------------- + +func indirectType(typ reflect.Type) reflect.Type { + if typ.Kind() == reflect.Ptr { + return typ.Elem() + } + return typ +} + +// fieldAlias parses a field tag to get a field alias. +func fieldAlias(field reflect.StructField, tagName string) (alias string, options tagOptions) { + if tag := field.Tag.Get(tagName); tag != "" { + alias, options = parseTag(tag) + } + if alias == "" { + alias = field.Name + } + return alias, options +} + +// tagOptions is the string following a comma in a struct field's tag, or +// the empty string. It does not include the leading comma. +type tagOptions []string + +// parseTag splits a struct field's url tag into its name and comma-separated +// options. +func parseTag(tag string) (string, tagOptions) { + s := strings.Split(tag, ",") + return s[0], s[1:] +} + +// Contains checks whether the tagOptions contains the specified option. +func (o tagOptions) Contains(option string) bool { + for _, s := range o { + if s == option { + return true + } + } + return false +} diff --git a/internal/schema/converter.go b/internal/schema/converter.go new file mode 100644 index 0000000..4f2116a --- /dev/null +++ b/internal/schema/converter.go @@ -0,0 +1,145 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schema + +import ( + "reflect" + "strconv" +) + +type Converter func(string) reflect.Value + +var ( + invalidValue = reflect.Value{} + boolType = reflect.Bool + float32Type = reflect.Float32 + float64Type = reflect.Float64 + intType = reflect.Int + int8Type = reflect.Int8 + int16Type = reflect.Int16 + int32Type = reflect.Int32 + int64Type = reflect.Int64 + stringType = reflect.String + uintType = reflect.Uint + uint8Type = reflect.Uint8 + uint16Type = reflect.Uint16 + uint32Type = reflect.Uint32 + uint64Type = reflect.Uint64 +) + +// Default converters for basic types. +var builtinConverters = map[reflect.Kind]Converter{ + boolType: convertBool, + float32Type: convertFloat32, + float64Type: convertFloat64, + intType: convertInt, + int8Type: convertInt8, + int16Type: convertInt16, + int32Type: convertInt32, + int64Type: convertInt64, + stringType: convertString, + uintType: convertUint, + uint8Type: convertUint8, + uint16Type: convertUint16, + uint32Type: convertUint32, + uint64Type: convertUint64, +} + +func convertBool(value string) reflect.Value { + if value == "on" { + return reflect.ValueOf(true) + } else if v, err := strconv.ParseBool(value); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} + +func convertFloat32(value string) reflect.Value { + if v, err := strconv.ParseFloat(value, 32); err == nil { + return reflect.ValueOf(float32(v)) + } + return invalidValue +} + +func convertFloat64(value string) reflect.Value { + if v, err := strconv.ParseFloat(value, 64); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} + +func convertInt(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 0); err == nil { + return reflect.ValueOf(int(v)) + } + return invalidValue +} + +func convertInt8(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 8); err == nil { + return reflect.ValueOf(int8(v)) + } + return invalidValue +} + +func convertInt16(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 16); err == nil { + return reflect.ValueOf(int16(v)) + } + return invalidValue +} + +func convertInt32(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 32); err == nil { + return reflect.ValueOf(int32(v)) + } + return invalidValue +} + +func convertInt64(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 64); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} + +func convertString(value string) reflect.Value { + return reflect.ValueOf(value) +} + +func convertUint(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 0); err == nil { + return reflect.ValueOf(uint(v)) + } + return invalidValue +} + +func convertUint8(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 8); err == nil { + return reflect.ValueOf(uint8(v)) + } + return invalidValue +} + +func convertUint16(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 16); err == nil { + return reflect.ValueOf(uint16(v)) + } + return invalidValue +} + +func convertUint32(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 32); err == nil { + return reflect.ValueOf(uint32(v)) + } + return invalidValue +} + +func convertUint64(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 64); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} diff --git a/internal/schema/decoder.go b/internal/schema/decoder.go new file mode 100644 index 0000000..b63c45e --- /dev/null +++ b/internal/schema/decoder.go @@ -0,0 +1,534 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schema + +import ( + "encoding" + "errors" + "fmt" + "reflect" + "strings" +) + +// NewDecoder returns a new Decoder. +func NewDecoder() *Decoder { + return &Decoder{cache: newCache()} +} + +// Decoder decodes values from a map[string][]string to a struct. +type Decoder struct { + cache *cache + zeroEmpty bool + ignoreUnknownKeys bool +} + +// SetAliasTag changes the tag used to locate custom field aliases. +// The default tag is "schema". +func (d *Decoder) SetAliasTag(tag string) { + d.cache.tag = tag +} + +// ZeroEmpty controls the behaviour when the decoder encounters empty values +// in a map. +// If z is true and a key in the map has the empty string as a value +// then the corresponding struct field is set to the zero value. +// If z is false then empty strings are ignored. +// +// The default value is false, that is empty values do not change +// the value of the struct field. +func (d *Decoder) ZeroEmpty(z bool) { + d.zeroEmpty = z +} + +// IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown +// keys in the map. +// If i is true and an unknown field is encountered, it is ignored. This is +// similar to how unknown keys are handled by encoding/json. +// If i is false then Decode will return an error. Note that any valid keys +// will still be decoded in to the target struct. +// +// To preserve backwards compatibility, the default value is false. +func (d *Decoder) IgnoreUnknownKeys(i bool) { + d.ignoreUnknownKeys = i +} + +// RegisterConverter registers a converter function for a custom type. +func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) { + d.cache.registerConverter(value, converterFunc) +} + +// Decode decodes a map[string][]string to a struct. +// +// The first parameter must be a pointer to a struct. +// +// The second parameter is a map, typically url.Values from an HTTP request. +// Keys are "paths" in dotted notation to the struct fields and nested structs. +// +// See the package documentation for a full explanation of the mechanics. +func (d *Decoder) Decode(dst interface{}, src map[string][]string) error { + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return errors.New("schema: interface must be a pointer to struct") + } + v = v.Elem() + t := v.Type() + multiError := MultiError{} + for path, values := range src { + if parts, err := d.cache.parsePath(path, t); err == nil { + if err = d.decode(v, path, parts, values); err != nil { + multiError[path] = err + } + } else if !d.ignoreUnknownKeys { + multiError[path] = UnknownKeyError{Key: path} + } + } + multiError.merge(d.checkRequired(t, src)) + if len(multiError) > 0 { + return multiError + } + return nil +} + +// checkRequired checks whether required fields are empty +// +// check type t recursively if t has struct fields. +// +// src is the source map for decoding, we use it here to see if those required fields are included in src +func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string) MultiError { + m, errs := d.findRequiredFields(t, "", "") + for key, fields := range m { + if isEmptyFields(fields, src) { + errs[key] = EmptyFieldError{Key: key} + } + } + return errs +} + +// findRequiredFields recursively searches the struct type t for required fields. +// +// canonicalPrefix and searchPrefix are used to resolve full paths in dotted notation +// for nested struct fields. canonicalPrefix is a complete path which never omits +// any embedded struct fields. searchPrefix is a user-friendly path which may omit +// some embedded struct fields to point promoted fields. +func (d *Decoder) findRequiredFields(t reflect.Type, canonicalPrefix, searchPrefix string) (map[string][]fieldWithPrefix, MultiError) { + struc := d.cache.get(t) + if struc == nil { + // unexpect, cache.get never return nil + return nil, MultiError{canonicalPrefix + "*": errors.New("cache fail")} + } + + m := map[string][]fieldWithPrefix{} + errs := MultiError{} + for _, f := range struc.fields { + if f.typ.Kind() == reflect.Struct { + fcprefix := canonicalPrefix + f.canonicalAlias + "." + for _, fspath := range f.paths(searchPrefix) { + fm, ferrs := d.findRequiredFields(f.typ, fcprefix, fspath+".") + for key, fields := range fm { + m[key] = append(m[key], fields...) + } + errs.merge(ferrs) + } + } + if f.isRequired { + key := canonicalPrefix + f.canonicalAlias + m[key] = append(m[key], fieldWithPrefix{ + fieldInfo: f, + prefix: searchPrefix, + }) + } + } + return m, errs +} + +type fieldWithPrefix struct { + *fieldInfo + prefix string +} + +// isEmptyFields returns true if all of specified fields are empty. +func isEmptyFields(fields []fieldWithPrefix, src map[string][]string) bool { + for _, f := range fields { + for _, path := range f.paths(f.prefix) { + v, ok := src[path] + if ok && !isEmpty(f.typ, v) { + return false + } + for key := range src { + // issue references: + // https://github.com/gofiber/fiber/issues/1414 + // https://github.com/gorilla/schema/issues/176 + nested := strings.IndexByte(key, '.') != -1 + + // for non required nested structs + c1 := strings.HasSuffix(f.prefix, ".") && key == path + + // for required nested structs + c2 := f.prefix == "" && nested && strings.HasPrefix(key, path) + + // for non nested fields + c3 := f.prefix == "" && !nested && key == path + if !isEmpty(f.typ, src[key]) && (c1 || c2 || c3) { + return false + } + } + } + } + return true +} + +// isEmpty returns true if value is empty for specific type +func isEmpty(t reflect.Type, value []string) bool { + if len(value) == 0 { + return true + } + switch t.Kind() { + case boolType, float32Type, float64Type, intType, int8Type, int32Type, int64Type, stringType, uint8Type, uint16Type, uint32Type, uint64Type: + return len(value[0]) == 0 + } + return false +} + +// decode fills a struct field using a parsed path. +func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values []string) error { + // Get the field walking the struct fields by index. + for _, name := range parts[0].path { + if v.Type().Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + // alloc embedded structs + if v.Type().Kind() == reflect.Struct { + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous { + field.Set(reflect.New(field.Type().Elem())) + } + } + } + + v = v.FieldByName(name) + } + // Don't even bother for unexported fields. + if !v.CanSet() { + return nil + } + + // Dereference if needed. + t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + if v.IsNil() { + v.Set(reflect.New(t)) + } + v = v.Elem() + } + + // Slice of structs. Let's go recursive. + if len(parts) > 1 { + idx := parts[0].index + if v.IsNil() || v.Len() < idx+1 { + value := reflect.MakeSlice(t, idx+1, idx+1) + if v.Len() < idx+1 { + // Resize it. + reflect.Copy(value, v) + } + v.Set(value) + } + return d.decode(v.Index(idx), path, parts[1:], values) + } + + // Get the converter early in case there is one for a slice type. + conv := d.cache.converter(t) + m := isTextUnmarshaler(v) + if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement { + var items []reflect.Value + elemT := t.Elem() + isPtrElem := elemT.Kind() == reflect.Ptr + if isPtrElem { + elemT = elemT.Elem() + } + + // Try to get a converter for the element type. + conv := d.cache.converter(elemT) + if conv == nil { + conv = builtinConverters[elemT.Kind()] + if conv == nil { + // As we are not dealing with slice of structs here, we don't need to check if the type + // implements TextUnmarshaler interface + return fmt.Errorf("schema: converter not found for %v", elemT) + } + } + + for key, value := range values { + if value == "" { + if d.zeroEmpty { + items = append(items, reflect.Zero(elemT)) + } + } else if m.IsValid { + u := reflect.New(elemT) + if m.IsSliceElementPtr { + u = reflect.New(reflect.PtrTo(elemT).Elem()) + } + if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: key, + Err: err, + } + } + if m.IsSliceElementPtr { + items = append(items, u.Elem().Addr()) + } else if u.Kind() == reflect.Ptr { + items = append(items, u.Elem()) + } else { + items = append(items, u) + } + } else if item := conv(value); item.IsValid() { + if isPtrElem { + ptr := reflect.New(elemT) + ptr.Elem().Set(item) + item = ptr + } + if item.Type() != elemT && !isPtrElem { + item = item.Convert(elemT) + } + items = append(items, item) + } else { + if strings.Contains(value, ",") { + values := strings.Split(value, ",") + for _, value := range values { + if value == "" { + if d.zeroEmpty { + items = append(items, reflect.Zero(elemT)) + } + } else if item := conv(value); item.IsValid() { + if isPtrElem { + ptr := reflect.New(elemT) + ptr.Elem().Set(item) + item = ptr + } + if item.Type() != elemT && !isPtrElem { + item = item.Convert(elemT) + } + items = append(items, item) + } else { + return ConversionError{ + Key: path, + Type: elemT, + Index: key, + } + } + } + } else { + return ConversionError{ + Key: path, + Type: elemT, + Index: key, + } + } + } + } + value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...) + v.Set(value) + } else { + val := "" + // Use the last value provided if any values were provided + if len(values) > 0 { + val = values[len(values)-1] + } + + if conv != nil { + if value := conv(val); value.IsValid() { + v.Set(value.Convert(t)) + } else { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + } + } + } else if m.IsValid { + if m.IsPtr { + u := reflect.New(v.Type()) + if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + Err: err, + } + } + v.Set(reflect.Indirect(u)) + } else { + // If the value implements the encoding.TextUnmarshaler interface + // apply UnmarshalText as the converter + if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + Err: err, + } + } + } + } else if val == "" { + if d.zeroEmpty { + v.Set(reflect.Zero(t)) + } + } else if conv := builtinConverters[t.Kind()]; conv != nil { + if value := conv(val); value.IsValid() { + v.Set(value.Convert(t)) + } else { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + } + } + } else { + return fmt.Errorf("schema: converter not found for %v", t) + } + } + return nil +} + +func isTextUnmarshaler(v reflect.Value) unmarshaler { + // Create a new unmarshaller instance + m := unmarshaler{} + if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid { + return m + } + // As the UnmarshalText function should be applied to the pointer of the + // type, we check that type to see if it implements the necessary + // method. + if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid { + m.IsPtr = true + return m + } + + // if v is []T or *[]T create new T + t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Slice { + // Check if the slice implements encoding.TextUnmarshaller + if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid { + return m + } + // If t is a pointer slice, check if its elements implement + // encoding.TextUnmarshaler + m.IsSliceElement = true + if t = t.Elem(); t.Kind() == reflect.Ptr { + t = reflect.PtrTo(t.Elem()) + v = reflect.Zero(t) + m.IsSliceElementPtr = true + m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler) + return m + } + } + + v = reflect.New(t) + m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler) + return m +} + +// TextUnmarshaler helpers ---------------------------------------------------- +// unmarshaller contains information about a TextUnmarshaler type +type unmarshaler struct { + Unmarshaler encoding.TextUnmarshaler + // IsValid indicates whether the resolved type indicated by the other + // flags implements the encoding.TextUnmarshaler interface. + IsValid bool + // IsPtr indicates that the resolved type is the pointer of the original + // type. + IsPtr bool + // IsSliceElement indicates that the resolved type is a slice element of + // the original type. + IsSliceElement bool + // IsSliceElementPtr indicates that the resolved type is a pointer to a + // slice element of the original type. + IsSliceElementPtr bool +} + +// Errors --------------------------------------------------------------------- + +// ConversionError stores information about a failed conversion. +type ConversionError struct { + Key string // key from the source map. + Type reflect.Type // expected type of elem + Index int // index for multi-value fields; -1 for single-value fields. + Err error // low-level error (when it exists) +} + +func (e ConversionError) Error() string { + var output string + + if e.Index < 0 { + output = fmt.Sprintf("schema: error converting value for %q", e.Key) + } else { + output = fmt.Sprintf("schema: error converting value for index %d of %q", + e.Index, e.Key) + } + + if e.Err != nil { + output = fmt.Sprintf("%s. Details: %s", output, e.Err) + } + + return output +} + +// UnknownKeyError stores information about an unknown key in the source map. +type UnknownKeyError struct { + Key string // key from the source map. +} + +func (e UnknownKeyError) Error() string { + return fmt.Sprintf("schema: invalid path %q", e.Key) +} + +// EmptyFieldError stores information about an empty required field. +type EmptyFieldError struct { + Key string // required key in the source map. +} + +func (e EmptyFieldError) Error() string { + return fmt.Sprintf("%v is empty", e.Key) +} + +// MultiError stores multiple decoding errors. +// +// Borrowed from the App Engine SDK. +type MultiError map[string]error + +func (e MultiError) Error() string { + s := "" + for _, err := range e { + s = err.Error() + break + } + switch len(e) { + case 0: + return "(0 errors)" + case 1: + return s + case 2: + return s + " (and 1 other error)" + } + return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1) +} + +func (e MultiError) merge(errors MultiError) { + for key, err := range errors { + if e[key] == nil { + e[key] = err + } + } +} diff --git a/internal/schema/doc.go b/internal/schema/doc.go new file mode 100644 index 0000000..fff0fe7 --- /dev/null +++ b/internal/schema/doc.go @@ -0,0 +1,148 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package gorilla/schema fills a struct with form values. + +The basic usage is really simple. Given this struct: + + type Person struct { + Name string + Phone string + } + +...we can fill it passing a map to the Decode() function: + + values := map[string][]string{ + "Name": {"John"}, + "Phone": {"999-999-999"}, + } + person := new(Person) + decoder := schema.NewDecoder() + decoder.Decode(person, values) + +This is just a simple example and it doesn't make a lot of sense to create +the map manually. Typically it will come from a http.Request object and +will be of type url.Values, http.Request.Form, or http.Request.MultipartForm: + + func MyHandler(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + + if err != nil { + // Handle error + } + + decoder := schema.NewDecoder() + // r.PostForm is a map of our POST form values + err := decoder.Decode(person, r.PostForm) + + if err != nil { + // Handle error + } + + // Do something with person.Name or person.Phone + } + +Note: it is a good idea to set a Decoder instance as a package global, +because it caches meta-data about structs, and an instance can be shared safely: + + var decoder = schema.NewDecoder() + +To define custom names for fields, use a struct tag "schema". To not populate +certain fields, use a dash for the name and it will be ignored: + + type Person struct { + Name string `schema:"name"` // custom name + Phone string `schema:"phone"` // custom name + Admin bool `schema:"-"` // this field is never set + } + +The supported field types in the destination struct are: + + - bool + - float variants (float32, float64) + - int variants (int, int8, int16, int32, int64) + - string + - uint variants (uint, uint8, uint16, uint32, uint64) + - struct + - a pointer to one of the above types + - a slice or a pointer to a slice of one of the above types + +Non-supported types are simply ignored, however custom types can be registered +to be converted. + +To fill nested structs, keys must use a dotted notation as the "path" for the +field. So for example, to fill the struct Person below: + + type Phone struct { + Label string + Number string + } + + type Person struct { + Name string + Phone Phone + } + +...the source map must have the keys "Name", "Phone.Label" and "Phone.Number". +This means that an HTML form to fill a Person struct must look like this: + +
+ + + +
+ +Single values are filled using the first value for a key from the source map. +Slices are filled using all values for a key from the source map. So to fill +a Person with multiple Phone values, like: + + type Person struct { + Name string + Phones []Phone + } + +...an HTML form that accepts three Phone values would look like this: + +
+ + + + + + + +
+ +Notice that only for slices of structs the slice index is required. +This is needed for disambiguation: if the nested struct also had a slice +field, we could not translate multiple values to it if we did not use an +index for the parent struct. + +There's also the possibility to create a custom type that implements the +TextUnmarshaler interface, and in this case there's no need to register +a converter, like: + + type Person struct { + Emails []Email + } + + type Email struct { + *mail.Address + } + + func (e *Email) UnmarshalText(text []byte) (err error) { + e.Address, err = mail.ParseAddress(string(text)) + return + } + +...an HTML form that accepts three Email values would look like this: + +
+ + + +
+*/ +package schema diff --git a/internal/schema/encoder.go b/internal/schema/encoder.go new file mode 100644 index 0000000..c01de00 --- /dev/null +++ b/internal/schema/encoder.go @@ -0,0 +1,202 @@ +package schema + +import ( + "errors" + "fmt" + "reflect" + "strconv" +) + +type encoderFunc func(reflect.Value) string + +// Encoder encodes values from a struct into url.Values. +type Encoder struct { + cache *cache + regenc map[reflect.Type]encoderFunc +} + +// NewEncoder returns a new Encoder with defaults. +func NewEncoder() *Encoder { + return &Encoder{cache: newCache(), regenc: make(map[reflect.Type]encoderFunc)} +} + +// Encode encodes a struct into map[string][]string. +// +// Intended for use with url.Values. +func (e *Encoder) Encode(src interface{}, dst map[string][]string) error { + v := reflect.ValueOf(src) + + return e.encode(v, dst) +} + +// RegisterEncoder registers a converter for encoding a custom type. +func (e *Encoder) RegisterEncoder(value interface{}, encoder func(reflect.Value) string) { + e.regenc[reflect.TypeOf(value)] = encoder +} + +// SetAliasTag changes the tag used to locate custom field aliases. +// The default tag is "schema". +func (e *Encoder) SetAliasTag(tag string) { + e.cache.tag = tag +} + +// isValidStructPointer test if input value is a valid struct pointer. +func isValidStructPointer(v reflect.Value) bool { + return v.Type().Kind() == reflect.Ptr && v.Elem().IsValid() && v.Elem().Type().Kind() == reflect.Struct +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Func: + case reflect.Map, reflect.Slice: + return v.IsNil() || v.Len() == 0 + case reflect.Array: + z := true + for i := 0; i < v.Len(); i++ { + z = z && isZero(v.Index(i)) + } + return z + case reflect.Struct: + type zero interface { + IsZero() bool + } + if v.Type().Implements(reflect.TypeOf((*zero)(nil)).Elem()) { + iz := v.MethodByName("IsZero").Call([]reflect.Value{})[0] + return iz.Interface().(bool) + } + z := true + for i := 0; i < v.NumField(); i++ { + z = z && isZero(v.Field(i)) + } + return z + } + // Compare other types directly: + z := reflect.Zero(v.Type()) + return v.Interface() == z.Interface() +} + +func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return errors.New("schema: interface must be a struct") + } + t := v.Type() + + errors := MultiError{} + + for i := 0; i < v.NumField(); i++ { + name, opts := fieldAlias(t.Field(i), e.cache.tag) + if name == "-" { + continue + } + + // Encode struct pointer types if the field is a valid pointer and a struct. + if isValidStructPointer(v.Field(i)) { + _ = e.encode(v.Field(i).Elem(), dst) + continue + } + + encFunc := typeEncoder(v.Field(i).Type(), e.regenc) + + // Encode non-slice types and custom implementations immediately. + if encFunc != nil { + value := encFunc(v.Field(i)) + if opts.Contains("omitempty") && isZero(v.Field(i)) { + continue + } + + dst[name] = append(dst[name], value) + continue + } + + if v.Field(i).Type().Kind() == reflect.Struct { + _ = e.encode(v.Field(i), dst) + continue + } + + if v.Field(i).Type().Kind() == reflect.Slice { + encFunc = typeEncoder(v.Field(i).Type().Elem(), e.regenc) + } + + if encFunc == nil { + errors[v.Field(i).Type().String()] = fmt.Errorf("schema: encoder not found for %v", v.Field(i)) + continue + } + + // Encode a slice. + if v.Field(i).Len() == 0 && opts.Contains("omitempty") { + continue + } + + dst[name] = []string{} + for j := 0; j < v.Field(i).Len(); j++ { + dst[name] = append(dst[name], encFunc(v.Field(i).Index(j))) + } + } + + if len(errors) > 0 { + return errors + } + return nil +} + +func typeEncoder(t reflect.Type, reg map[reflect.Type]encoderFunc) encoderFunc { + if f, ok := reg[t]; ok { + return f + } + + switch t.Kind() { + case reflect.Bool: + return encodeBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return encodeUint + case reflect.Float32: + return encodeFloat32 + case reflect.Float64: + return encodeFloat64 + case reflect.Ptr: + f := typeEncoder(t.Elem(), reg) + return func(v reflect.Value) string { + if v.IsNil() { + return "null" + } + return f(v.Elem()) + } + case reflect.String: + return encodeString + default: + return nil + } +} + +func encodeBool(v reflect.Value) string { + return strconv.FormatBool(v.Bool()) +} + +func encodeInt(v reflect.Value) string { + return strconv.FormatInt(int64(v.Int()), 10) +} + +func encodeUint(v reflect.Value) string { + return strconv.FormatUint(uint64(v.Uint()), 10) +} + +func encodeFloat(v reflect.Value, bits int) string { + return strconv.FormatFloat(v.Float(), 'f', 6, bits) +} + +func encodeFloat32(v reflect.Value) string { + return encodeFloat(v, 32) +} + +func encodeFloat64(v reflect.Value) string { + return encodeFloat(v, 64) +} + +func encodeString(v reflect.Value) string { + return v.String() +} diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..b4d14f2 --- /dev/null +++ b/middleware.go @@ -0,0 +1,23 @@ +package nf + +import ( + "fmt" + "os" + "runtime/debug" +) + +func NewRecover(enableStackTrace bool) HandlerFunc { + return func(c *Ctx) error { + defer func() { + if r := recover(); r != nil { + if enableStackTrace { + os.Stderr.WriteString(fmt.Sprintf("recovered from panic: %v\n%s\n", r, debug.Stack())) + } else { + os.Stderr.WriteString(fmt.Sprintf("recovered from panic: %v\n", r)) + } + } + }() + + return c.Next() + } +} diff --git a/nf.go b/nf.go new file mode 100644 index 0000000..527541f --- /dev/null +++ b/nf.go @@ -0,0 +1,45 @@ +package nf + +const ( + banner = " _ _ _ ___ _ \n | \\| |___| |_ | __|__ _ _ _ _ __| |\n | .` / _ \\ _| | _/ _ \\ || | ' \\/ _` |\n |_|\\_\\___/\\__| |_|\\___/\\_,_|_||_\\__,_|\n " +) + +type Map map[string]interface{} + +type Config struct { + // Default: 4 * 1024 * 1024 + BodyLimit int64 `json:"-"` + DisableBanner bool `json:"-"` + DisableLogger bool `json:"-"` + DisableRecover bool `json:"-"` +} + +var ( + defaultConfig = &Config{ + BodyLimit: 4 * 1024 * 1024, + } +) + +func New(config ...Config) *App { + app := &App{ + router: newRouter(), + } + + if len(config) > 0 { + app.config = &config[0] + if app.config.BodyLimit == 0 { + app.config.BodyLimit = defaultConfig.BodyLimit + } + } else { + app.config = defaultConfig + } + + app.RouterGroup = &RouterGroup{app: app} + app.groups = []*RouterGroup{app.RouterGroup} + + if !app.config.DisableRecover { + app.Use(NewRecover(true)) + } + + return app +} diff --git a/resp.go b/resp.go new file mode 100644 index 0000000..7c9a03a --- /dev/null +++ b/resp.go @@ -0,0 +1,49 @@ +package nf + +import ( + "encoding/json" + "fmt" +) + +func (c *Ctx) Status(code int) *Ctx { + c.StatusCode = code + c.Writer.WriteHeader(code) + return c +} + +func (c *Ctx) SetHeader(key string, value string) { + c.Writer.Header().Set(key, value) +} + +func (c *Ctx) SendString(data string) error { + c.SetHeader("Content-Type", "text/plain") + _, err := c.Write([]byte(data)) + return err +} + +func (c *Ctx) Writef(format string, values ...interface{}) (int, error) { + c.SetHeader("Content-Type", "text/plain") + return c.Writer.Write([]byte(fmt.Sprintf(format, values...))) +} + +func (c *Ctx) JSON(data interface{}) error { + c.SetHeader("Content-Type", "application/json") + + encoder := json.NewEncoder(c.Writer) + + if err := encoder.Encode(data); err != nil { + return err + } + + return nil +} + +func (c *Ctx) Write(data []byte) (int, error) { + return c.Writer.Write(data) +} + +func (c *Ctx) HTML(html string) error { + c.SetHeader("Content-Type", "text/html") + _, err := c.Writer.Write([]byte(html)) + return err +} diff --git a/router.go b/router.go new file mode 100644 index 0000000..27cd344 --- /dev/null +++ b/router.go @@ -0,0 +1,99 @@ +package nf + +import "strings" + +type router struct { + roots map[string]*_node + handlers map[string][]HandlerFunc +} + +func newRouter() *router { + return &router{ + roots: make(map[string]*_node), + handlers: make(map[string][]HandlerFunc), + } +} + +// Only one * is allowed +func parsePattern(pattern string) []string { + vs := strings.Split(pattern, "/") + + parts := make([]string, 0) + for _, item := range vs { + if item != "" { + parts = append(parts, item) + if item[0] == '*' { + break + } + } + } + return parts +} + +func (r *router) addRoute(method string, pattern string, handlers ...HandlerFunc) { + parts := parsePattern(pattern) + + key := method + "-" + pattern + _, ok := r.roots[method] + if !ok { + r.roots[method] = &_node{} + } + r.roots[method].insert(pattern, parts, 0) + r.handlers[key] = handlers +} + +func (r *router) getRoute(method string, path string) (*_node, map[string]string) { + searchParts := parsePattern(path) + params := make(map[string]string) + root, ok := r.roots[method] + + if !ok { + return nil, nil + } + + n := root.search(searchParts, 0) + + if n != nil { + parts := parsePattern(n.pattern) + for index, part := range parts { + if part[0] == ':' { + params[part[1:]] = searchParts[index] + } + if part[0] == '*' && len(part) > 1 { + params[part[1:]] = strings.Join(searchParts[index:], "/") + break + } + } + return n, params + } + + return nil, nil +} + +func (r *router) getRoutes(method string) []*_node { + root, ok := r.roots[method] + if !ok { + return nil + } + nodes := make([]*_node, 0) + root.travel(&nodes) + return nodes +} + +func (r *router) handle(c *Ctx) error { + if err := c.verify(); err != nil { + return err + } + + node, params := r.getRoute(c.Method, c.path) + if node != nil { + c.params = params + key := c.Method + "-" + node.pattern + c.handlers = append(c.handlers, r.handlers[key]...) + } else { + _, err := c.Writef("404 NOT FOUND: %s\n", c.path) + return err + } + + return c.Next() +} diff --git a/tree.go b/tree.go new file mode 100644 index 0000000..49aaa85 --- /dev/null +++ b/tree.go @@ -0,0 +1,76 @@ +package nf + +import ( + "strings" +) + +type _node struct { + pattern string + part string + children []*_node + isWild bool +} + +func (n *_node) insert(pattern string, parts []string, height int) { + if len(parts) == height { + n.pattern = pattern + return + } + + part := parts[height] + child := n.matchChild(part) + if child == nil { + child = &_node{part: part, isWild: part[0] == ':' || part[0] == '*'} + n.children = append(n.children, child) + } + child.insert(pattern, parts, height+1) +} + +func (n *_node) search(parts []string, height int) *_node { + if len(parts) == height || strings.HasPrefix(n.part, "*") { + if n.pattern == "" { + return nil + } + return n + } + + part := parts[height] + children := n.matchChildren(part) + + for _, child := range children { + result := child.search(parts, height+1) + if result != nil { + return result + } + } + + return nil +} + +func (n *_node) travel(list *([]*_node)) { + if n.pattern != "" { + *list = append(*list, n) + } + for _, child := range n.children { + child.travel(list) + } +} + +func (n *_node) matchChild(part string) *_node { + for _, child := range n.children { + if child.part == part || child.isWild { + return child + } + } + return nil +} + +func (n *_node) matchChildren(part string) []*_node { + nodes := make([]*_node, 0) + for _, child := range n.children { + if child.part == part || child.isWild { + nodes = append(nodes, child) + } + } + return nodes +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..c336ded --- /dev/null +++ b/util.go @@ -0,0 +1,81 @@ +package nf + +import ( + "fmt" + "github.com/loveuer/nf/internal/schema" + "strings" +) + +const ( + MIMETextXML = "text/xml" + MIMETextHTML = "text/html" + MIMETextPlain = "text/plain" + MIMETextJavaScript = "text/javascript" + MIMEApplicationXML = "application/xml" + MIMEApplicationJSON = "application/json" + // Deprecated: use MIMETextJavaScript instead + MIMEApplicationJavaScript = "application/javascript" + MIMEApplicationForm = "application/x-www-form-urlencoded" + MIMEOctetStream = "application/octet-stream" + MIMEMultipartForm = "multipart/form-data" + + MIMETextXMLCharsetUTF8 = "text/xml; charset=utf-8" + MIMETextHTMLCharsetUTF8 = "text/html; charset=utf-8" + MIMETextPlainCharsetUTF8 = "text/plain; charset=utf-8" + MIMETextJavaScriptCharsetUTF8 = "text/javascript; charset=utf-8" + MIMEApplicationXMLCharsetUTF8 = "application/xml; charset=utf-8" + MIMEApplicationJSONCharsetUTF8 = "application/json; charset=utf-8" + // Deprecated: use MIMETextJavaScriptCharsetUTF8 instead + MIMEApplicationJavaScriptCharsetUTF8 = "application/javascript; charset=utf-8" +) + +func verifyHandlers(path string, handlers ...HandlerFunc) { + if len(handlers) == 0 { + panic(fmt.Sprintf("missing handler in route: %s", path)) + } + + for _, handler := range handlers { + if handler == nil { + panic(fmt.Sprintf("nil handler found in route: %s", path)) + } + } +} + +// parseVendorSpecificContentType check if content type is vendor specific and +// if it is parsable to any known types. If it's not vendor specific then returns +// the original content type. +func parseVendorSpecificContentType(cType string) string { + plusIndex := strings.Index(cType, "+") + + if plusIndex == -1 { + return cType + } + + var parsableType string + if semiColonIndex := strings.Index(cType, ";"); semiColonIndex == -1 { + parsableType = cType[plusIndex+1:] + } else if plusIndex < semiColonIndex { + parsableType = cType[plusIndex+1 : semiColonIndex] + } else { + return cType[:semiColonIndex] + } + + slashIndex := strings.Index(cType, "/") + + if slashIndex == -1 { + return cType + } + + return cType[0:slashIndex+1] + parsableType +} + +func parseToStruct(aliasTag string, out interface{}, data map[string][]string) error { + schemaDecoder := schema.NewDecoder() + schemaDecoder.SetAliasTag(aliasTag) + + if err := schemaDecoder.Decode(out, data); err != nil { + return fmt.Errorf("failed to decode: %w", err) + } + + return nil +} diff --git a/xtest/basic/basic.http b/xtest/basic/basic.http new file mode 100644 index 0000000..daff45c --- /dev/null +++ b/xtest/basic/basic.http @@ -0,0 +1,2 @@ +### basic - get +GET http://127.0.0.1/hello/nf \ No newline at end of file diff --git a/xtest/basic/main.go b/xtest/basic/main.go new file mode 100644 index 0000000..9e45b3a --- /dev/null +++ b/xtest/basic/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "github.com/loveuer/nf" + "log" +) + +func main() { + app := nf.New() + + app.Get("/hello/:name", func(c *nf.Ctx) error { + name := c.Param("name") + return c.JSON(nf.Map{"status": 200, "data": "hello, " + name}) + }) + + log.Fatal(app.Run("0.0.0.0:80")) +} diff --git a/xtest/bodyLimit/body_limit.http b/xtest/bodyLimit/body_limit.http new file mode 100644 index 0000000..cbc7673 --- /dev/null +++ b/xtest/bodyLimit/body_limit.http @@ -0,0 +1,9 @@ +### body_limit +POST http://127.0.0.1/data +Content-Type: application/json + +{ + "name": "zyp", + "age": 19, + "likes": ["2233"] +} \ No newline at end of file diff --git a/xtest/bodyLimit/main.go b/xtest/bodyLimit/main.go new file mode 100644 index 0000000..ecaf0eb --- /dev/null +++ b/xtest/bodyLimit/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "github.com/loveuer/nf" + "log" +) + +func main() { + app := nf.New(nf.Config{BodyLimit: -1}) + + app.Post("/data", func(c *nf.Ctx) error { + type Req struct { + Name string `json:"name"` + Age int `json:"age"` + Likes []string `json:"likes"` + } + + var ( + err error + req = new(Req) + ) + + if err = c.BodyParser(req); err != nil { + return c.JSON(nf.Map{"status": 400, "err": err.Error()}) + } + + return c.JSON(nf.Map{"status": 200, "data": req}) + }) + + app.Post("/url", func(c *nf.Ctx) error { + type Req struct { + Name string `form:"name"` + Age int `form:"age"` + Likes []string `form:"likes"` + } + + var ( + err error + req = new(Req) + ) + + if err = c.BodyParser(req); err != nil { + return c.JSON(nf.Map{"status": 400, "err": err.Error()}) + } + + return c.JSON(nf.Map{"status": 200, "data": req}) + }) + + log.Fatal(app.Run("0.0.0.0:80")) +} diff --git a/xtest/panic/main.go b/xtest/panic/main.go new file mode 100644 index 0000000..a675ec6 --- /dev/null +++ b/xtest/panic/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "github.com/loveuer/nf" + "log" +) + +func main() { + app := nf.New(nf.Config{ + DisableRecover: true, + }) + + app.Get("/hello/:name", func(c *nf.Ctx) error { + name := c.Param("name") + + if name == "nf" { + panic("name is nf") + } + + return c.JSON("nice") + }) + + log.Fatal(app.Run("0.0.0.0:80")) +} diff --git a/xtest/panic/panic.http b/xtest/panic/panic.http new file mode 100644 index 0000000..8092822 --- /dev/null +++ b/xtest/panic/panic.http @@ -0,0 +1,5 @@ +### panic test +GET http://127.0.0.1/hello/nf + +### if covered? +GET http://127.0.0.1/hello/world \ No newline at end of file diff --git a/xtest/queryParser/main.go b/xtest/queryParser/main.go new file mode 100644 index 0000000..bc9d9b4 --- /dev/null +++ b/xtest/queryParser/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "github.com/loveuer/nf" + "log" +) + +func main() { + app := nf.New() + + app.Get("/hello", func(c *nf.Ctx) error { + type Req struct { + Name string `query:"name"` + Age int `query:"age"` + Likes []string `query:"likes"` + } + + var ( + err error + req = new(Req) + ) + + if err = c.QueryParser(req); err != nil { + return nf.NewNFError(400, err.Error()) + } + + return c.JSON(nf.Map{"status": 200, "data": req}) + }) + + log.Fatal(app.Run("0.0.0.0:80")) +}