diff --git a/middlewares/dump/req.go b/middlewares/dump/req.go new file mode 100644 index 0000000..f4488ff --- /dev/null +++ b/middlewares/dump/req.go @@ -0,0 +1,147 @@ +package dump + +import ( + "bytes" + "context" + "gitea.loveuer.com/yizhisec/packages/logger" + "gitea.loveuer.com/yizhisec/packages/tool" + "github.com/gin-gonic/gin" + "io" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +type RequestHandler func(r *http.Request, body []byte) string + +func Request(ctx context.Context, handler RequestHandler, writers ...io.Writer) gin.HandlerFunc { + + var ( + out io.Writer = os.Stdout + ch = make(chan string, 128) + builder = strings.Builder{} + buf = make([]string, 0, 16) + ) + + if len(writers) > 0 && writers[0] != nil { + out = writers[0] + } + + do := func() { + for _, item := range buf { + builder.WriteString(item) + builder.WriteRune('\n') + } + + _, _ = out.Write(tool.StringToBytes(builder.String())) + + builder.Reset() + buf = buf[:0] + } + + go func() { + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + if len(buf) > 0 { + do() + } + return + case <-ticker.C: + if len(buf) > 0 { + do() + } + case msg, _ := <-ch: + buf = append(buf, msg) + if len(buf) >= 10 { + do() + } + } + } + }() + + return func(c *gin.Context) { + var ( + err error + contentType = c.GetHeader("Content-Type") + contentLength = c.GetHeader("Content-Length") + cl int + ) + + if contentLength == "" && (c.Request.Method == "GET" || c.Request.Method == "HEAD") { + goto DUMP + } + + if cl, err = strconv.Atoi(contentLength); err != nil { + logger.WarnCtx(c.Request.Context(), "Request: convert Content-Length failed, err = %s", err.Error()) + c.Next() + return + } + + if cl > 0 && !strings.Contains(contentType, "application/json") { + c.Next() + return + } + + DUMP: + + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + logger.WarnCtx(c.Request.Context(), "读取请求体错误: %v", err) + c.Next() + return + } + + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + dumped := handler(c.Request, bodyBytes) + + ch <- dumped + + c.Next() + } +} + +func RequestHandlerCurl(r *http.Request, body []byte) string { + var builder strings.Builder + + // 添加 curl 基础命令和方法 + builder.WriteString("curl -X " + r.Method) + + // 添加请求 URL + url := getFullURL(r) + builder.WriteString(" '" + url + "'") + + // 添加请求头 + for key, values := range r.Header { + if strings.EqualFold(key, "Host") { + continue // 跳过 Host 头 + } + for _, value := range values { + builder.WriteString(" -H '" + key + ": " + value + "'") + } + } + + // 添加 JSON 数据 + if len(body) > 0 { + // 转义单引号防止命令中断 + escapedBody := strings.ReplaceAll(string(body), "'", `'\''`) + builder.WriteString(" -d '" + escapedBody + "'") + } + + return builder.String() +} + +func getFullURL(r *http.Request) string { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + return scheme + "://" + r.Host + r.URL.RequestURI() +} diff --git a/middlewares/dump/req_test.go b/middlewares/dump/req_test.go new file mode 100644 index 0000000..1f1a3ac --- /dev/null +++ b/middlewares/dump/req_test.go @@ -0,0 +1,74 @@ +package dump + +import ( + "bytes" + "encoding/json" + "gitea.loveuer.com/yizhisec/packages/logger" + "gitea.loveuer.com/yizhisec/packages/tool" + "github.com/gin-gonic/gin" + "net/http" + "testing" + "time" +) + +func TestRequest(t *testing.T) { + ready := make(chan struct{}) + + go func() { + app := gin.Default() + app.Use(Request(t.Context(), RequestHandlerCurl)) + + app.GET("/hello", func(c *gin.Context) { + c.JSON(200, gin.H{"name": c.Query("name")}) + }) + + app.POST("/hello", func(c *gin.Context) { + type Req struct { + Id int `json:"id"` + Name string `json:"name"` + } + + var ( + err error + req = new(Req) + ) + + if err = c.BindJSON(req); err != nil { + c.JSON(200, gin.H{"err": err}) + } + + c.JSON(200, gin.H{"id": req.Id, "name": req.Name}) + }) + + logger.Fatal(app.Run(":18080").Error()) + }() + + go func() { + time.Sleep(1 * time.Second) + for _ = range 10 { + _, err := http.Get("http://localhost:18080/hello?name=" + tool.RandomName()) + if err != nil { + t.Error(err.Error()) + } + } + + for _ = range 5 { + bs, _ := json.Marshal(map[string]interface{}{"id": tool.RandomInt(30), "name": tool.RandomName()}) + req, err := http.NewRequest(http.MethodPost, "http://localhost:18080/hello", bytes.NewReader(bs)) + req.Header.Set("Content-Type", "application/json") + if err != nil { + t.Fatal(err.Error()) + } + + _, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err.Error()) + } + } + + ready <- struct{}{} + }() + + <-ready + time.Sleep(1 * time.Second) +}