package nf import ( "bytes" "context" "encoding/json" "errors" "fmt" "html/template" "io" "mime/multipart" "net" "net/http" "strings" "sync" "github.com/google/uuid" "github.com/loveuer/nf/internal/sse" ) var forwardHeaders = []string{"CF-Connecting-IP", "X-Forwarded-For", "X-Real-Ip"} type Ctx struct { lock sync.Mutex writermem responseWriter Writer ResponseWriter Request *http.Request path string method string StatusCode int app *App params *Params index int handlers []HandlerFunc locals map[string]interface{} skippedNodes *[]skippedNode fullPath string } func (c *Ctx) reset(w http.ResponseWriter, r *http.Request) { traceId := r.Header.Get(TraceKey) if traceId == "" { traceId = uuid.Must(uuid.NewV7()).String() } c.writermem.reset(w) c.Request = r.WithContext(context.WithValue(r.Context(), TraceKey, traceId)) c.Writer = &c.writermem c.handlers = nil c.index = -1 c.path = r.URL.Path c.method = r.Method c.StatusCode = 200 c.fullPath = "" *c.params = (*c.params)[:0] *c.skippedNodes = (*c.skippedNodes)[:0] for key := range c.locals { delete(c.locals, key) } c.writermem.Header().Set(TraceKey, traceId) } func (c *Ctx) Locals(key string, value ...interface{}) interface{} { data := c.locals[key] if len(value) > 0 { c.locals[key] = value[0] } return data } func (c *Ctx) Method(overWrite ...string) string { method := c.Request.Method if len(overWrite) > 0 && overWrite[0] != "" { c.Request.Method = overWrite[0] } return method } func (c *Ctx) Path(overWrite ...string) string { path := c.Request.URL.Path if len(overWrite) > 0 && overWrite[0] != "" { c.Request.URL.Path = overWrite[0] } return path } func (c *Ctx) Cookies(key string, defaultValue ...string) string { dv := "" if len(defaultValue) > 0 { dv = defaultValue[0] } cookie, err := c.Request.Cookie(key) if err != nil || cookie.Value == "" { return dv } return cookie.Value } func (c *Ctx) Context() context.Context { return c.Request.Context() } func (c *Ctx) Next() error { c.index++ if c.index >= len(c.handlers) { return nil } var ( err error handler = c.handlers[c.index] ) if handler != nil { if err = handler(c); err != nil { return err } } c.index++ return nil } /* =============================================================== || Handle Ctx Request Part =============================================================== */ func (c *Ctx) verify() error { // 验证 body size if c.app.config.BodyLimit != -1 && c.Request.ContentLength > c.app.config.BodyLimit { return NewNFError(413, "Content Too Large") } return nil } func (c *Ctx) Param(key string) string { return c.params.ByName(key) } func (c *Ctx) SetParam(key, value string) { c.lock.Lock() defer c.lock.Unlock() params := append(*c.params, Param{Key: key, Value: value}) c.params = ¶ms } func (c *Ctx) Form(key string) string { return c.Request.FormValue(key) } // FormValue fiber ctx function func (c *Ctx) FormValue(key string) string { return c.Request.FormValue(key) } func (c *Ctx) FormFile(key string) (*multipart.FileHeader, error) { _, fh, err := c.Request.FormFile(key) return fh, err } func (c *Ctx) MultipartForm() (*multipart.Form, error) { if err := c.Request.ParseMultipartForm(c.app.config.BodyLimit); err != nil { return nil, err } return c.Request.MultipartForm, nil } func (c *Ctx) Query(key string) string { return c.Request.URL.Query().Get(key) } func (c *Ctx) Get(key string, defaultValue ...string) string { value := c.Request.Header.Get(key) if value == "" && len(defaultValue) > 0 { return defaultValue[0] } return value } func (c *Ctx) IP(useProxyHeader ...bool) string { ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) if err != nil { return "" } if len(useProxyHeader) > 0 && useProxyHeader[0] { for _, h := range forwardHeaders { for _, rip := range strings.Split(c.Request.Header.Get(h), ",") { realIP := net.ParseIP(strings.Replace(rip, " ", "", -1)) if check := net.ParseIP(realIP.String()); check != nil { ip = realIP.String() break } } } } return ip } func (c *Ctx) BodyParser(out interface{}) error { var ( err error ctype = strings.ToLower(c.Request.Header.Get("Content-Type")) ) ctype = parseVendorSpecificContentType(ctype) ctypeEnd := strings.IndexByte(ctype, ';') if ctypeEnd != -1 { ctype = ctype[:ctypeEnd] } if strings.HasSuffix(ctype, "json") { bs, err := io.ReadAll(c.Request.Body) if err != nil { return err } _ = c.Request.Body.Close() c.Request.Body = io.NopCloser(bytes.NewReader(bs)) return json.Unmarshal(bs, out) } if strings.HasPrefix(ctype, MIMEApplicationForm) { if err = c.Request.ParseForm(); err != nil { return NewNFError(400, err.Error()) } return parseToStruct("form", out, c.Request.Form) } if strings.HasPrefix(ctype, MIMEMultipartForm) { if err = c.Request.ParseMultipartForm(c.app.config.BodyLimit); err != nil { return NewNFError(400, err.Error()) } return parseToStruct("form", out, c.Request.PostForm) } return NewNFError(422, "Unprocessable Content") } func (c *Ctx) QueryParser(out interface{}) error { return parseToStruct("query", out, c.Request.URL.Query()) } /* =============================================================== || Handle Ctx Response Part =============================================================== */ func (c *Ctx) Status(code int) *Ctx { c.lock.Lock() defer c.lock.Unlock() c.Writer.WriteHeader(code) c.StatusCode = c.writermem.status return c } // Set set response header func (c *Ctx) Set(key string, value string) { c.Writer.Header().Set(key, value) } // AddHeader add response header func (c *Ctx) AddHeader(key string, value string) { c.Writer.Header().Add(key, value) } // SetHeader set response header func (c *Ctx) SetHeader(key string, value string) { c.Writer.Header().Set(key, value) } func (c *Ctx) SendStatus(code int) error { c.Status(code) c.Writer.WriteHeaderNow() return nil } func (c *Ctx) SendString(data string) error { c.SetHeader("Content-Type", "text/plain") _, err := c.Write([]byte(data)) return err } func (c *Ctx) Writef(format string, values ...interface{}) (int, error) { c.SetHeader("Content-Type", "text/plain") return c.Write([]byte(fmt.Sprintf(format, values...))) } func (c *Ctx) JSON(data interface{}) error { c.SetHeader("Content-Type", MIMEApplicationJSON) encoder := json.NewEncoder(c.Writer) if err := encoder.Encode(data); err != nil { return err } return nil } func (c *Ctx) SSEvent(event string, data interface{}) error { c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Transfer-Encoding", "chunked") return sse.Encode(c.Writer, sse.Event{Event: event, Data: data}) } func (c *Ctx) Flush() error { if f, ok := c.Writer.(http.Flusher); ok { f.Flush() return nil } return errors.New("http.Flusher is not implemented") } func (c *Ctx) HTML(html string) error { c.SetHeader("Content-Type", "text/html") _, err := c.Write([]byte(html)) return err } func (c *Ctx) RenderHTML(name, html string, obj any) error { c.SetHeader("Content-Type", "text/html") t, err := template.New(name).Parse(html) if err != nil { return err } return t.Execute(c.Writer, obj) } func (c *Ctx) Redirect(url string, code int) error { http.Redirect(c.Writer, c.Request, url, code) return nil } func (c *Ctx) Write(data []byte) (int, error) { return c.Writer.Write(data) }