commit fb65e32239780ea470e7de742549aba121dfd544 Author: loveuer Date: Mon Dec 30 15:09:02 2024 +0800 wip: 初版 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..340a9d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +xtest +.vscode +.idea +*.db +*.sqlite +.DS_Store \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..a0d36a8 --- /dev/null +++ b/go.mod @@ -0,0 +1,48 @@ +module github.com/loveuer/upp + +go 1.23.4 + +require ( + gitea.com/loveuer/gredis v1.0.0 + github.com/elastic/go-elasticsearch/v7 v7.17.10 + github.com/fatih/color v1.18.0 + github.com/glebarez/sqlite v1.11.0 + github.com/go-redis/redis/v8 v8.11.5 + github.com/google/uuid v1.6.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/jedib0t/go-pretty/v6 v6.6.5 + github.com/loveuer/nf v0.3.1 + github.com/samber/lo v1.47.0 + github.com/spf13/cast v1.7.1 + golang.org/x/crypto v0.25.0 + gorm.io/driver/mysql v1.5.7 + gorm.io/driver/postgres v1.5.11 + gorm.io/gorm v1.25.12 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/glebarez/go-sqlite v1.21.2 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rivo/uniseg v0.2.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.21.0 // indirect + modernc.org/libc v1.22.5 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/sqlite v1.23.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..131bc1a --- /dev/null +++ b/go.sum @@ -0,0 +1,124 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +gitea.com/loveuer/gredis v1.0.0 h1:fbRS8YZObcp1KV1KGj8pDpIj1WrI0W8pwU9Ny/2fJys= +gitea.com/loveuer/gredis v1.0.0/go.mod h1:TQlubgDiyNTRXqASd/XIUrqPBLj9NZRR2DmV3V2ZyMY= +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/elastic/go-elasticsearch/v7 v7.17.10 h1:TCQ8i4PmIJuBunvBS6bwT2ybzVFxxUhhltAs3Gyu1yo= +github.com/elastic/go-elasticsearch/v7 v7.17.10/go.mod h1:OJ4wdbtDNk5g503kvlHLyErCgQwwzmDtaFC4XyOxXA4= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= +github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= +github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= +github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jedib0t/go-pretty/v6 v6.6.5 h1:9PgMJOVBedpgYLI56jQRJYqngxYAAzfEUua+3NgSqAo= +github.com/jedib0t/go-pretty/v6 v6.6.5/go.mod h1:Uq/HrbhuFty5WSVNfjpQQe47x16RwVGXIveNGEyGtHs= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/loveuer/nf v0.3.1 h1:FTmyAC9LQF06BVGeGwrwaYfbC6MIQMqr+GoZUQQPvXU= +github.com/loveuer/nf v0.3.1/go.mod h1:aApO+2cSP0ULczkfS4OVw8zfWM3rY8gQrzc5PnVV7lY= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc= +github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314= +gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= +modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE= +modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM= +modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk= diff --git a/internal/bytesconv/bytesconv_1.19.go b/internal/bytesconv/bytesconv_1.19.go new file mode 100644 index 0000000..669c9c9 --- /dev/null +++ b/internal/bytesconv/bytesconv_1.19.go @@ -0,0 +1,26 @@ +// Copyright 2020 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +//go:build !go1.20 + +package bytesconv + +import ( + "unsafe" +) + +// StringToBytes converts string to byte slice without a memory allocation. +func StringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer( + &struct { + string + Cap int + }{s, len(s)}, + )) +} + +// BytesToString converts byte slice to string without a memory allocation. +func BytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} diff --git a/internal/bytesconv/bytesconv_1.20.go b/internal/bytesconv/bytesconv_1.20.go new file mode 100644 index 0000000..5b6040a --- /dev/null +++ b/internal/bytesconv/bytesconv_1.20.go @@ -0,0 +1,23 @@ +// Copyright 2023 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +//go:build go1.20 + +package bytesconv + +import ( + "unsafe" +) + +// StringToBytes converts string to byte slice without a memory allocation. +// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077. +func StringToBytes(s string) []byte { + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// BytesToString converts byte slice to string without a memory allocation. +// For more details, see https://github.com/golang/go/issues/53003#issuecomment-1140276077. +func BytesToString(b []byte) string { + return unsafe.String(unsafe.SliceData(b), len(b)) +} diff --git a/internal/bytesconv/bytesconv_test.go b/internal/bytesconv/bytesconv_test.go new file mode 100644 index 0000000..eeaad5e --- /dev/null +++ b/internal/bytesconv/bytesconv_test.go @@ -0,0 +1,99 @@ +// Copyright 2020 Gin Core Team. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found in the LICENSE file. + +package bytesconv + +import ( + "bytes" + "math/rand" + "strings" + "testing" + "time" +) + +var testString = "Albert Einstein: Logic will get you from A to B. Imagination will take you everywhere." +var testBytes = []byte(testString) + +func rawBytesToStr(b []byte) string { + return string(b) +} + +func rawStrToBytes(s string) []byte { + return []byte(s) +} + +// go test -v + +func TestBytesToString(t *testing.T) { + data := make([]byte, 1024) + for i := 0; i < 100; i++ { + rand.Read(data) + if rawBytesToStr(data) != BytesToString(data) { + t.Fatal("don't match") + } + } +} + +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const ( + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + if idx := int(cache & letterIdxMask); idx < len(letterBytes) { + sb.WriteByte(letterBytes[idx]) + i-- + } + cache >>= letterIdxBits + remain-- + } + + return sb.String() +} + +func TestStringToBytes(t *testing.T) { + for i := 0; i < 100; i++ { + s := RandStringBytesMaskImprSrcSB(64) + if !bytes.Equal(rawStrToBytes(s), StringToBytes(s)) { + t.Fatal("don't match") + } + } +} + +// go test -v -run=none -bench=^BenchmarkBytesConv -benchmem=true + +func BenchmarkBytesConvBytesToStrRaw(b *testing.B) { + for i := 0; i < b.N; i++ { + rawBytesToStr(testBytes) + } +} + +func BenchmarkBytesConvBytesToStr(b *testing.B) { + for i := 0; i < b.N; i++ { + BytesToString(testBytes) + } +} + +func BenchmarkBytesConvStrToBytesRaw(b *testing.B) { + for i := 0; i < b.N; i++ { + rawStrToBytes(testString) + } +} + +func BenchmarkBytesConvStrToBytes(b *testing.B) { + for i := 0; i < b.N; i++ { + StringToBytes(testString) + } +} diff --git a/internal/schema/LICENSE b/internal/schema/LICENSE new file mode 100644 index 0000000..0e5fb87 --- /dev/null +++ b/internal/schema/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 Rodrigo Moraes. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/internal/schema/cache.go b/internal/schema/cache.go new file mode 100644 index 0000000..bf21697 --- /dev/null +++ b/internal/schema/cache.go @@ -0,0 +1,305 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schema + +import ( + "errors" + "reflect" + "strconv" + "strings" + "sync" +) + +var errInvalidPath = errors.New("schema: invalid path") + +// newCache returns a new cache. +func newCache() *cache { + c := cache{ + m: make(map[reflect.Type]*structInfo), + regconv: make(map[reflect.Type]Converter), + tag: "schema", + } + return &c +} + +// cache caches meta-data about a struct. +type cache struct { + l sync.RWMutex + m map[reflect.Type]*structInfo + regconv map[reflect.Type]Converter + tag string +} + +// registerConverter registers a converter function for a custom type. +func (c *cache) registerConverter(value interface{}, converterFunc Converter) { + c.regconv[reflect.TypeOf(value)] = converterFunc +} + +// parsePath parses a path in dotted notation verifying that it is a valid +// path to a struct field. +// +// It returns "path parts" which contain indices to fields to be used by +// reflect.Value.FieldByString(). Multiple parts are required for slices of +// structs. +func (c *cache) parsePath(p string, t reflect.Type) ([]pathPart, error) { + var struc *structInfo + var field *fieldInfo + var index64 int64 + var err error + parts := make([]pathPart, 0) + path := make([]string, 0) + keys := strings.Split(p, ".") + for i := 0; i < len(keys); i++ { + if t.Kind() != reflect.Struct { + return nil, errInvalidPath + } + if struc = c.get(t); struc == nil { + return nil, errInvalidPath + } + if field = struc.get(keys[i]); field == nil { + return nil, errInvalidPath + } + // Valid field. Append index. + path = append(path, field.name) + if field.isSliceOfStructs && (!field.unmarshalerInfo.IsValid || (field.unmarshalerInfo.IsValid && field.unmarshalerInfo.IsSliceElement)) { + // Parse a special case: slices of structs. + // i+1 must be the slice index. + // + // Now that struct can implements TextUnmarshaler interface, + // we don't need to force the struct's fields to appear in the path. + // So checking i+2 is not necessary anymore. + i++ + if i+1 > len(keys) { + return nil, errInvalidPath + } + if index64, err = strconv.ParseInt(keys[i], 10, 0); err != nil { + return nil, errInvalidPath + } + parts = append(parts, pathPart{ + path: path, + field: field, + index: int(index64), + }) + path = make([]string, 0) + + // Get the next struct type, dropping ptrs. + if field.typ.Kind() == reflect.Ptr { + t = field.typ.Elem() + } else { + t = field.typ + } + if t.Kind() == reflect.Slice { + t = t.Elem() + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + } + } else if field.typ.Kind() == reflect.Ptr { + t = field.typ.Elem() + } else { + t = field.typ + } + } + // Add the remaining. + parts = append(parts, pathPart{ + path: path, + field: field, + index: -1, + }) + return parts, nil +} + +// get returns a cached structInfo, creating it if necessary. +func (c *cache) get(t reflect.Type) *structInfo { + c.l.RLock() + info := c.m[t] + c.l.RUnlock() + if info == nil { + info = c.create(t, "") + c.l.Lock() + c.m[t] = info + c.l.Unlock() + } + return info +} + +// create creates a structInfo with meta-data about a struct. +func (c *cache) create(t reflect.Type, parentAlias string) *structInfo { + info := &structInfo{} + var anonymousInfos []*structInfo + for i := 0; i < t.NumField(); i++ { + if f := c.createField(t.Field(i), parentAlias); f != nil { + info.fields = append(info.fields, f) + if ft := indirectType(f.typ); ft.Kind() == reflect.Struct && f.isAnonymous { + anonymousInfos = append(anonymousInfos, c.create(ft, f.canonicalAlias)) + } + } + } + for i, a := range anonymousInfos { + others := []*structInfo{info} + others = append(others, anonymousInfos[:i]...) + others = append(others, anonymousInfos[i+1:]...) + for _, f := range a.fields { + if !containsAlias(others, f.alias) { + info.fields = append(info.fields, f) + } + } + } + return info +} + +// createField creates a fieldInfo for the given field. +func (c *cache) createField(field reflect.StructField, parentAlias string) *fieldInfo { + alias, options := fieldAlias(field, c.tag) + if alias == "-" { + // Ignore this field. + return nil + } + canonicalAlias := alias + if parentAlias != "" { + canonicalAlias = parentAlias + "." + alias + } + // Check if the type is supported and don't cache it if not. + // First let's get the basic type. + isSlice, isStruct := false, false + ft := field.Type + m := isTextUnmarshaler(reflect.Zero(ft)) + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + if isSlice = ft.Kind() == reflect.Slice; isSlice { + ft = ft.Elem() + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + } + if ft.Kind() == reflect.Array { + ft = ft.Elem() + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + } + if isStruct = ft.Kind() == reflect.Struct; !isStruct { + if c.converter(ft) == nil && builtinConverters[ft.Kind()] == nil { + // Type is not supported. + return nil + } + } + + return &fieldInfo{ + typ: field.Type, + name: field.Name, + alias: alias, + canonicalAlias: canonicalAlias, + unmarshalerInfo: m, + isSliceOfStructs: isSlice && isStruct, + isAnonymous: field.Anonymous, + isRequired: options.Contains("required"), + } +} + +// converter returns the converter for a type. +func (c *cache) converter(t reflect.Type) Converter { + return c.regconv[t] +} + +// ---------------------------------------------------------------------------- + +type structInfo struct { + fields []*fieldInfo +} + +func (i *structInfo) get(alias string) *fieldInfo { + for _, field := range i.fields { + if strings.EqualFold(field.alias, alias) { + return field + } + } + return nil +} + +func containsAlias(infos []*structInfo, alias string) bool { + for _, info := range infos { + if info.get(alias) != nil { + return true + } + } + return false +} + +type fieldInfo struct { + typ reflect.Type + // name is the field name in the struct. + name string + alias string + // canonicalAlias is almost the same as the alias, but is prefixed with + // an embedded struct field alias in dotted notation if this field is + // promoted from the struct. + // For instance, if the alias is "N" and this field is an embedded field + // in a struct "X", canonicalAlias will be "X.N". + canonicalAlias string + // unmarshalerInfo contains information regarding the + // encoding.TextUnmarshaler implementation of the field type. + unmarshalerInfo unmarshaler + // isSliceOfStructs indicates if the field type is a slice of structs. + isSliceOfStructs bool + // isAnonymous indicates whether the field is embedded in the struct. + isAnonymous bool + isRequired bool +} + +func (f *fieldInfo) paths(prefix string) []string { + if f.alias == f.canonicalAlias { + return []string{prefix + f.alias} + } + return []string{prefix + f.alias, prefix + f.canonicalAlias} +} + +type pathPart struct { + field *fieldInfo + path []string // path to the field: walks structs using field names. + index int // struct index in slices of structs. +} + +// ---------------------------------------------------------------------------- + +func indirectType(typ reflect.Type) reflect.Type { + if typ.Kind() == reflect.Ptr { + return typ.Elem() + } + return typ +} + +// fieldAlias parses a field tag to get a field alias. +func fieldAlias(field reflect.StructField, tagName string) (alias string, options tagOptions) { + if tag := field.Tag.Get(tagName); tag != "" { + alias, options = parseTag(tag) + } + if alias == "" { + alias = field.Name + } + return alias, options +} + +// tagOptions is the string following a comma in a struct field's tag, or +// the empty string. It does not include the leading comma. +type tagOptions []string + +// parseTag splits a struct field's url tag into its name and comma-separated +// options. +func parseTag(tag string) (string, tagOptions) { + s := strings.Split(tag, ",") + return s[0], s[1:] +} + +// Contains checks whether the tagOptions contains the specified option. +func (o tagOptions) Contains(option string) bool { + for _, s := range o { + if s == option { + return true + } + } + return false +} diff --git a/internal/schema/converter.go b/internal/schema/converter.go new file mode 100644 index 0000000..4f2116a --- /dev/null +++ b/internal/schema/converter.go @@ -0,0 +1,145 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schema + +import ( + "reflect" + "strconv" +) + +type Converter func(string) reflect.Value + +var ( + invalidValue = reflect.Value{} + boolType = reflect.Bool + float32Type = reflect.Float32 + float64Type = reflect.Float64 + intType = reflect.Int + int8Type = reflect.Int8 + int16Type = reflect.Int16 + int32Type = reflect.Int32 + int64Type = reflect.Int64 + stringType = reflect.String + uintType = reflect.Uint + uint8Type = reflect.Uint8 + uint16Type = reflect.Uint16 + uint32Type = reflect.Uint32 + uint64Type = reflect.Uint64 +) + +// Default converters for basic types. +var builtinConverters = map[reflect.Kind]Converter{ + boolType: convertBool, + float32Type: convertFloat32, + float64Type: convertFloat64, + intType: convertInt, + int8Type: convertInt8, + int16Type: convertInt16, + int32Type: convertInt32, + int64Type: convertInt64, + stringType: convertString, + uintType: convertUint, + uint8Type: convertUint8, + uint16Type: convertUint16, + uint32Type: convertUint32, + uint64Type: convertUint64, +} + +func convertBool(value string) reflect.Value { + if value == "on" { + return reflect.ValueOf(true) + } else if v, err := strconv.ParseBool(value); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} + +func convertFloat32(value string) reflect.Value { + if v, err := strconv.ParseFloat(value, 32); err == nil { + return reflect.ValueOf(float32(v)) + } + return invalidValue +} + +func convertFloat64(value string) reflect.Value { + if v, err := strconv.ParseFloat(value, 64); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} + +func convertInt(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 0); err == nil { + return reflect.ValueOf(int(v)) + } + return invalidValue +} + +func convertInt8(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 8); err == nil { + return reflect.ValueOf(int8(v)) + } + return invalidValue +} + +func convertInt16(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 16); err == nil { + return reflect.ValueOf(int16(v)) + } + return invalidValue +} + +func convertInt32(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 32); err == nil { + return reflect.ValueOf(int32(v)) + } + return invalidValue +} + +func convertInt64(value string) reflect.Value { + if v, err := strconv.ParseInt(value, 10, 64); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} + +func convertString(value string) reflect.Value { + return reflect.ValueOf(value) +} + +func convertUint(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 0); err == nil { + return reflect.ValueOf(uint(v)) + } + return invalidValue +} + +func convertUint8(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 8); err == nil { + return reflect.ValueOf(uint8(v)) + } + return invalidValue +} + +func convertUint16(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 16); err == nil { + return reflect.ValueOf(uint16(v)) + } + return invalidValue +} + +func convertUint32(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 32); err == nil { + return reflect.ValueOf(uint32(v)) + } + return invalidValue +} + +func convertUint64(value string) reflect.Value { + if v, err := strconv.ParseUint(value, 10, 64); err == nil { + return reflect.ValueOf(v) + } + return invalidValue +} diff --git a/internal/schema/decoder.go b/internal/schema/decoder.go new file mode 100644 index 0000000..b63c45e --- /dev/null +++ b/internal/schema/decoder.go @@ -0,0 +1,534 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schema + +import ( + "encoding" + "errors" + "fmt" + "reflect" + "strings" +) + +// NewDecoder returns a new Decoder. +func NewDecoder() *Decoder { + return &Decoder{cache: newCache()} +} + +// Decoder decodes values from a map[string][]string to a struct. +type Decoder struct { + cache *cache + zeroEmpty bool + ignoreUnknownKeys bool +} + +// SetAliasTag changes the tag used to locate custom field aliases. +// The default tag is "schema". +func (d *Decoder) SetAliasTag(tag string) { + d.cache.tag = tag +} + +// ZeroEmpty controls the behaviour when the decoder encounters empty values +// in a map. +// If z is true and a key in the map has the empty string as a value +// then the corresponding struct field is set to the zero value. +// If z is false then empty strings are ignored. +// +// The default value is false, that is empty values do not change +// the value of the struct field. +func (d *Decoder) ZeroEmpty(z bool) { + d.zeroEmpty = z +} + +// IgnoreUnknownKeys controls the behaviour when the decoder encounters unknown +// keys in the map. +// If i is true and an unknown field is encountered, it is ignored. This is +// similar to how unknown keys are handled by encoding/json. +// If i is false then Decode will return an error. Note that any valid keys +// will still be decoded in to the target struct. +// +// To preserve backwards compatibility, the default value is false. +func (d *Decoder) IgnoreUnknownKeys(i bool) { + d.ignoreUnknownKeys = i +} + +// RegisterConverter registers a converter function for a custom type. +func (d *Decoder) RegisterConverter(value interface{}, converterFunc Converter) { + d.cache.registerConverter(value, converterFunc) +} + +// Decode decodes a map[string][]string to a struct. +// +// The first parameter must be a pointer to a struct. +// +// The second parameter is a map, typically url.Values from an HTTP request. +// Keys are "paths" in dotted notation to the struct fields and nested structs. +// +// See the package documentation for a full explanation of the mechanics. +func (d *Decoder) Decode(dst interface{}, src map[string][]string) error { + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return errors.New("schema: interface must be a pointer to struct") + } + v = v.Elem() + t := v.Type() + multiError := MultiError{} + for path, values := range src { + if parts, err := d.cache.parsePath(path, t); err == nil { + if err = d.decode(v, path, parts, values); err != nil { + multiError[path] = err + } + } else if !d.ignoreUnknownKeys { + multiError[path] = UnknownKeyError{Key: path} + } + } + multiError.merge(d.checkRequired(t, src)) + if len(multiError) > 0 { + return multiError + } + return nil +} + +// checkRequired checks whether required fields are empty +// +// check type t recursively if t has struct fields. +// +// src is the source map for decoding, we use it here to see if those required fields are included in src +func (d *Decoder) checkRequired(t reflect.Type, src map[string][]string) MultiError { + m, errs := d.findRequiredFields(t, "", "") + for key, fields := range m { + if isEmptyFields(fields, src) { + errs[key] = EmptyFieldError{Key: key} + } + } + return errs +} + +// findRequiredFields recursively searches the struct type t for required fields. +// +// canonicalPrefix and searchPrefix are used to resolve full paths in dotted notation +// for nested struct fields. canonicalPrefix is a complete path which never omits +// any embedded struct fields. searchPrefix is a user-friendly path which may omit +// some embedded struct fields to point promoted fields. +func (d *Decoder) findRequiredFields(t reflect.Type, canonicalPrefix, searchPrefix string) (map[string][]fieldWithPrefix, MultiError) { + struc := d.cache.get(t) + if struc == nil { + // unexpect, cache.get never return nil + return nil, MultiError{canonicalPrefix + "*": errors.New("cache fail")} + } + + m := map[string][]fieldWithPrefix{} + errs := MultiError{} + for _, f := range struc.fields { + if f.typ.Kind() == reflect.Struct { + fcprefix := canonicalPrefix + f.canonicalAlias + "." + for _, fspath := range f.paths(searchPrefix) { + fm, ferrs := d.findRequiredFields(f.typ, fcprefix, fspath+".") + for key, fields := range fm { + m[key] = append(m[key], fields...) + } + errs.merge(ferrs) + } + } + if f.isRequired { + key := canonicalPrefix + f.canonicalAlias + m[key] = append(m[key], fieldWithPrefix{ + fieldInfo: f, + prefix: searchPrefix, + }) + } + } + return m, errs +} + +type fieldWithPrefix struct { + *fieldInfo + prefix string +} + +// isEmptyFields returns true if all of specified fields are empty. +func isEmptyFields(fields []fieldWithPrefix, src map[string][]string) bool { + for _, f := range fields { + for _, path := range f.paths(f.prefix) { + v, ok := src[path] + if ok && !isEmpty(f.typ, v) { + return false + } + for key := range src { + // issue references: + // https://github.com/gofiber/fiber/issues/1414 + // https://github.com/gorilla/schema/issues/176 + nested := strings.IndexByte(key, '.') != -1 + + // for non required nested structs + c1 := strings.HasSuffix(f.prefix, ".") && key == path + + // for required nested structs + c2 := f.prefix == "" && nested && strings.HasPrefix(key, path) + + // for non nested fields + c3 := f.prefix == "" && !nested && key == path + if !isEmpty(f.typ, src[key]) && (c1 || c2 || c3) { + return false + } + } + } + } + return true +} + +// isEmpty returns true if value is empty for specific type +func isEmpty(t reflect.Type, value []string) bool { + if len(value) == 0 { + return true + } + switch t.Kind() { + case boolType, float32Type, float64Type, intType, int8Type, int32Type, int64Type, stringType, uint8Type, uint16Type, uint32Type, uint64Type: + return len(value[0]) == 0 + } + return false +} + +// decode fills a struct field using a parsed path. +func (d *Decoder) decode(v reflect.Value, path string, parts []pathPart, values []string) error { + // Get the field walking the struct fields by index. + for _, name := range parts[0].path { + if v.Type().Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + + // alloc embedded structs + if v.Type().Kind() == reflect.Struct { + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.Type().Kind() == reflect.Ptr && field.IsNil() && v.Type().Field(i).Anonymous { + field.Set(reflect.New(field.Type().Elem())) + } + } + } + + v = v.FieldByName(name) + } + // Don't even bother for unexported fields. + if !v.CanSet() { + return nil + } + + // Dereference if needed. + t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + if v.IsNil() { + v.Set(reflect.New(t)) + } + v = v.Elem() + } + + // Slice of structs. Let's go recursive. + if len(parts) > 1 { + idx := parts[0].index + if v.IsNil() || v.Len() < idx+1 { + value := reflect.MakeSlice(t, idx+1, idx+1) + if v.Len() < idx+1 { + // Resize it. + reflect.Copy(value, v) + } + v.Set(value) + } + return d.decode(v.Index(idx), path, parts[1:], values) + } + + // Get the converter early in case there is one for a slice type. + conv := d.cache.converter(t) + m := isTextUnmarshaler(v) + if conv == nil && t.Kind() == reflect.Slice && m.IsSliceElement { + var items []reflect.Value + elemT := t.Elem() + isPtrElem := elemT.Kind() == reflect.Ptr + if isPtrElem { + elemT = elemT.Elem() + } + + // Try to get a converter for the element type. + conv := d.cache.converter(elemT) + if conv == nil { + conv = builtinConverters[elemT.Kind()] + if conv == nil { + // As we are not dealing with slice of structs here, we don't need to check if the type + // implements TextUnmarshaler interface + return fmt.Errorf("schema: converter not found for %v", elemT) + } + } + + for key, value := range values { + if value == "" { + if d.zeroEmpty { + items = append(items, reflect.Zero(elemT)) + } + } else if m.IsValid { + u := reflect.New(elemT) + if m.IsSliceElementPtr { + u = reflect.New(reflect.PtrTo(elemT).Elem()) + } + if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(value)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: key, + Err: err, + } + } + if m.IsSliceElementPtr { + items = append(items, u.Elem().Addr()) + } else if u.Kind() == reflect.Ptr { + items = append(items, u.Elem()) + } else { + items = append(items, u) + } + } else if item := conv(value); item.IsValid() { + if isPtrElem { + ptr := reflect.New(elemT) + ptr.Elem().Set(item) + item = ptr + } + if item.Type() != elemT && !isPtrElem { + item = item.Convert(elemT) + } + items = append(items, item) + } else { + if strings.Contains(value, ",") { + values := strings.Split(value, ",") + for _, value := range values { + if value == "" { + if d.zeroEmpty { + items = append(items, reflect.Zero(elemT)) + } + } else if item := conv(value); item.IsValid() { + if isPtrElem { + ptr := reflect.New(elemT) + ptr.Elem().Set(item) + item = ptr + } + if item.Type() != elemT && !isPtrElem { + item = item.Convert(elemT) + } + items = append(items, item) + } else { + return ConversionError{ + Key: path, + Type: elemT, + Index: key, + } + } + } + } else { + return ConversionError{ + Key: path, + Type: elemT, + Index: key, + } + } + } + } + value := reflect.Append(reflect.MakeSlice(t, 0, 0), items...) + v.Set(value) + } else { + val := "" + // Use the last value provided if any values were provided + if len(values) > 0 { + val = values[len(values)-1] + } + + if conv != nil { + if value := conv(val); value.IsValid() { + v.Set(value.Convert(t)) + } else { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + } + } + } else if m.IsValid { + if m.IsPtr { + u := reflect.New(v.Type()) + if err := u.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(val)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + Err: err, + } + } + v.Set(reflect.Indirect(u)) + } else { + // If the value implements the encoding.TextUnmarshaler interface + // apply UnmarshalText as the converter + if err := m.Unmarshaler.UnmarshalText([]byte(val)); err != nil { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + Err: err, + } + } + } + } else if val == "" { + if d.zeroEmpty { + v.Set(reflect.Zero(t)) + } + } else if conv := builtinConverters[t.Kind()]; conv != nil { + if value := conv(val); value.IsValid() { + v.Set(value.Convert(t)) + } else { + return ConversionError{ + Key: path, + Type: t, + Index: -1, + } + } + } else { + return fmt.Errorf("schema: converter not found for %v", t) + } + } + return nil +} + +func isTextUnmarshaler(v reflect.Value) unmarshaler { + // Create a new unmarshaller instance + m := unmarshaler{} + if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid { + return m + } + // As the UnmarshalText function should be applied to the pointer of the + // type, we check that type to see if it implements the necessary + // method. + if m.Unmarshaler, m.IsValid = reflect.New(v.Type()).Interface().(encoding.TextUnmarshaler); m.IsValid { + m.IsPtr = true + return m + } + + // if v is []T or *[]T create new T + t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Slice { + // Check if the slice implements encoding.TextUnmarshaller + if m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler); m.IsValid { + return m + } + // If t is a pointer slice, check if its elements implement + // encoding.TextUnmarshaler + m.IsSliceElement = true + if t = t.Elem(); t.Kind() == reflect.Ptr { + t = reflect.PtrTo(t.Elem()) + v = reflect.Zero(t) + m.IsSliceElementPtr = true + m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler) + return m + } + } + + v = reflect.New(t) + m.Unmarshaler, m.IsValid = v.Interface().(encoding.TextUnmarshaler) + return m +} + +// TextUnmarshaler helpers ---------------------------------------------------- +// unmarshaller contains information about a TextUnmarshaler type +type unmarshaler struct { + Unmarshaler encoding.TextUnmarshaler + // IsValid indicates whether the resolved type indicated by the other + // flags implements the encoding.TextUnmarshaler interface. + IsValid bool + // IsPtr indicates that the resolved type is the pointer of the original + // type. + IsPtr bool + // IsSliceElement indicates that the resolved type is a slice element of + // the original type. + IsSliceElement bool + // IsSliceElementPtr indicates that the resolved type is a pointer to a + // slice element of the original type. + IsSliceElementPtr bool +} + +// Errors --------------------------------------------------------------------- + +// ConversionError stores information about a failed conversion. +type ConversionError struct { + Key string // key from the source map. + Type reflect.Type // expected type of elem + Index int // index for multi-value fields; -1 for single-value fields. + Err error // low-level error (when it exists) +} + +func (e ConversionError) Error() string { + var output string + + if e.Index < 0 { + output = fmt.Sprintf("schema: error converting value for %q", e.Key) + } else { + output = fmt.Sprintf("schema: error converting value for index %d of %q", + e.Index, e.Key) + } + + if e.Err != nil { + output = fmt.Sprintf("%s. Details: %s", output, e.Err) + } + + return output +} + +// UnknownKeyError stores information about an unknown key in the source map. +type UnknownKeyError struct { + Key string // key from the source map. +} + +func (e UnknownKeyError) Error() string { + return fmt.Sprintf("schema: invalid path %q", e.Key) +} + +// EmptyFieldError stores information about an empty required field. +type EmptyFieldError struct { + Key string // required key in the source map. +} + +func (e EmptyFieldError) Error() string { + return fmt.Sprintf("%v is empty", e.Key) +} + +// MultiError stores multiple decoding errors. +// +// Borrowed from the App Engine SDK. +type MultiError map[string]error + +func (e MultiError) Error() string { + s := "" + for _, err := range e { + s = err.Error() + break + } + switch len(e) { + case 0: + return "(0 errors)" + case 1: + return s + case 2: + return s + " (and 1 other error)" + } + return fmt.Sprintf("%s (and %d other errors)", s, len(e)-1) +} + +func (e MultiError) merge(errors MultiError) { + for key, err := range errors { + if e[key] == nil { + e[key] = err + } + } +} diff --git a/internal/schema/doc.go b/internal/schema/doc.go new file mode 100644 index 0000000..fff0fe7 --- /dev/null +++ b/internal/schema/doc.go @@ -0,0 +1,148 @@ +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package gorilla/schema fills a struct with form values. + +The basic usage is really simple. Given this struct: + + type Person struct { + Name string + Phone string + } + +...we can fill it passing a map to the Decode() function: + + values := map[string][]string{ + "Name": {"John"}, + "Phone": {"999-999-999"}, + } + person := new(Person) + decoder := schema.NewDecoder() + decoder.Decode(person, values) + +This is just a simple example and it doesn't make a lot of sense to create +the map manually. Typically it will come from a http.Request object and +will be of type url.Values, http.Request.Form, or http.Request.MultipartForm: + + func MyHandler(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + + if err != nil { + // Handle error + } + + decoder := schema.NewDecoder() + // r.PostForm is a map of our POST form values + err := decoder.Decode(person, r.PostForm) + + if err != nil { + // Handle error + } + + // Do something with person.Name or person.Phone + } + +Note: it is a good idea to set a Decoder instance as a package global, +because it caches meta-data about structs, and an instance can be shared safely: + + var decoder = schema.NewDecoder() + +To define custom names for fields, use a struct tag "schema". To not populate +certain fields, use a dash for the name and it will be ignored: + + type Person struct { + Name string `schema:"name"` // custom name + Phone string `schema:"phone"` // custom name + Admin bool `schema:"-"` // this field is never set + } + +The supported field types in the destination struct are: + + - bool + - float variants (float32, float64) + - int variants (int, int8, int16, int32, int64) + - string + - uint variants (uint, uint8, uint16, uint32, uint64) + - struct + - a pointer to one of the above types + - a slice or a pointer to a slice of one of the above types + +Non-supported types are simply ignored, however custom types can be registered +to be converted. + +To fill nested structs, keys must use a dotted notation as the "path" for the +field. So for example, to fill the struct Person below: + + type Phone struct { + Label string + Number string + } + + type Person struct { + Name string + Phone Phone + } + +...the source map must have the keys "Name", "Phone.Label" and "Phone.Number". +This means that an HTML form to fill a Person struct must look like this: + +
+ + + +
+ +Single values are filled using the first value for a key from the source map. +Slices are filled using all values for a key from the source map. So to fill +a Person with multiple Phone values, like: + + type Person struct { + Name string + Phones []Phone + } + +...an HTML form that accepts three Phone values would look like this: + +
+ + + + + + + +
+ +Notice that only for slices of structs the slice index is required. +This is needed for disambiguation: if the nested struct also had a slice +field, we could not translate multiple values to it if we did not use an +index for the parent struct. + +There's also the possibility to create a custom type that implements the +TextUnmarshaler interface, and in this case there's no need to register +a converter, like: + + type Person struct { + Emails []Email + } + + type Email struct { + *mail.Address + } + + func (e *Email) UnmarshalText(text []byte) (err error) { + e.Address, err = mail.ParseAddress(string(text)) + return + } + +...an HTML form that accepts three Email values would look like this: + +
+ + + +
+*/ +package schema diff --git a/internal/schema/encoder.go b/internal/schema/encoder.go new file mode 100644 index 0000000..c01de00 --- /dev/null +++ b/internal/schema/encoder.go @@ -0,0 +1,202 @@ +package schema + +import ( + "errors" + "fmt" + "reflect" + "strconv" +) + +type encoderFunc func(reflect.Value) string + +// Encoder encodes values from a struct into url.Values. +type Encoder struct { + cache *cache + regenc map[reflect.Type]encoderFunc +} + +// NewEncoder returns a new Encoder with defaults. +func NewEncoder() *Encoder { + return &Encoder{cache: newCache(), regenc: make(map[reflect.Type]encoderFunc)} +} + +// Encode encodes a struct into map[string][]string. +// +// Intended for use with url.Values. +func (e *Encoder) Encode(src interface{}, dst map[string][]string) error { + v := reflect.ValueOf(src) + + return e.encode(v, dst) +} + +// RegisterEncoder registers a converter for encoding a custom type. +func (e *Encoder) RegisterEncoder(value interface{}, encoder func(reflect.Value) string) { + e.regenc[reflect.TypeOf(value)] = encoder +} + +// SetAliasTag changes the tag used to locate custom field aliases. +// The default tag is "schema". +func (e *Encoder) SetAliasTag(tag string) { + e.cache.tag = tag +} + +// isValidStructPointer test if input value is a valid struct pointer. +func isValidStructPointer(v reflect.Value) bool { + return v.Type().Kind() == reflect.Ptr && v.Elem().IsValid() && v.Elem().Type().Kind() == reflect.Struct +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Func: + case reflect.Map, reflect.Slice: + return v.IsNil() || v.Len() == 0 + case reflect.Array: + z := true + for i := 0; i < v.Len(); i++ { + z = z && isZero(v.Index(i)) + } + return z + case reflect.Struct: + type zero interface { + IsZero() bool + } + if v.Type().Implements(reflect.TypeOf((*zero)(nil)).Elem()) { + iz := v.MethodByName("IsZero").Call([]reflect.Value{})[0] + return iz.Interface().(bool) + } + z := true + for i := 0; i < v.NumField(); i++ { + z = z && isZero(v.Field(i)) + } + return z + } + // Compare other types directly: + z := reflect.Zero(v.Type()) + return v.Interface() == z.Interface() +} + +func (e *Encoder) encode(v reflect.Value, dst map[string][]string) error { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + return errors.New("schema: interface must be a struct") + } + t := v.Type() + + errors := MultiError{} + + for i := 0; i < v.NumField(); i++ { + name, opts := fieldAlias(t.Field(i), e.cache.tag) + if name == "-" { + continue + } + + // Encode struct pointer types if the field is a valid pointer and a struct. + if isValidStructPointer(v.Field(i)) { + _ = e.encode(v.Field(i).Elem(), dst) + continue + } + + encFunc := typeEncoder(v.Field(i).Type(), e.regenc) + + // Encode non-slice types and custom implementations immediately. + if encFunc != nil { + value := encFunc(v.Field(i)) + if opts.Contains("omitempty") && isZero(v.Field(i)) { + continue + } + + dst[name] = append(dst[name], value) + continue + } + + if v.Field(i).Type().Kind() == reflect.Struct { + _ = e.encode(v.Field(i), dst) + continue + } + + if v.Field(i).Type().Kind() == reflect.Slice { + encFunc = typeEncoder(v.Field(i).Type().Elem(), e.regenc) + } + + if encFunc == nil { + errors[v.Field(i).Type().String()] = fmt.Errorf("schema: encoder not found for %v", v.Field(i)) + continue + } + + // Encode a slice. + if v.Field(i).Len() == 0 && opts.Contains("omitempty") { + continue + } + + dst[name] = []string{} + for j := 0; j < v.Field(i).Len(); j++ { + dst[name] = append(dst[name], encFunc(v.Field(i).Index(j))) + } + } + + if len(errors) > 0 { + return errors + } + return nil +} + +func typeEncoder(t reflect.Type, reg map[reflect.Type]encoderFunc) encoderFunc { + if f, ok := reg[t]; ok { + return f + } + + switch t.Kind() { + case reflect.Bool: + return encodeBool + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return encodeInt + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return encodeUint + case reflect.Float32: + return encodeFloat32 + case reflect.Float64: + return encodeFloat64 + case reflect.Ptr: + f := typeEncoder(t.Elem(), reg) + return func(v reflect.Value) string { + if v.IsNil() { + return "null" + } + return f(v.Elem()) + } + case reflect.String: + return encodeString + default: + return nil + } +} + +func encodeBool(v reflect.Value) string { + return strconv.FormatBool(v.Bool()) +} + +func encodeInt(v reflect.Value) string { + return strconv.FormatInt(int64(v.Int()), 10) +} + +func encodeUint(v reflect.Value) string { + return strconv.FormatUint(uint64(v.Uint()), 10) +} + +func encodeFloat(v reflect.Value, bits int) string { + return strconv.FormatFloat(v.Float(), 'f', 6, bits) +} + +func encodeFloat32(v reflect.Value) string { + return encodeFloat(v, 32) +} + +func encodeFloat64(v reflect.Value) string { + return encodeFloat(v, 64) +} + +func encodeString(v reflect.Value) string { + return v.String() +} diff --git a/internal/sse/sse-encoder.go b/internal/sse/sse-encoder.go new file mode 100644 index 0000000..26100e7 --- /dev/null +++ b/internal/sse/sse-encoder.go @@ -0,0 +1,106 @@ +package sse + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "reflect" + "strconv" + "strings" +) + +// Server-Sent Events +// W3C Working Draft 29 October 2009 +// http://www.w3.org/TR/2009/WD-eventsource-20091029/ + +const ContentType = "text/event-stream" + +var contentType = []string{ContentType} +var noCache = []string{"no-cache"} + +var fieldReplacer = strings.NewReplacer( + "\n", "\\n", + "\r", "\\r") + +var dataReplacer = strings.NewReplacer( + "\n", "\ndata:", + "\r", "\\r") + +type Event struct { + Event string + Id string + Retry uint + Data interface{} +} + +func Encode(writer io.Writer, event Event) error { + w := checkWriter(writer) + writeId(w, event.Id) + writeEvent(w, event.Event) + writeRetry(w, event.Retry) + return writeData(w, event.Data) +} + +func writeId(w stringWriter, id string) { + if len(id) > 0 { + w.WriteString("id:") + fieldReplacer.WriteString(w, id) + w.WriteString("\n") + } +} + +func writeEvent(w stringWriter, event string) { + if len(event) > 0 { + w.WriteString("event:") + fieldReplacer.WriteString(w, event) + w.WriteString("\n") + } +} + +func writeRetry(w stringWriter, retry uint) { + if retry > 0 { + w.WriteString("retry:") + w.WriteString(strconv.FormatUint(uint64(retry), 10)) + w.WriteString("\n") + } +} + +func writeData(w stringWriter, data interface{}) error { + w.WriteString("data:") + switch kindOfData(data) { + case reflect.Struct, reflect.Slice, reflect.Map: + err := json.NewEncoder(w).Encode(data) + if err != nil { + return err + } + w.WriteString("\n") + default: + dataReplacer.WriteString(w, fmt.Sprint(data)) + w.WriteString("\n\n") + } + return nil +} + +func (r Event) Render(w http.ResponseWriter) error { + r.WriteContentType(w) + return Encode(w, r) +} + +func (r Event) WriteContentType(w http.ResponseWriter) { + header := w.Header() + header["Content-Type"] = contentType + + if _, exist := header["Cache-Control"]; !exist { + header["Cache-Control"] = noCache + } +} + +func kindOfData(data interface{}) reflect.Kind { + value := reflect.ValueOf(data) + valueType := value.Kind() + if valueType == reflect.Ptr { + valueType = value.Elem().Kind() + } + return valueType +} diff --git a/internal/sse/writer.go b/internal/sse/writer.go new file mode 100644 index 0000000..6f9806c --- /dev/null +++ b/internal/sse/writer.go @@ -0,0 +1,24 @@ +package sse + +import "io" + +type stringWriter interface { + io.Writer + WriteString(string) (int, error) +} + +type stringWrapper struct { + io.Writer +} + +func (w stringWrapper) WriteString(str string) (int, error) { + return w.Writer.Write([]byte(str)) +} + +func checkWriter(writer io.Writer) stringWriter { + if w, ok := writer.(stringWriter); ok { + return w + } else { + return stringWrapper{writer} + } +} diff --git a/pkg/api/api.go b/pkg/api/api.go new file mode 100644 index 0000000..5ce52bc --- /dev/null +++ b/pkg/api/api.go @@ -0,0 +1,97 @@ +package api + +import "sync" + +const ( + _404 = `404 Not Found` + _405 = `405 Method Not Allowed` + _500 = `500 Internal Server Error` + TraceKey = "X-Trace-Id" +) + +type Map map[string]interface{} + +type Config struct { + DisableMessagePrint bool `json:"-"` + // Default: 4 * 1024 * 1024 + BodyLimit int64 `json:"-"` + + // if report http.ErrServerClosed as run err + ErrServeClose bool `json:"-"` + + DisableLogger bool `json:"-"` + DisableRecover bool `json:"-"` + DisableHttpErrorLog bool `json:"-"` + + // EnableNotImplementHandler bool `json:"-"` + NotFoundHandler HandlerFunc `json:"-"` + MethodNotAllowedHandler HandlerFunc `json:"-"` +} + +var defaultConfig = &Config{ + BodyLimit: 4 * 1024 * 1024, + NotFoundHandler: func(c *Ctx) error { + c.Set("Content-Type", MIMETextPlain) + _, err := c.Status(404).Write([]byte(_404)) + return err + }, + MethodNotAllowedHandler: func(c *Ctx) error { + c.Set("Content-Type", MIMETextPlain) + _, err := c.Status(405).Write([]byte(_405)) + return err + }, +} + +func New(config ...Config) *App { + app := &App{ + RouterGroup: RouterGroup{ + Handlers: nil, + basePath: "/", + root: true, + }, + + pool: &sync.Pool{}, + + redirectTrailingSlash: true, // true + redirectFixedPath: false, // false + handleMethodNotAllowed: true, // false + useRawPath: false, // false + unescapePathValues: true, // true + removeExtraSlash: false, // false + } + + if len(config) > 0 { + app.config = &config[0] + + if app.config.BodyLimit == 0 { + app.config.BodyLimit = defaultConfig.BodyLimit + } + + if app.config.NotFoundHandler == nil { + app.config.NotFoundHandler = defaultConfig.NotFoundHandler + } + + if app.config.MethodNotAllowedHandler == nil { + app.config.MethodNotAllowedHandler = defaultConfig.MethodNotAllowedHandler + } + + } else { + app.config = defaultConfig + } + + app.RouterGroup.app = app + + if !app.config.DisableLogger { + app.Use(NewLogger()) + } + + if !app.config.DisableRecover { + app.Use(NewRecover(true)) + } + + app.pool.New = func() any { + return app.allocateContext() + } + + return app +} diff --git a/pkg/api/app.go b/pkg/api/app.go new file mode 100644 index 0000000..dbd05f2 --- /dev/null +++ b/pkg/api/app.go @@ -0,0 +1,300 @@ +package api + +import ( + "context" + "crypto/tls" + "errors" + "io" + "log" + "net" + "net/http" + "path" + "regexp" + "sync" + + "github.com/loveuer/upp/internal/bytesconv" + "github.com/loveuer/upp/pkg/interfaces" +) + +var ( + _ IRouter = (*App)(nil) + + regSafePrefix = regexp.MustCompile("[^a-zA-Z0-9/-]+") + regRemoveRepeatedChar = regexp.MustCompile("/{2,}") +) + +type App struct { + RouterGroup + Upp interfaces.Upp + config *Config + groups []*RouterGroup + server *http.Server + + trees methodTrees + + pool *sync.Pool + + maxParams uint16 + maxSections uint16 + + redirectTrailingSlash bool // true + redirectFixedPath bool // false + handleMethodNotAllowed bool // false + useRawPath bool // false + unescapePathValues bool // true + removeExtraSlash bool // false +} + +func (a *App) allocateContext() *Ctx { + var ( + skippedNodes = make([]skippedNode, 0, a.maxSections) + v = make(Params, 0, a.maxParams) + ) + + ctx := Ctx{ + lock: sync.Mutex{}, + app: a, + index: -1, + locals: make(map[string]any), + handlers: make([]HandlerFunc, 0), + skippedNodes: &skippedNodes, + params: &v, + } + + return &ctx +} + +func (a *App) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + var ( + err error + c = a.pool.Get().(*Ctx) + nfe = new(Err) + ) + + c.reset(writer, request) + + if err = c.verify(); err != nil { + if errors.As(err, nfe) { + _ = c.Status(nfe.Status).SendString(nfe.Msg) + return + } + + _ = c.Status(500).SendString(err.Error()) + return + } + + a.handleHTTPRequest(c) + + a.pool.Put(c) +} + +func (a *App) run(ln net.Listener) error { + srv := &http.Server{Handler: a} + + if a.config.DisableHttpErrorLog { + srv.ErrorLog = log.New(io.Discard, "", 0) + } + + a.server = srv + + err := a.server.Serve(ln) + if !errors.Is(err, http.ErrServerClosed) || a.config.ErrServeClose { + return err + } + + return nil +} + +func (a *App) Run(address string) error { + ln, err := net.Listen("tcp", address) + if err != nil { + return err + } + + return a.run(ln) +} + +func (a *App) RunTLS(address string, tlsConfig *tls.Config) error { + ln, err := tls.Listen("tcp", address, tlsConfig) + if err != nil { + return err + } + + return a.run(ln) +} + +func (a *App) RunListener(ln net.Listener) error { + a.server = &http.Server{Addr: ln.Addr().String()} + + return a.run(ln) +} + +func (a *App) Shutdown(ctx context.Context) error { + return a.server.Shutdown(ctx) +} + +func (a *App) addRoute(method, path string, handlers ...HandlerFunc) { + elsePanic(path[0] == '/', "path must begin with '/'") + elsePanic(method != "", "HTTP method can not be empty") + elsePanic(len(handlers) > 0, "without enable not implement, there must be at least one handler") + + // if !a.config.DisableMessagePrint { + // fmt.Printf("[NF] Add Route: %-8s - %-25s (%2d handlers)\n", method, path, len(handlers)) + // } + + root := a.trees.get(method) + if root == nil { + root = new(node) + root.fullPath = "/" + a.trees = append(a.trees, methodTree{method: method, root: root}) + } + + root.addRoute(path, handlers...) + + if paramsCount := countParams(path); paramsCount > a.maxParams { + a.maxParams = paramsCount + } + + if sectionsCount := countSections(path); sectionsCount > a.maxSections { + a.maxSections = sectionsCount + } +} + +func (a *App) handleHTTPRequest(c *Ctx) { + var err error + + httpMethod := c.Request.Method + rPath := c.Request.URL.Path + unescape := false + if a.useRawPath && len(c.Request.URL.RawPath) > 0 { + rPath = c.Request.URL.RawPath + unescape = a.unescapePathValues + } + + if a.removeExtraSlash { + rPath = cleanPath(rPath) + } + + // Find root of the tree for the given HTTP method + t := a.trees + for i, tl := 0, len(t); i < tl; i++ { + if t[i].method != httpMethod { + continue + } + root := t[i].root + // Find route in tree + value := root.getValue(rPath, c.params, c.skippedNodes, unescape) + if value.params != nil { + c.params = value.params + } + + if value.handlers != nil { + c.handlers = value.handlers + c.fullPath = value.fullPath + + if err = c.Next(); err != nil { + serveError(c, errorHandler) + } + + return + } + if httpMethod != http.MethodConnect && rPath != "/" { + if value.tsr && a.redirectTrailingSlash { + redirectTrailingSlash(c) + return + } + if a.redirectFixedPath && redirectFixedPath(c, root, a.redirectFixedPath) { + return + } + } + break + } + + if a.handleMethodNotAllowed { + // According to RFC 7231 section 6.5.5, MUST generate an Allow header field in response + // containing a list of the target resource's currently supported methods. + allowed := make([]string, 0, len(t)-1) + for _, tree := range a.trees { + if tree.method == httpMethod { + continue + } + if value := tree.root.getValue(rPath, nil, c.skippedNodes, unescape); value.handlers != nil { + allowed = append(allowed, tree.method) + } + } + + if len(allowed) > 0 { + c.handlers = a.combineHandlers(a.config.MethodNotAllowedHandler) + + _ = c.Next() + + return + } + } + + c.handlers = a.combineHandlers(a.config.NotFoundHandler) + + _ = c.Next() + + return +} + +func errorHandler(c *Ctx) error { + return c.Status(500).SendString(_500) +} + +func serveError(c *Ctx, handler HandlerFunc) { + err := c.Next() + + if c.writermem.Written() { + return + } + + _ = handler(c) + _ = err +} + +func redirectTrailingSlash(c *Ctx) { + req := c.Request + p := req.URL.Path + if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." { + prefix = regSafePrefix.ReplaceAllString(prefix, "") + prefix = regRemoveRepeatedChar.ReplaceAllString(prefix, "/") + + p = prefix + "/" + req.URL.Path + } + req.URL.Path = p + "/" + if length := len(p); length > 1 && p[length-1] == '/' { + req.URL.Path = p[:length-1] + } + + redirectRequest(c) +} + +func redirectFixedPath(c *Ctx, root *node, trailingSlash bool) bool { + req := c.Request + rPath := req.URL.Path + + if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok { + req.URL.Path = bytesconv.BytesToString(fixedPath) + redirectRequest(c) + return true + } + return false +} + +func redirectRequest(c *Ctx) { + req := c.Request + // rPath := req.URL.Path + rURL := req.URL.String() + + code := http.StatusMovedPermanently // Permanent redirect, request with GET method + if req.Method != http.MethodGet { + code = http.StatusTemporaryRedirect + } + + // debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) + + http.Redirect(c.Writer, req, rURL, code) + c.writermem.WriteHeaderNow() +} diff --git a/pkg/api/ctx.go b/pkg/api/ctx.go new file mode 100644 index 0000000..d5a506e --- /dev/null +++ b/pkg/api/ctx.go @@ -0,0 +1,369 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "mime/multipart" + "net" + "net/http" + "strings" + "sync" + + "github.com/google/uuid" + "github.com/loveuer/upp/internal/sse" + "gorm.io/gorm" +) + +var forwardHeaders = []string{"CF-Connecting-IP", "X-Forwarded-For", "X-Real-Ip"} + +type Ctx struct { + lock sync.Mutex + writermem responseWriter + Writer ResponseWriter + Request *http.Request + path string + method string + StatusCode int + + app *App + params *Params + index int + handlers []HandlerFunc + locals map[string]interface{} + skippedNodes *[]skippedNode + fullPath string +} + +func (c *Ctx) UseDB() *gorm.DB { + return c.app.Upp.UseDB() +} + +func (c *Ctx) reset(w http.ResponseWriter, r *http.Request) { + traceId := r.Header.Get(TraceKey) + if traceId == "" { + traceId = uuid.Must(uuid.NewV7()).String() + } + + c.writermem.reset(w) + + c.Request = r.WithContext(context.WithValue(r.Context(), TraceKey, traceId)) + c.Writer = &c.writermem + c.handlers = nil + c.index = -1 + c.path = r.URL.Path + c.method = r.Method + c.StatusCode = 200 + + c.fullPath = "" + *c.params = (*c.params)[:0] + *c.skippedNodes = (*c.skippedNodes)[:0] + for key := range c.locals { + delete(c.locals, key) + } + c.writermem.Header().Set(TraceKey, traceId) +} + +func (c *Ctx) Locals(key string, value ...interface{}) interface{} { + data := c.locals[key] + if len(value) > 0 { + c.locals[key] = value[0] + } + + return data +} + +func (c *Ctx) Method(overWrite ...string) string { + method := c.Request.Method + + if len(overWrite) > 0 && overWrite[0] != "" { + c.Request.Method = overWrite[0] + } + + return method +} + +func (c *Ctx) Path(overWrite ...string) string { + path := c.Request.URL.Path + if len(overWrite) > 0 && overWrite[0] != "" { + c.Request.URL.Path = overWrite[0] + } + + return path +} + +func (c *Ctx) Cookies(key string, defaultValue ...string) string { + dv := "" + + if len(defaultValue) > 0 { + dv = defaultValue[0] + } + + cookie, err := c.Request.Cookie(key) + if err != nil || cookie.Value == "" { + return dv + } + + return cookie.Value +} + +func (c *Ctx) Context() context.Context { + return c.Request.Context() +} + +func (c *Ctx) Next() error { + c.index++ + + if c.index >= len(c.handlers) { + return nil + } + + var ( + err error + handler = c.handlers[c.index] + ) + + if handler != nil { + if err = handler(c); err != nil { + return err + } + } + + c.index++ + + return nil +} + +/* =============================================================== +|| Handle Ctx Request Part +=============================================================== */ + +func (c *Ctx) verify() error { + // 验证 body size + if c.app.config.BodyLimit != -1 && c.Request.ContentLength > c.app.config.BodyLimit { + return NewNFError(413, "Content Too Large") + } + + return nil +} + +func (c *Ctx) Param(key string) string { + return c.params.ByName(key) +} + +func (c *Ctx) SetParam(key, value string) { + c.lock.Lock() + defer c.lock.Unlock() + + params := append(*c.params, Param{Key: key, Value: value}) + c.params = ¶ms +} + +func (c *Ctx) Form(key string) string { + return c.Request.FormValue(key) +} + +// FormValue fiber ctx function +func (c *Ctx) FormValue(key string) string { + return c.Request.FormValue(key) +} + +func (c *Ctx) FormFile(key string) (*multipart.FileHeader, error) { + _, fh, err := c.Request.FormFile(key) + return fh, err +} + +func (c *Ctx) MultipartForm() (*multipart.Form, error) { + if err := c.Request.ParseMultipartForm(c.app.config.BodyLimit); err != nil { + return nil, err + } + + return c.Request.MultipartForm, nil +} + +func (c *Ctx) Query(key string) string { + return c.Request.URL.Query().Get(key) +} + +func (c *Ctx) Get(key string, defaultValue ...string) string { + value := c.Request.Header.Get(key) + if value == "" && len(defaultValue) > 0 { + return defaultValue[0] + } + + return value +} + +func (c *Ctx) IP(useProxyHeader ...bool) string { + ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) + if err != nil { + return "" + } + + if len(useProxyHeader) > 0 && useProxyHeader[0] { + for _, h := range forwardHeaders { + for _, rip := range strings.Split(c.Request.Header.Get(h), ",") { + realIP := net.ParseIP(strings.Replace(rip, " ", "", -1)) + if check := net.ParseIP(realIP.String()); check != nil { + ip = realIP.String() + break + } + } + } + } + + return ip +} + +func (c *Ctx) BodyParser(out interface{}) error { + var ( + err error + ctype = strings.ToLower(c.Request.Header.Get("Content-Type")) + ) + + ctype = parseVendorSpecificContentType(ctype) + + ctypeEnd := strings.IndexByte(ctype, ';') + if ctypeEnd != -1 { + ctype = ctype[:ctypeEnd] + } + + if strings.HasSuffix(ctype, "json") { + bs, err := io.ReadAll(c.Request.Body) + if err != nil { + return err + } + _ = c.Request.Body.Close() + + c.Request.Body = io.NopCloser(bytes.NewReader(bs)) + + return json.Unmarshal(bs, out) + } + + if strings.HasPrefix(ctype, MIMEApplicationForm) { + + if err = c.Request.ParseForm(); err != nil { + return NewNFError(400, err.Error()) + } + + return parseToStruct("form", out, c.Request.Form) + } + + if strings.HasPrefix(ctype, MIMEMultipartForm) { + if err = c.Request.ParseMultipartForm(c.app.config.BodyLimit); err != nil { + return NewNFError(400, err.Error()) + } + + return parseToStruct("form", out, c.Request.PostForm) + } + + return NewNFError(422, "Unprocessable Content") +} + +func (c *Ctx) QueryParser(out interface{}) error { + return parseToStruct("query", out, c.Request.URL.Query()) +} + +/* =============================================================== +|| Handle Ctx Response Part +=============================================================== */ + +func (c *Ctx) Status(code int) *Ctx { + c.lock.Lock() + defer c.lock.Unlock() + + c.Writer.WriteHeader(code) + c.StatusCode = c.writermem.status + + return c +} + +// Set set response header +func (c *Ctx) Set(key string, value string) { + c.Writer.Header().Set(key, value) +} + +// AddHeader add response header +func (c *Ctx) AddHeader(key string, value string) { + c.Writer.Header().Add(key, value) +} + +// SetHeader set response header +func (c *Ctx) SetHeader(key string, value string) { + c.Writer.Header().Set(key, value) +} + +func (c *Ctx) SendStatus(code int) error { + c.Status(code) + c.Writer.WriteHeaderNow() + return nil +} + +func (c *Ctx) SendString(data string) error { + c.SetHeader("Content-Type", "text/plain") + _, err := c.Write([]byte(data)) + return err +} + +func (c *Ctx) Writef(format string, values ...interface{}) (int, error) { + c.SetHeader("Content-Type", "text/plain") + return c.Write([]byte(fmt.Sprintf(format, values...))) +} + +func (c *Ctx) JSON(data interface{}) error { + c.SetHeader("Content-Type", MIMEApplicationJSON) + + encoder := json.NewEncoder(c.Writer) + + if err := encoder.Encode(data); err != nil { + return err + } + + return nil +} + +func (c *Ctx) SSEvent(event string, data interface{}) error { + c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Transfer-Encoding", "chunked") + + return sse.Encode(c.Writer, sse.Event{Event: event, Data: data}) +} + +func (c *Ctx) Flush() error { + if f, ok := c.Writer.(http.Flusher); ok { + f.Flush() + return nil + } + + return errors.New("http.Flusher is not implemented") +} + +func (c *Ctx) HTML(html string) error { + c.SetHeader("Content-Type", "text/html") + _, err := c.Write([]byte(html)) + return err +} + +func (c *Ctx) RenderHTML(name, html string, obj any) error { + c.SetHeader("Content-Type", "text/html") + t, err := template.New(name).Parse(html) + if err != nil { + return err + } + + return t.Execute(c.Writer, obj) +} + +func (c *Ctx) Redirect(url string, code int) error { + http.Redirect(c.Writer, c.Request, url, code) + return nil +} + +func (c *Ctx) Write(data []byte) (int, error) { + return c.Writer.Write(data) +} diff --git a/pkg/api/error.go b/pkg/api/error.go new file mode 100644 index 0000000..e286e1e --- /dev/null +++ b/pkg/api/error.go @@ -0,0 +1,16 @@ +package api + +import "strconv" + +type Err struct { + Status int + Msg string +} + +func (n Err) Error() string { + return strconv.Itoa(n.Status) + " " + n.Msg +} + +func NewNFError(status int, msg string) Err { + return Err{Status: status, Msg: msg} +} diff --git a/pkg/api/handler.go b/pkg/api/handler.go new file mode 100644 index 0000000..c185d90 --- /dev/null +++ b/pkg/api/handler.go @@ -0,0 +1,9 @@ +package api + +import "fmt" + +type HandlerFunc func(*Ctx) error + +func ToDoHandler(c *Ctx) error { + return c.Status(501).SendString(fmt.Sprintf("%s - %s Not Implemented", c.Method(), c.Path())) +} diff --git a/pkg/api/middleware.go b/pkg/api/middleware.go new file mode 100644 index 0000000..c1e6747 --- /dev/null +++ b/pkg/api/middleware.go @@ -0,0 +1,67 @@ +package api + +import ( + "fmt" + "os" + "runtime/debug" + "strconv" + "time" + + "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/log" + "github.com/loveuer/nf/nft/resp" + "github.com/loveuer/upp/pkg/tool" +) + +func NewRecover(enableStackTrace bool) HandlerFunc { + return func(c *Ctx) error { + defer func() { + if r := recover(); r != nil { + if enableStackTrace { + os.Stderr.WriteString(fmt.Sprintf("recovered from panic: %v\n%s\n", r, debug.Stack())) + } else { + os.Stderr.WriteString(fmt.Sprintf("recovered from panic: %v\n", r)) + } + + _ = c.Status(500).SendString(fmt.Sprint(r)) + } + }() + + return c.Next() + } +} + +func NewLogger() HandlerFunc { + return func(c *Ctx) error { + var ( + now = time.Now() + logFn func(msg string, data ...any) + ip = c.IP() + ) + + traceId := c.Context().Value(nf.TraceKey) + c.Locals(nf.TraceKey, traceId) + + err := c.Next() + + c.Writer.Header().Set(nf.TraceKey, fmt.Sprint(traceId)) + + status, _ := strconv.Atoi(c.Writer.Header().Get(resp.RealStatusHeader)) + duration := time.Since(now) + + msg := fmt.Sprintf("%s | %15s | %d[%3d] | %s | %6s | %s", traceId, ip, c.StatusCode, status, tool.HumanDuration(duration.Nanoseconds()), c.Method(), c.Path()) + + switch { + case status >= 500: + logFn = log.Error + case status >= 400: + logFn = log.Warn + default: + logFn = log.Info + } + + logFn(msg) + + return err + } +} diff --git a/pkg/api/response_writer.go b/pkg/api/response_writer.go new file mode 100644 index 0000000..e7608ac --- /dev/null +++ b/pkg/api/response_writer.go @@ -0,0 +1,134 @@ +package api + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" +) + +const ( + noWritten = -1 + defaultStatus = http.StatusOK +) + +// ResponseWriter ... +type ResponseWriter interface { + http.ResponseWriter + http.Hijacker + http.Flusher + http.CloseNotifier + + // Status returns the HTTP response status code of the current request. + Status() int + + // Size returns the number of bytes already written into the response http body. + // See Written() + Size() int + + // WriteString writes the string into the response body. + WriteString(string) (int, error) + + // Written returns true if the response body was already written. + Written() bool + + // WriteHeaderNow forces to write the http header (status code + headers). + WriteHeaderNow() + + // Pusher get the http.Pusher for server push + Pusher() http.Pusher +} + +type responseWriter struct { + http.ResponseWriter + written bool + size int + status int +} + +var _ ResponseWriter = (*responseWriter)(nil) + +func (w *responseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func (w *responseWriter) reset(writer http.ResponseWriter) { + w.ResponseWriter = writer + w.size = noWritten + w.status = defaultStatus +} + +func (w *responseWriter) WriteHeader(code int) { + if code > 0 && w.status != code { + if w.Written() { + fmt.Printf("WARNING: Headers were already written. Wanted to override status code %d with %d", w.status, code) + return + } + w.status = code + } +} + +func (w *responseWriter) WriteHeaderNow() { + if !w.Written() { + w.size = 0 + + if w.status == 0 { + w.status = 200 + } + + w.ResponseWriter.WriteHeader(w.status) + } +} + +func (w *responseWriter) Write(data []byte) (n int, err error) { + w.WriteHeaderNow() + n, err = w.ResponseWriter.Write(data) + w.size += n + return +} + +func (w *responseWriter) WriteString(s string) (n int, err error) { + w.WriteHeaderNow() + n, err = io.WriteString(w.ResponseWriter, s) + w.size += n + return +} + +func (w *responseWriter) Status() int { + return w.status +} + +func (w *responseWriter) Size() int { + return w.size +} + +func (w *responseWriter) Written() bool { + return w.size != noWritten || w.written +} + +// Hijack implements the http.Hijacker interface. +func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if w.size < 0 { + w.size = 0 + } + return w.ResponseWriter.(http.Hijacker).Hijack() +} + +// CloseNotify implements the http.CloseNotifier interface. +func (w *responseWriter) CloseNotify() <-chan bool { + return w.ResponseWriter.(http.CloseNotifier).CloseNotify() +} + +// Flush implements the http.Flusher interface. +func (w *responseWriter) Flush() { + w.WriteHeaderNow() + w.ResponseWriter.(http.Flusher).Flush() +} + +func (w *responseWriter) Pusher() (pusher http.Pusher) { + if pusher, ok := w.ResponseWriter.(http.Pusher); ok { + return pusher + } + return nil +} diff --git a/pkg/api/routergroup.go b/pkg/api/routergroup.go new file mode 100644 index 0000000..548f51d --- /dev/null +++ b/pkg/api/routergroup.go @@ -0,0 +1,155 @@ +package api + +import ( + "math" + "net/http" + "path" + "regexp" +) + +var ( + // regEnLetter matches english letters for http method name + regEnLetter = regexp.MustCompile("^[A-Z]+$") + + // anyMethods for RouterGroup Any method + anyMethods = []string{ + http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, + http.MethodHead, http.MethodOptions, http.MethodDelete, http.MethodConnect, + http.MethodTrace, + } +) + +// IRouter defines all router handle interface includes single and group router. +type IRouter interface { + IRoutes + Group(string, ...HandlerFunc) *RouterGroup +} + +// IRoutes defines all router handle interface. +type IRoutes interface { + Use(...HandlerFunc) IRoutes + + Handle(string, string, ...HandlerFunc) IRoutes + Any(string, ...HandlerFunc) IRoutes + GET(string, ...HandlerFunc) IRoutes + POST(string, ...HandlerFunc) IRoutes + DELETE(string, ...HandlerFunc) IRoutes + PATCH(string, ...HandlerFunc) IRoutes + PUT(string, ...HandlerFunc) IRoutes + OPTIONS(string, ...HandlerFunc) IRoutes + HEAD(string, ...HandlerFunc) IRoutes + Match([]string, string, ...HandlerFunc) IRoutes + + // StaticFile(string, string) IRoutes + // StaticFileFS(string, string, http.FileSystem) IRoutes + // Static(string, string) IRoutes + // StaticFS(string, http.FileSystem) IRoutes +} + +type RouterGroup struct { + Handlers []HandlerFunc + basePath string + app *App + root bool +} + +var _ IRouter = (*RouterGroup)(nil) + +func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes { + group.Handlers = append(group.Handlers, middleware...) + return group.returnObj() +} + +func (group *RouterGroup) Group(relativePath string, handlers ...HandlerFunc) *RouterGroup { + return &RouterGroup{ + Handlers: group.combineHandlers(handlers...), + basePath: group.calculateAbsolutePath(relativePath), + app: group.app, + } +} + +func (group *RouterGroup) BasePath() string { + return group.basePath +} + +func (group *RouterGroup) handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { + absolutePath := group.calculateAbsolutePath(relativePath) + handlers = group.combineHandlers(handlers...) + group.app.addRoute(httpMethod, absolutePath, handlers...) + return group.returnObj() +} + +func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...HandlerFunc) IRoutes { + if matched := regEnLetter.MatchString(httpMethod); !matched { + panic("http method " + httpMethod + " is not valid") + } + return group.handle(httpMethod, relativePath, handlers...) +} + +func (group *RouterGroup) POST(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle(http.MethodPost, relativePath, handlers...) +} + +func (group *RouterGroup) GET(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle(http.MethodGet, relativePath, handlers...) +} + +func (group *RouterGroup) DELETE(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle(http.MethodDelete, relativePath, handlers...) +} + +func (group *RouterGroup) PATCH(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle(http.MethodPatch, relativePath, handlers...) +} + +func (group *RouterGroup) PUT(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle(http.MethodPut, relativePath, handlers...) +} + +func (group *RouterGroup) OPTIONS(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle(http.MethodOptions, relativePath, handlers...) +} + +func (group *RouterGroup) HEAD(relativePath string, handlers ...HandlerFunc) IRoutes { + return group.handle(http.MethodHead, relativePath, handlers...) +} + +// Any registers a route that matches all the HTTP methods. +// GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE. +func (group *RouterGroup) Any(relativePath string, handlers ...HandlerFunc) IRoutes { + for _, method := range anyMethods { + group.handle(method, relativePath, handlers...) + } + + return group.returnObj() +} + +func (group *RouterGroup) Match(methods []string, relativePath string, handlers ...HandlerFunc) IRoutes { + for _, method := range methods { + group.handle(method, relativePath, handlers...) + } + + return group.returnObj() +} + +const abortIndex int8 = math.MaxInt8 >> 1 + +func (group *RouterGroup) combineHandlers(handlers ...HandlerFunc) []HandlerFunc { + finalSize := len(group.Handlers) + len(handlers) + elsePanic(finalSize < int(abortIndex), "too many handlers") + mergedHandlers := make([]HandlerFunc, finalSize) + copy(mergedHandlers, group.Handlers) + copy(mergedHandlers[len(group.Handlers):], handlers) + return mergedHandlers +} + +func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { + return path.Join(group.basePath, relativePath) +} + +func (group *RouterGroup) returnObj() IRoutes { + if group.root { + return group.app + } + return group +} diff --git a/pkg/api/tree.go b/pkg/api/tree.go new file mode 100644 index 0000000..542e992 --- /dev/null +++ b/pkg/api/tree.go @@ -0,0 +1,891 @@ +package api + +import ( + "bytes" + "net/url" + "strings" + "unicode" + "unicode/utf8" + + "github.com/loveuer/upp/internal/bytesconv" +) + +var ( + strColon = []byte(":") + strStar = []byte("*") + strSlash = []byte("/") +) + +// Param is a single URL parameter, consisting of a key and a value. +type Param struct { + Key string + Value string +} + +// Params is a Param-slice, as returned by the router. +// The slice is ordered, the first URL parameter is also the first slice value. +// It is therefore safe to read values by the index. +type Params []Param + +// Get returns the value of the first Param which key matches the given name and a boolean true. +// If no matching Param is found, an empty string is returned and a boolean false . +func (ps Params) Get(name string) (string, bool) { + for _, entry := range ps { + if entry.Key == name { + return entry.Value, true + } + } + return "", false +} + +// ByName returns the value of the first Param which key matches the given name. +// If no matching Param is found, an empty string is returned. +func (ps Params) ByName(name string) (va string) { + va, _ = ps.Get(name) + return +} + +type methodTree struct { + method string + root *node +} + +type methodTrees []methodTree + +func (trees methodTrees) get(method string) *node { + for _, tree := range trees { + if tree.method == method { + return tree.root + } + } + return nil +} + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +func longestCommonPrefix(a, b string) int { + i := 0 + max := min(len(a), len(b)) + for i < max && a[i] == b[i] { + i++ + } + return i +} + +// addChild will add a child node, keeping wildcardChild at the end +func (n *node) addChild(child *node) { + if n.wildChild && len(n.children) > 0 { + wildcardChild := n.children[len(n.children)-1] + n.children = append(n.children[:len(n.children)-1], child, wildcardChild) + } else { + n.children = append(n.children, child) + } +} + +func countParams(path string) uint16 { + var n uint16 + s := bytesconv.StringToBytes(path) + n += uint16(bytes.Count(s, strColon)) + n += uint16(bytes.Count(s, strStar)) + return n +} + +func countSections(path string) uint16 { + s := bytesconv.StringToBytes(path) + return uint16(bytes.Count(s, strSlash)) +} + +type nodeType uint8 + +const ( + static nodeType = iota + root + param + catchAll +) + +type node struct { + path string + indices string + wildChild bool + nType nodeType + priority uint32 + children []*node // child nodes, at most 1 :param style node at the end of the array + handlers []HandlerFunc + fullPath string +} + +// Increments priority of the given child and reorders if necessary +func (n *node) incrementChildPrio(pos int) int { + cs := n.children + cs[pos].priority++ + prio := cs[pos].priority + + // Adjust position (move to front) + newPos := pos + for ; newPos > 0 && cs[newPos-1].priority < prio; newPos-- { + // Swap node positions + cs[newPos-1], cs[newPos] = cs[newPos], cs[newPos-1] + } + + // Build new index char string + if newPos != pos { + n.indices = n.indices[:newPos] + // Unchanged prefix, might be empty + n.indices[pos:pos+1] + // The index char we move + n.indices[newPos:pos] + n.indices[pos+1:] // Rest without char at 'pos' + } + + return newPos +} + +// addRoute adds a node with the given handle to the path. +// Not concurrency-safe! +func (n *node) addRoute(path string, handlers ...HandlerFunc) { + fullPath := path + n.priority++ + + // Empty tree + if len(n.path) == 0 && len(n.children) == 0 { + n.insertChild(path, fullPath, handlers...) + n.nType = root + return + } + + parentFullPathIndex := 0 + +walk: + for { + // Find the longest common prefix. + // This also implies that the common prefix contains no ':' or '*' + // since the existing key can't contain those chars. + i := longestCommonPrefix(path, n.path) + + // Split edge + if i < len(n.path) { + child := node{ + path: n.path[i:], + wildChild: n.wildChild, + nType: static, + indices: n.indices, + children: n.children, + handlers: n.handlers, + priority: n.priority - 1, + fullPath: n.fullPath, + } + + n.children = []*node{&child} + // []byte for proper unicode char conversion, see #65 + n.indices = bytesconv.BytesToString([]byte{n.path[i]}) + n.path = path[:i] + n.handlers = nil + n.wildChild = false + n.fullPath = fullPath[:parentFullPathIndex+i] + } + + // Make new node a child of this node + if i < len(path) { + path = path[i:] + c := path[0] + + // '/' after param + if n.nType == param && c == '/' && len(n.children) == 1 { + parentFullPathIndex += len(n.path) + n = n.children[0] + n.priority++ + continue walk + } + + // Check if a child with the next path byte exists + for i, max := 0, len(n.indices); i < max; i++ { + if c == n.indices[i] { + parentFullPathIndex += len(n.path) + i = n.incrementChildPrio(i) + n = n.children[i] + continue walk + } + } + + // Otherwise insert it + if c != ':' && c != '*' && n.nType != catchAll { + // []byte for proper unicode char conversion, see #65 + n.indices += bytesconv.BytesToString([]byte{c}) + child := &node{ + fullPath: fullPath, + } + n.addChild(child) + n.incrementChildPrio(len(n.indices) - 1) + n = child + } else if n.wildChild { + // inserting a wildcard node, need to check if it conflicts with the existing wildcard + n = n.children[len(n.children)-1] + n.priority++ + + // Check if the wildcard matches + if len(path) >= len(n.path) && n.path == path[:len(n.path)] && + // Adding a child to a catchAll is not possible + n.nType != catchAll && + // Check for longer wildcard, e.g. :name and :names + (len(n.path) >= len(path) || path[len(n.path)] == '/') { + continue walk + } + + // Wildcard conflict + pathSeg := path + if n.nType != catchAll { + pathSeg = strings.SplitN(pathSeg, "/", 2)[0] + } + prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path + panic("'" + pathSeg + + "' in new path '" + fullPath + + "' conflicts with existing wildcard '" + n.path + + "' in existing prefix '" + prefix + + "'") + } + + n.insertChild(path, fullPath, handlers...) + return + } + + // Otherwise add handle to current node + if n.handlers != nil { + panic("handlers are already registered for path '" + fullPath + "'") + } + n.handlers = handlers + n.fullPath = fullPath + return + } +} + +// Search for a wildcard segment and check the name for invalid characters. +// Returns -1 as index, if no wildcard was found. +func findWildcard(path string) (wildcard string, i int, valid bool) { + // Find start + for start, c := range []byte(path) { + // A wildcard starts with ':' (param) or '*' (catch-all) + if c != ':' && c != '*' { + continue + } + + // Find end and check for invalid characters + valid = true + for end, c := range []byte(path[start+1:]) { + switch c { + case '/': + return path[start : start+1+end], start, valid + case ':', '*': + valid = false + } + } + return path[start:], start, valid + } + return "", -1, false +} + +func (n *node) insertChild(path string, fullPath string, handlers ...HandlerFunc) { + for { + // Find prefix until first wildcard + wildcard, i, valid := findWildcard(path) + if i < 0 { // No wildcard found + break + } + + // The wildcard name must only contain one ':' or '*' character + if !valid { + panic("only one wildcard per path segment is allowed, has: '" + + wildcard + "' in path '" + fullPath + "'") + } + + // check if the wildcard has a name + if len(wildcard) < 2 { + panic("wildcards must be named with a non-empty name in path '" + fullPath + "'") + } + + if wildcard[0] == ':' { // param + if i > 0 { + // Insert prefix before the current wildcard + n.path = path[:i] + path = path[i:] + } + + child := &node{ + nType: param, + path: wildcard, + fullPath: fullPath, + } + n.addChild(child) + n.wildChild = true + n = child + n.priority++ + + // if the path doesn't end with the wildcard, then there + // will be another subpath starting with '/' + if len(wildcard) < len(path) { + path = path[len(wildcard):] + + child := &node{ + priority: 1, + fullPath: fullPath, + } + n.addChild(child) + n = child + continue + } + + // Otherwise we're done. Insert the handle in the new leaf + n.handlers = handlers + return + } + + // catchAll + if i+len(wildcard) != len(path) { + panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'") + } + + if len(n.path) > 0 && n.path[len(n.path)-1] == '/' { + pathSeg := "" + if len(n.children) != 0 { + pathSeg = strings.SplitN(n.children[0].path, "/", 2)[0] + } + panic("catch-all wildcard '" + path + + "' in new path '" + fullPath + + "' conflicts with existing path segment '" + pathSeg + + "' in existing prefix '" + n.path + pathSeg + + "'") + } + + // currently fixed width 1 for '/' + i-- + if path[i] != '/' { + panic("no / before catch-all in path '" + fullPath + "'") + } + + n.path = path[:i] + + // First node: catchAll node with empty path + child := &node{ + wildChild: true, + nType: catchAll, + fullPath: fullPath, + } + + n.addChild(child) + n.indices = string('/') + n = child + n.priority++ + + // second node: node holding the variable + child = &node{ + path: path[i:], + nType: catchAll, + handlers: handlers, + priority: 1, + fullPath: fullPath, + } + n.children = []*node{child} + + return + } + + // If no wildcard was found, simply insert the path and handle + n.path = path + n.handlers = handlers + n.fullPath = fullPath +} + +// nodeValue holds return values of (*Node).getValue method +type nodeValue struct { + handlers []HandlerFunc + params *Params + tsr bool + fullPath string +} + +type skippedNode struct { + path string + node *node + paramsCount int16 +} + +// Returns the handle registered with the given path (key). The values of +// wildcards are saved to a map. +// If no handle can be found, a TSR (trailing slash redirect) recommendation is +// made if a handle exists with an extra (without the) trailing slash for the +// given path. +func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode, unescape bool) (value nodeValue) { + var globalParamsCount int16 + +walk: // Outer loop for walking the tree + for { + prefix := n.path + if len(path) > len(prefix) { + if path[:len(prefix)] == prefix { + path = path[len(prefix):] + + // Try all the non-wildcard children first by matching the indices + idxc := path[0] + for i, c := range []byte(n.indices) { + if c == idxc { + // strings.HasPrefix(n.children[len(n.children)-1].path, ":") == n.wildChild + if n.wildChild { + index := len(*skippedNodes) + *skippedNodes = (*skippedNodes)[:index+1] + (*skippedNodes)[index] = skippedNode{ + path: prefix + path, + node: &node{ + path: n.path, + wildChild: n.wildChild, + nType: n.nType, + priority: n.priority, + children: n.children, + handlers: n.handlers, + fullPath: n.fullPath, + }, + paramsCount: globalParamsCount, + } + } + + n = n.children[i] + continue walk + } + } + + if !n.wildChild { + // If the path at the end of the loop is not equal to '/' and the current node has no child nodes + // the current node needs to roll back to last valid skippedNode + if path != "/" { + for length := len(*skippedNodes); length > 0; length-- { + skippedNode := (*skippedNodes)[length-1] + *skippedNodes = (*skippedNodes)[:length-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } + } + + // Nothing found. + // We can recommend to redirect to the same URL without a + // trailing slash if a leaf exists for that path. + value.tsr = path == "/" && n.handlers != nil + return value + } + + // Handle wildcard child, which is always at the end of the array + n = n.children[len(n.children)-1] + globalParamsCount++ + + switch n.nType { + case param: + // fix truncate the parameter + // tree_test.go line: 204 + + // Find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // Save param value + if params != nil { + // Preallocate capacity if necessary + if cap(*params) < int(globalParamsCount) { + newParams := make(Params, len(*params), globalParamsCount) + copy(newParams, *params) + *params = newParams + } + + if value.params == nil { + value.params = params + } + // Expand slice within preallocated capacity + i := len(*value.params) + *value.params = (*value.params)[:i+1] + val := path[:end] + if unescape { + if v, err := url.QueryUnescape(val); err == nil { + val = v + } + } + (*value.params)[i] = Param{ + Key: n.path[1:], + Value: val, + } + } + + // we need to go deeper! + if end < len(path) { + if len(n.children) > 0 { + path = path[end:] + n = n.children[0] + continue walk + } + + // ... but we can't + value.tsr = len(path) == end+1 + return value + } + + if value.handlers = n.handlers; value.handlers != nil { + value.fullPath = n.fullPath + return value + } + if len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists for TSR recommendation + n = n.children[0] + value.tsr = (n.path == "/" && n.handlers != nil) || (n.path == "" && n.indices == "/") + } + return value + + case catchAll: + // Save param value + if params != nil { + // Preallocate capacity if necessary + if cap(*params) < int(globalParamsCount) { + newParams := make(Params, len(*params), globalParamsCount) + copy(newParams, *params) + *params = newParams + } + + if value.params == nil { + value.params = params + } + // Expand slice within preallocated capacity + i := len(*value.params) + *value.params = (*value.params)[:i+1] + val := path + if unescape { + if v, err := url.QueryUnescape(path); err == nil { + val = v + } + } + (*value.params)[i] = Param{ + Key: n.path[2:], + Value: val, + } + } + + value.handlers = n.handlers + value.fullPath = n.fullPath + return value + + default: + panic("invalid node type") + } + } + } + + if path == prefix { + // If the current path does not equal '/' and the node does not have a registered handle and the most recently matched node has a child node + // the current node needs to roll back to last valid skippedNode + if n.handlers == nil && path != "/" { + for length := len(*skippedNodes); length > 0; length-- { + skippedNode := (*skippedNodes)[length-1] + *skippedNodes = (*skippedNodes)[:length-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } + // n = latestNode.children[len(latestNode.children)-1] + } + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if value.handlers = n.handlers; value.handlers != nil { + value.fullPath = n.fullPath + return value + } + + // If there is no handle for this route, but this route has a + // wildcard child, there must be a handle for this path with an + // additional trailing slash + if path == "/" && n.wildChild && n.nType != root { + value.tsr = true + return value + } + + if path == "/" && n.nType == static { + value.tsr = true + return value + } + + // No handle found. Check if a handle for this path + a + // trailing slash exists for trailing slash recommendation + for i, c := range []byte(n.indices) { + if c == '/' { + n = n.children[i] + value.tsr = (len(n.path) == 1 && n.handlers != nil) || + (n.nType == catchAll && n.children[0].handlers != nil) + return value + } + } + + return value + } + + // Nothing found. We can recommend to redirect to the same URL with an + // extra trailing slash if a leaf exists for that path + value.tsr = path == "/" || + (len(prefix) == len(path)+1 && prefix[len(path)] == '/' && + path == prefix[:len(prefix)-1] && n.handlers != nil) + + // roll back to last valid skippedNode + if !value.tsr && path != "/" { + for length := len(*skippedNodes); length > 0; length-- { + skippedNode := (*skippedNodes)[length-1] + *skippedNodes = (*skippedNodes)[:length-1] + if strings.HasSuffix(skippedNode.path, path) { + path = skippedNode.path + n = skippedNode.node + if value.params != nil { + *value.params = (*value.params)[:skippedNode.paramsCount] + } + globalParamsCount = skippedNode.paramsCount + continue walk + } + } + } + + return value + } +} + +// Makes a case-insensitive lookup of the given path and tries to find a handler. +// It can optionally also fix trailing slashes. +// It returns the case-corrected path and a bool indicating whether the lookup +// was successful. +func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) ([]byte, bool) { + const stackBufSize = 128 + + // Use a static sized buffer on the stack in the common case. + // If the path is too long, allocate a buffer on the heap instead. + buf := make([]byte, 0, stackBufSize) + if length := len(path) + 1; length > stackBufSize { + buf = make([]byte, 0, length) + } + + ciPath := n.findCaseInsensitivePathRec( + path, + buf, // Preallocate enough memory for new path + [4]byte{}, // Empty rune buffer + fixTrailingSlash, + ) + + return ciPath, ciPath != nil +} + +// Shift bytes in array by n bytes left +func shiftNRuneBytes(rb [4]byte, n int) [4]byte { + switch n { + case 0: + return rb + case 1: + return [4]byte{rb[1], rb[2], rb[3], 0} + case 2: + return [4]byte{rb[2], rb[3]} + case 3: + return [4]byte{rb[3]} + default: + return [4]byte{} + } +} + +// Recursive case-insensitive lookup function used by n.findCaseInsensitivePath +func (n *node) findCaseInsensitivePathRec(path string, ciPath []byte, rb [4]byte, fixTrailingSlash bool) []byte { + npLen := len(n.path) + +walk: // Outer loop for walking the tree + for len(path) >= npLen && (npLen == 0 || strings.EqualFold(path[1:npLen], n.path[1:])) { + // Add common prefix to result + oldPath := path + path = path[npLen:] + ciPath = append(ciPath, n.path...) + + if len(path) == 0 { + // We should have reached the node containing the handle. + // Check if this node has a handle registered. + if n.handlers != nil { + return ciPath + } + + // No handle found. + // Try to fix the path by adding a trailing slash + if fixTrailingSlash { + for i, c := range []byte(n.indices) { + if c == '/' { + n = n.children[i] + if (len(n.path) == 1 && n.handlers != nil) || + (n.nType == catchAll && n.children[0].handlers != nil) { + return append(ciPath, '/') + } + return nil + } + } + } + return nil + } + + // If this node does not have a wildcard (param or catchAll) child, + // we can just look up the next child node and continue to walk down + // the tree + if !n.wildChild { + // Skip rune bytes already processed + rb = shiftNRuneBytes(rb, npLen) + + if rb[0] != 0 { + // Old rune not finished + idxc := rb[0] + for i, c := range []byte(n.indices) { + if c == idxc { + // continue with child node + n = n.children[i] + npLen = len(n.path) + continue walk + } + } + } else { + // Process a new rune + var rv rune + + // Find rune start. + // Runes are up to 4 byte long, + // -4 would definitely be another rune. + var off int + for max := min(npLen, 3); off < max; off++ { + if i := npLen - off; utf8.RuneStart(oldPath[i]) { + // read rune from cached path + rv, _ = utf8.DecodeRuneInString(oldPath[i:]) + break + } + } + + // Calculate lowercase bytes of current rune + lo := unicode.ToLower(rv) + utf8.EncodeRune(rb[:], lo) + + // Skip already processed bytes + rb = shiftNRuneBytes(rb, off) + + idxc := rb[0] + for i, c := range []byte(n.indices) { + // Lowercase matches + if c == idxc { + // must use a recursive approach since both the + // uppercase byte and the lowercase byte might exist + // as an index + if out := n.children[i].findCaseInsensitivePathRec( + path, ciPath, rb, fixTrailingSlash, + ); out != nil { + return out + } + break + } + } + + // If we found no match, the same for the uppercase rune, + // if it differs + if up := unicode.ToUpper(rv); up != lo { + utf8.EncodeRune(rb[:], up) + rb = shiftNRuneBytes(rb, off) + + idxc := rb[0] + for i, c := range []byte(n.indices) { + // Uppercase matches + if c == idxc { + // Continue with child node + n = n.children[i] + npLen = len(n.path) + continue walk + } + } + } + } + + // Nothing found. We can recommend to redirect to the same URL + // without a trailing slash if a leaf exists for that path + if fixTrailingSlash && path == "/" && n.handlers != nil { + return ciPath + } + return nil + } + + n = n.children[0] + switch n.nType { + case param: + // Find param end (either '/' or path end) + end := 0 + for end < len(path) && path[end] != '/' { + end++ + } + + // Add param value to case insensitive path + ciPath = append(ciPath, path[:end]...) + + // We need to go deeper! + if end < len(path) { + if len(n.children) > 0 { + // Continue with child node + n = n.children[0] + npLen = len(n.path) + path = path[end:] + continue + } + + // ... but we can't + if fixTrailingSlash && len(path) == end+1 { + return ciPath + } + return nil + } + + if n.handlers != nil { + return ciPath + } + + if fixTrailingSlash && len(n.children) == 1 { + // No handle found. Check if a handle for this path + a + // trailing slash exists + n = n.children[0] + if n.path == "/" && n.handlers != nil { + return append(ciPath, '/') + } + } + + return nil + + case catchAll: + return append(ciPath, path...) + + default: + panic("invalid node type") + } + } + + // Nothing found. + // Try to fix the path by adding / removing a trailing slash + if fixTrailingSlash { + if path == "/" { + return ciPath + } + if len(path)+1 == npLen && n.path[len(path)] == '/' && + strings.EqualFold(path[1:], n.path[1:len(path)]) && n.handlers != nil { + return append(ciPath, n.path...) + } + } + return nil +} diff --git a/pkg/api/util.go b/pkg/api/util.go new file mode 100644 index 0000000..7718043 --- /dev/null +++ b/pkg/api/util.go @@ -0,0 +1,226 @@ +package api + +import ( + "fmt" + "strings" + + "github.com/loveuer/upp/internal/schema" +) + +const ( + MIMETextXML = "text/xml" + MIMETextHTML = "text/html" + MIMETextPlain = "text/plain" + MIMETextJavaScript = "text/javascript" + MIMEApplicationXML = "application/xml" + MIMEApplicationJSON = "application/json" + MIMEApplicationForm = "application/x-www-form-urlencoded" + MIMEOctetStream = "application/octet-stream" + MIMEMultipartForm = "multipart/form-data" + + MIMETextXMLCharsetUTF8 = "text/xml; charset=utf-8" + MIMETextHTMLCharsetUTF8 = "text/html; charset=utf-8" + MIMETextPlainCharsetUTF8 = "text/plain; charset=utf-8" + MIMETextJavaScriptCharsetUTF8 = "text/javascript; charset=utf-8" + MIMEApplicationXMLCharsetUTF8 = "application/xml; charset=utf-8" + MIMEApplicationJSONCharsetUTF8 = "application/json; charset=utf-8" + // Deprecated: use MIMETextJavaScriptCharsetUTF8 instead + MIMEApplicationJavaScriptCharsetUTF8 = "application/javascript; charset=utf-8" +) + +// parseVendorSpecificContentType check if content type is vendor specific and +// if it is parsable to any known types. If it's not vendor specific then returns +// the original content type. +func parseVendorSpecificContentType(cType string) string { + plusIndex := strings.Index(cType, "+") + + if plusIndex == -1 { + return cType + } + + var parsableType string + if semiColonIndex := strings.Index(cType, ";"); semiColonIndex == -1 { + parsableType = cType[plusIndex+1:] + } else if plusIndex < semiColonIndex { + parsableType = cType[plusIndex+1 : semiColonIndex] + } else { + return cType[:semiColonIndex] + } + + slashIndex := strings.Index(cType, "/") + + if slashIndex == -1 { + return cType + } + + return cType[0:slashIndex+1] + parsableType +} + +func parseToStruct(aliasTag string, out interface{}, data map[string][]string) error { + schemaDecoder := schema.NewDecoder() + schemaDecoder.SetAliasTag(aliasTag) + + if err := schemaDecoder.Decode(out, data); err != nil { + return fmt.Errorf("failed to decode: %w", err) + } + + return nil +} + +func elsePanic(guard bool, text string) { + if !guard { + panic(text) + } +} + +func cleanPath(p string) string { + const stackBufSize = 128 + // Turn empty string into "/" + if p == "" { + return "/" + } + + // Reasonably sized buffer on stack to avoid allocations in the common case. + // If a larger buffer is required, it gets allocated dynamically. + buf := make([]byte, 0, stackBufSize) + + n := len(p) + + // Invariants: + // reading from path; r is index of next byte to process. + // writing to buf; w is index of next byte to write. + + // path must start with '/' + r := 1 + w := 1 + + if p[0] != '/' { + r = 0 + + if n+1 > stackBufSize { + buf = make([]byte, n+1) + } else { + buf = buf[:n+1] + } + buf[0] = '/' + } + + trailing := n > 1 && p[n-1] == '/' + + // A bit more clunky without a 'lazybuf' like the path package, but the loop + // gets completely inlined (bufApp calls). + // loop has no expensive function calls (except 1x make) // So in contrast to the path package this loop has no expensive function + // calls (except make, if needed). + + for r < n { + switch { + case p[r] == '/': + // empty path element, trailing slash is added after the end + r++ + + case p[r] == '.' && r+1 == n: + trailing = true + r++ + + case p[r] == '.' && p[r+1] == '/': + // . element + r += 2 + + case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): + // .. element: remove to last / + r += 3 + + if w > 1 { + // can backtrack + w-- + + if len(buf) == 0 { + for w > 1 && p[w] != '/' { + w-- + } + } else { + for w > 1 && buf[w] != '/' { + w-- + } + } + } + + default: + // Real path element. + // Add slash if needed + if w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + // Copy element + for r < n && p[r] != '/' { + bufApp(&buf, p, w, p[r]) + w++ + r++ + } + } + } + + // Re-append trailing slash + if trailing && w > 1 { + bufApp(&buf, p, w, '/') + w++ + } + + // If the original string was not modified (or only shortened at the end), + // return the respective substring of the original string. + // Otherwise return a new string from the buffer. + if len(buf) == 0 { + return p[:w] + } + return string(buf[:w]) +} + +// Internal helper to lazily create a buffer if necessary. +// Calls to this function get inlined. +func bufApp(buf *[]byte, s string, w int, c byte) { + b := *buf + if len(b) == 0 { + // No modification of the original string so far. + // If the next character is the same as in the original string, we do + // not yet have to allocate a buffer. + if s[w] == c { + return + } + + // Otherwise use either the stack buffer, if it is large enough, or + // allocate a new buffer on the heap, and copy all previous characters. + length := len(s) + if length > cap(b) { + *buf = make([]byte, length) + } else { + *buf = (*buf)[:length] + } + b = *buf + + copy(b, s[:w]) + } + b[w] = c +} + +func HumanDuration(nano int64) string { + duration := float64(nano) + unit := "ns" + if duration >= 1000 { + duration /= 1000 + unit = "us" + } + + if duration >= 1000 { + duration /= 1000 + unit = "ms" + } + + if duration >= 1000 { + duration /= 1000 + unit = " s" + } + + return fmt.Sprintf("%6.2f%s", duration, unit) +} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 0000000..f6d90b4 --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,60 @@ +package cache + +import ( + "context" + "encoding/json" + "errors" + "time" +) + +type Cache interface { + Get(ctx context.Context, key string) ([]byte, error) + Gets(ctx context.Context, keys ...string) ([][]byte, error) + GetScan(ctx context.Context, key string) Scanner + GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) + GetExScan(ctx context.Context, key string, duration time.Duration) Scanner + // Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal + Set(ctx context.Context, key string, value any) error + Sets(ctx context.Context, vm map[string]any) error + // SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal + SetEx(ctx context.Context, key string, value any, duration time.Duration) error + Del(ctx context.Context, keys ...string) error +} + +type Scanner interface { + Scan(model any) error +} + +type encoded_value interface { + MarshalBinary() ([]byte, error) +} + +type decoded_value interface { + UnmarshalBinary(bs []byte) error +} + +const ( + Prefix = "upp:" +) + +var ErrorKeyNotFound = errors.New("key not found") + +func handleValue(value any) ([]byte, error) { + var ( + bs []byte + err error + ) + + switch value.(type) { + case []byte: + return value.([]byte), nil + } + + if imp, ok := value.(encoded_value); ok { + bs, err = imp.MarshalBinary() + } else { + bs, err = json.Marshal(value) + } + + return bs, err +} diff --git a/pkg/cache/cache.lru.go b/pkg/cache/cache.lru.go new file mode 100644 index 0000000..72d177b --- /dev/null +++ b/pkg/cache/cache.lru.go @@ -0,0 +1,141 @@ +package cache + +import ( + "context" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + _ "github.com/hashicorp/golang-lru/v2/expirable" +) + +var _ Cache = (*_lru)(nil) + +type _lru struct { + client *expirable.LRU[string, *_lru_value] +} + +type _lru_value struct { + duration time.Duration + last time.Time + bs []byte +} + +func (l *_lru) Get(ctx context.Context, key string) ([]byte, error) { + v, ok := l.client.Get(key) + if !ok { + return nil, ErrorKeyNotFound + } + + if v.duration == 0 { + return v.bs, nil + } + + if time.Now().Sub(v.last) > v.duration { + l.client.Remove(key) + return nil, ErrorKeyNotFound + } + + return v.bs, nil +} + +func (l *_lru) Gets(ctx context.Context, keys ...string) ([][]byte, error) { + bss := make([][]byte, 0, len(keys)) + for _, key := range keys { + bs, err := l.Get(ctx, key) + if err != nil { + return nil, err + } + + bss = append(bss, bs) + } + + return bss, nil +} + +func (l *_lru) GetScan(ctx context.Context, key string) Scanner { + return newScanner(l.Get(ctx, key)) +} + +func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) { + v, ok := l.client.Get(key) + if !ok { + return nil, ErrorKeyNotFound + } + + if v.duration == 0 { + return v.bs, nil + } + + now := time.Now() + + if now.Sub(v.last) > v.duration { + l.client.Remove(key) + return nil, ErrorKeyNotFound + } + + l.client.Add(key, &_lru_value{ + duration: duration, + last: now, + bs: v.bs, + }) + + return v.bs, nil +} + +func (l *_lru) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner { + return newScanner(l.GetEx(ctx, key, duration)) +} + +func (l *_lru) Set(ctx context.Context, key string, value any) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + l.client.Add(key, &_lru_value{ + duration: 0, + last: time.Now(), + bs: bs, + }) + + return nil +} + +func (l *_lru) SetEx(ctx context.Context, key string, value any, duration time.Duration) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + l.client.Add(key, &_lru_value{ + duration: duration, + last: time.Now(), + bs: bs, + }) + + return nil +} + +func (l *_lru) Sets(ctx context.Context, m map[string]any) error { + for k, v := range m { + if err := l.Set(ctx, k, v); err != nil { + return err + } + } + + return nil +} + +func (l *_lru) Del(ctx context.Context, keys ...string) error { + for _, key := range keys { + l.client.Remove(key) + } + + return nil +} + +func newLRUCache() (Cache, error) { + client := expirable.NewLRU[string, *_lru_value](1024*1024, nil, 0) + + return &_lru{client: client}, nil +} diff --git a/pkg/cache/cache.memory.go b/pkg/cache/cache.memory.go new file mode 100644 index 0000000..f759a1b --- /dev/null +++ b/pkg/cache/cache.memory.go @@ -0,0 +1,105 @@ +package cache + +import ( + "context" + "errors" + "fmt" + "time" + + "gitea.com/loveuer/gredis" +) + +var _ Cache = (*_mem)(nil) + +type _mem struct { + client *gredis.Gredis +} + +func (m *_mem) GetScan(ctx context.Context, key string) Scanner { + return newScanner(m.Get(ctx, key)) +} + +func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner { + return newScanner(m.GetEx(ctx, key, duration)) +} + +func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) { + v, err := m.client.Get(key) + if err != nil { + if errors.Is(err, gredis.ErrKeyNotFound) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + bs, ok := v.([]byte) + if !ok { + return nil, fmt.Errorf("invalid value type=%T", v) + } + + return bs, nil +} + +func (m *_mem) Gets(ctx context.Context, keys ...string) ([][]byte, error) { + bss := make([][]byte, 0, len(keys)) + for _, key := range keys { + bs, err := m.Get(ctx, key) + if err != nil { + return nil, err + } + + bss = append(bss, bs) + } + + return bss, nil +} + +func (m *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) { + v, err := m.client.GetEx(key, duration) + if err != nil { + if errors.Is(err, gredis.ErrKeyNotFound) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + bs, ok := v.([]byte) + if !ok { + return nil, fmt.Errorf("invalid value type=%T", v) + } + + return bs, nil +} + +func (m *_mem) Set(ctx context.Context, key string, value any) error { + bs, err := handleValue(value) + if err != nil { + return err + } + return m.client.Set(key, bs) +} + +func (m *_mem) Sets(ctx context.Context, vm map[string]any) error { + for k, v := range vm { + if err := m.Set(ctx, k, v); err != nil { + return err + } + } + + return nil +} + +func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error { + bs, err := handleValue(value) + if err != nil { + return err + } + return m.client.SetEx(key, bs, duration) +} + +func (m *_mem) Del(ctx context.Context, keys ...string) error { + m.client.Delete(keys...) + return nil +} diff --git a/pkg/cache/cache.redis.go b/pkg/cache/cache.redis.go new file mode 100644 index 0000000..6d6d207 --- /dev/null +++ b/pkg/cache/cache.redis.go @@ -0,0 +1,106 @@ +package cache + +import ( + "context" + "errors" + "time" + + "github.com/go-redis/redis/v8" + "github.com/samber/lo" + "github.com/spf13/cast" +) + +type _redis struct { + client *redis.Client +} + +func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) { + result, err := r.client.Get(ctx, key).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + return []byte(result), nil +} + +func (r *_redis) Gets(ctx context.Context, keys ...string) ([][]byte, error) { + result, err := r.client.MGet(ctx, keys...).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + return lo.Map( + result, + func(item any, index int) []byte { + return []byte(cast.ToString(item)) + }, + ), nil +} + +func (r *_redis) GetScan(ctx context.Context, key string) Scanner { + return newScanner(r.Get(ctx, key)) +} + +func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) { + result, err := r.client.GetEx(ctx, key, duration).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + return []byte(result), nil +} + +func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner { + return newScanner(r.GetEx(ctx, key, duration)) +} + +func (r *_redis) Set(ctx context.Context, key string, value any) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + _, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result() + return err +} + +func (r *_redis) Sets(ctx context.Context, values map[string]any) error { + vm := make(map[string]any) + for k, v := range values { + bs, err := handleValue(v) + if err != nil { + return err + } + + vm[k] = bs + } + + return r.client.MSet(ctx, vm).Err() +} + +func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + _, err = r.client.SetEX(ctx, key, bs, duration).Result() + + return err +} + +func (r *_redis) Del(ctx context.Context, keys ...string) error { + return r.client.Del(ctx, keys...).Err() +} diff --git a/pkg/cache/new.go b/pkg/cache/new.go new file mode 100644 index 0000000..f402d8c --- /dev/null +++ b/pkg/cache/new.go @@ -0,0 +1,70 @@ +package cache + +import ( + "fmt" + "net/url" + "strings" + + "gitea.com/loveuer/gredis" + "github.com/go-redis/redis/v8" + "github.com/loveuer/upp/pkg/tool" +) + +func New(uri string) (Cache, error) { + var ( + client Cache + err error + ) + + strs := strings.Split(uri, "::") + + switch strs[0] { + case "memory": + gc := gredis.NewGredis(1024 * 1024) + client = &_mem{client: gc} + case "lru": + if client, err = newLRUCache(); err != nil { + return nil, err + } + case "redis": + var ( + ins *url.URL + err error + ) + + if len(strs) != 2 { + return nil, fmt.Errorf("cache.Init: invalid cache uri: %s", uri) + } + + uri := strs[1] + + if !strings.Contains(uri, "://") { + uri = fmt.Sprintf("redis://%s", uri) + } + + if ins, err = url.Parse(uri); err != nil { + return nil, fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", uri, err.Error()) + } + + addr := ins.Host + username := ins.User.Username() + password, _ := ins.User.Password() + + var rc *redis.Client + rc = redis.NewClient(&redis.Options{ + Addr: addr, + Username: username, + Password: password, + }) + + if err = rc.Ping(tool.Timeout(5)).Err(); err != nil { + return nil, fmt.Errorf("cache.Init: redis ping err: %s", err.Error()) + } + + client = &_redis{client: rc} + default: + return nil, fmt.Errorf("cache type %s not support", strs[0]) + } + + return client, nil +} diff --git a/pkg/cache/scan.go b/pkg/cache/scan.go new file mode 100644 index 0000000..c65d267 --- /dev/null +++ b/pkg/cache/scan.go @@ -0,0 +1,20 @@ +package cache + +import "encoding/json" + +type scanner struct { + err error + bs []byte +} + +func (s *scanner) Scan(model any) error { + if s.err != nil { + return s.err + } + + return json.Unmarshal(s.bs, model) +} + +func newScanner(bs []byte, err error) *scanner { + return &scanner{bs: bs, err: err} +} diff --git a/pkg/db/new.go b/pkg/db/new.go new file mode 100644 index 0000000..603611f --- /dev/null +++ b/pkg/db/new.go @@ -0,0 +1,58 @@ +package db + +import ( + "fmt" + "net/url" + "strings" + + "github.com/glebarez/sqlite" + "github.com/loveuer/upp/pkg/log" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +func New(uri string) (*gorm.DB, error) { + ins, err := url.Parse(uri) + if err != nil { + return nil, err + } + + var ( + username = "" + password = "" + tx *gorm.DB + ) + + if ins.User != nil { + username = ins.User.Username() + password, _ = ins.User.Password() + } + + switch ins.Scheme { + case "sqlite": + path := strings.TrimPrefix(uri, ins.Scheme+"://") + log.Debug("db.New: type = %s, path = %s", ins.Scheme, path) + tx, err = gorm.Open(sqlite.Open(path)) + case "mysql", "mariadb": + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?%s", username, password, ins.Host, ins.Path, ins.RawQuery) + log.Debug("db.New: type = %s, dsn = %s", ins.Scheme, dsn) + tx, err = gorm.Open(mysql.Open(dsn)) + case "pg", "postgres", "postgresql": + opts := make([]string, 0) + for key, val := range ins.Query() { + opts = append(opts, fmt.Sprintf("%s=%s", key, val)) + } + dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s %s", ins.Hostname(), username, password, ins.Path, ins.Port(), strings.Join(opts, " ")) + log.Debug("db.New: type = %s, dsn = %s", ins.Scheme, dsn) + tx, err = gorm.Open(postgres.Open(dsn)) + default: + return nil, fmt.Errorf("invalid database type(uri_scheme): %s", ins.Scheme) + } + + if err != nil { + return nil, err + } + + return tx, nil +} diff --git a/pkg/interfaces/api.engine.go b/pkg/interfaces/api.engine.go new file mode 100644 index 0000000..1844c35 --- /dev/null +++ b/pkg/interfaces/api.engine.go @@ -0,0 +1,64 @@ +package interfaces + +import ( + "context" + "mime/multipart" + "net" +) + +type ApiGroup interface { + Group(path string, handlers ...ApiHandler) ApiGroup + GET(path string, handlers ...ApiHandler) + POST(path string, handlers ...ApiHandler) + PUT(path string, handlers ...ApiHandler) + DELETE(path string, handlers ...ApiHandler) + HEAD(path string, handlers ...ApiHandler) + PATCH(path string, handlers ...ApiHandler) + OPTIONS(path string, handlers ...ApiHandler) + Handle(method, path string, handlers ...ApiHandler) + Use(handlers ...ApiHandler) +} + +type ApiEngine interface { + ApiGroup + Run(address string) error + RunListener(ln net.Listener) error +} + +type ApiContext interface { + App() Upp + // parse body, form, json + BodyParser(out any) error + Context() context.Context + Cookie(string) string + FormFile(string) (*multipart.FileHeader, error) + FormValue(string) string + GetHeader(string) string + SetHeader(string, string) + IP() string + Json(any) error + Locals(interface{}, ...any) any + + // get method or rewrite method + Method(string, ...string) string + MultipartForm() (*multipart.Form, error) + Param(string) string + + // get path or rewrite path + Path(string, ...string) string + Next() error + Query(string) string + QueryParse(any) error + Redirect(string, ...int) error + + // render html response + Render(name, layout string, data any) error + + // set response status + Status(int) + SendStatus(int) + + Write([]byte) (int, error) +} + +type ApiHandler func(c ApiContext) error diff --git a/pkg/interfaces/logger.go b/pkg/interfaces/logger.go new file mode 100644 index 0000000..16c40bb --- /dev/null +++ b/pkg/interfaces/logger.go @@ -0,0 +1,10 @@ +package interfaces + +type Logger interface { + Debug(string, ...any) + Info(string, ...any) + Warn(string, ...any) + Error(string, ...any) + Panic(string, ...any) + Fatal(string, ...any) +} diff --git a/pkg/interfaces/upp.go b/pkg/interfaces/upp.go new file mode 100644 index 0000000..88ffe2b --- /dev/null +++ b/pkg/interfaces/upp.go @@ -0,0 +1,13 @@ +package interfaces + +import ( + "github.com/elastic/go-elasticsearch/v7" + "github.com/loveuer/upp/pkg/cache" + "gorm.io/gorm" +) + +type Upp interface { + UseDB() *gorm.DB + UseCache() cache.Cache + UseES() *elasticsearch.Client +} diff --git a/pkg/log/default.go b/pkg/log/default.go new file mode 100644 index 0000000..7ecf5e7 --- /dev/null +++ b/pkg/log/default.go @@ -0,0 +1,67 @@ +package log + +import ( + "fmt" + "os" + "sync" +) + +var ( + nilLogger = func(prefix, timestamp, msg string, data ...any) {} + normalLogger = func(prefix, timestamp, msg string, data ...any) { + fmt.Printf(prefix+"| "+timestamp+" | "+msg+"\n", data...) + } + + panicLogger = func(prefix, timestamp, msg string, data ...any) { + panic(fmt.Sprintf(prefix+"| "+timestamp+" | "+msg+"\n", data...)) + } + + fatalLogger = func(prefix, timestamp, msg string, data ...any) { + fmt.Printf(prefix+"| "+timestamp+" | "+msg+"\n", data...) + os.Exit(1) + } + + DefaultLogger = &logger{ + Mutex: sync.Mutex{}, + timeFormat: "2006-01-02T15:04:05", + writer: os.Stdout, + level: LogLevelInfo, + debug: nilLogger, + info: normalLogger, + warn: normalLogger, + error: normalLogger, + panic: panicLogger, + fatal: fatalLogger, + } +) + +func SetTimeFormat(format string) { + DefaultLogger.SetTimeFormat(format) +} + +func SetLogLevel(level LogLevel) { + DefaultLogger.SetLogLevel(level) +} + +func Debug(msg string, data ...any) { + DefaultLogger.Debug(msg, data...) +} +func Info(msg string, data ...any) { + DefaultLogger.Info(msg, data...) +} + +func Warn(msg string, data ...any) { + DefaultLogger.Warn(msg, data...) +} + +func Error(msg string, data ...any) { + DefaultLogger.Error(msg, data...) +} + +func Panic(msg string, data ...any) { + DefaultLogger.Panic(msg, data...) +} + +func Fatal(msg string, data ...any) { + DefaultLogger.Fatal(msg, data...) +} diff --git a/pkg/log/log.go b/pkg/log/log.go new file mode 100644 index 0000000..9e55695 --- /dev/null +++ b/pkg/log/log.go @@ -0,0 +1,115 @@ +package log + +import ( + "github.com/fatih/color" + "io" + "sync" + "time" +) + +type LogLevel uint32 + +const ( + LogLevelDebug = iota + LogLevelInfo + LogLevelWarn + LogLevelError + LogLevelPanic + LogLevelFatal +) + +type logger struct { + sync.Mutex + timeFormat string + writer io.Writer + level LogLevel + debug func(prefix, timestamp, msg string, data ...any) + info func(prefix, timestamp, msg string, data ...any) + warn func(prefix, timestamp, msg string, data ...any) + error func(prefix, timestamp, msg string, data ...any) + panic func(prefix, timestamp, msg string, data ...any) + fatal func(prefix, timestamp, msg string, data ...any) +} + +var ( + red = color.New(color.FgRed) + hired = color.New(color.FgHiRed) + green = color.New(color.FgGreen) + yellow = color.New(color.FgYellow) + white = color.New(color.FgWhite) +) + +func (l *logger) SetTimeFormat(format string) { + l.Lock() + defer l.Unlock() + l.timeFormat = format +} + +func (l *logger) SetLogLevel(level LogLevel) { + l.Lock() + defer l.Unlock() + + if level > LogLevelDebug { + l.debug = nilLogger + } else { + l.debug = normalLogger + } + + if level > LogLevelInfo { + l.info = nilLogger + } else { + l.info = normalLogger + } + + if level > LogLevelWarn { + l.warn = nilLogger + } else { + l.warn = normalLogger + } + + if level > LogLevelError { + l.error = nilLogger + } else { + l.error = normalLogger + } + + if level > LogLevelPanic { + l.panic = nilLogger + } else { + l.panic = panicLogger + } + + if level > LogLevelFatal { + l.fatal = nilLogger + } else { + l.fatal = fatalLogger + } +} + +func (l *logger) Debug(msg string, data ...any) { + l.debug(white.Sprint("Debug "), time.Now().Format(l.timeFormat), msg, data...) +} + +func (l *logger) Info(msg string, data ...any) { + l.info(green.Sprint("Info "), time.Now().Format(l.timeFormat), msg, data...) +} + +func (l *logger) Warn(msg string, data ...any) { + l.warn(yellow.Sprint("Warn "), time.Now().Format(l.timeFormat), msg, data...) +} + +func (l *logger) Error(msg string, data ...any) { + l.error(red.Sprint("Error "), time.Now().Format(l.timeFormat), msg, data...) +} + +func (l *logger) Panic(msg string, data ...any) { + l.panic(hired.Sprint("Panic "), time.Now().Format(l.timeFormat), msg, data...) +} + +func (l *logger) Fatal(msg string, data ...any) { + l.fatal(hired.Sprint("Fatal "), time.Now().Format(l.timeFormat), msg, data...) +} + +type WroteLogger interface { + Info(msg string, data ...any) +} diff --git a/pkg/log/new.go b/pkg/log/new.go new file mode 100644 index 0000000..204fac1 --- /dev/null +++ b/pkg/log/new.go @@ -0,0 +1,21 @@ +package log + +import ( + "os" + "sync" +) + +func New() *logger { + return &logger{ + Mutex: sync.Mutex{}, + timeFormat: "2006-01-02T15:04:05", + writer: os.Stdout, + level: LogLevelInfo, + debug: nilLogger, + info: normalLogger, + warn: normalLogger, + error: normalLogger, + panic: panicLogger, + fatal: fatalLogger, + } +} diff --git a/pkg/tool/ctx.go b/pkg/tool/ctx.go new file mode 100644 index 0000000..501b18f --- /dev/null +++ b/pkg/tool/ctx.go @@ -0,0 +1,38 @@ +package tool + +import ( + "context" + "time" +) + +func Timeout(seconds ...int) (ctx context.Context) { + var ( + duration time.Duration + ) + + if len(seconds) > 0 && seconds[0] > 0 { + duration = time.Duration(seconds[0]) * time.Second + } else { + duration = time.Duration(30) * time.Second + } + + ctx, _ = context.WithTimeout(context.Background(), duration) + + return +} + +func TimeoutCtx(ctx context.Context, seconds ...int) context.Context { + var ( + duration time.Duration + ) + + if len(seconds) > 0 && seconds[0] > 0 { + duration = time.Duration(seconds[0]) * time.Second + } else { + duration = time.Duration(30) * time.Second + } + + nctx, _ := context.WithTimeout(ctx, duration) + + return nctx +} diff --git a/pkg/tool/human.go b/pkg/tool/human.go new file mode 100644 index 0000000..af9a188 --- /dev/null +++ b/pkg/tool/human.go @@ -0,0 +1,24 @@ +package tool + +import "fmt" + +func HumanDuration(nano int64) string { + duration := float64(nano) + unit := "ns" + if duration >= 1000 { + duration /= 1000 + unit = "us" + } + + if duration >= 1000 { + duration /= 1000 + unit = "ms" + } + + if duration >= 1000 { + duration /= 1000 + unit = " s" + } + + return fmt.Sprintf("%6.2f%s", duration, unit) +} diff --git a/pkg/tool/loading/loading.go b/pkg/tool/loading/loading.go new file mode 100644 index 0000000..3c986e1 --- /dev/null +++ b/pkg/tool/loading/loading.go @@ -0,0 +1,123 @@ +package loading + +import ( + "context" + "fmt" + "time" +) + +type Type int + +const ( + TypeProcessing Type = iota + TypeInfo + TypeSuccess + TypeWarning + TypeError +) + +func (t Type) Symbol() string { + switch t { + case TypeSuccess: + return "✔️ " + case TypeWarning: + return "❗ " + case TypeError: + return "❌ " + case TypeInfo: + return "❕ " + default: + return "" + } +} + +type _msg struct { + msg string + t Type +} + +var frames = []string{"|", "/", "-", "\\"} + +func Do(ctx context.Context, fn func(ctx context.Context, print func(msg string, types ...Type)) error) (err error) { + start := time.Now() + ch := make(chan *_msg) + + defer func() { + fmt.Printf("\r\033[K") + }() + + go func() { + var ( + m *_msg + ok bool + processing string + ) + + for { + for _, frame := range frames { + select { + case <-ctx.Done(): + return + case m, ok = <-ch: + if !ok || m == nil { + return + } + + switch m.t { + case TypeProcessing: + if m.msg != "" { + processing = m.msg + } + case TypeInfo, + TypeSuccess, + TypeWarning, + TypeError: + // Clear the loading animation + fmt.Printf("\r\033[K") + fmt.Printf("%s%s\n", m.t.Symbol(), m.msg) + } + default: + elapsed := time.Since(start).Seconds() + if processing != "" { + fmt.Printf("\r\033[K%s %s (%.2fs)", frame, processing, elapsed) + } + time.Sleep(100 * time.Millisecond) + } + } + } + }() + + printFn := func(msg string, types ...Type) { + if msg == "" { + return + } + + m := &_msg{ + msg: msg, + t: TypeProcessing, + } + + if len(types) > 0 { + m.t = types[0] + } + + ch <- m + } + + done := make(chan struct{}) + go func() { + if err = fn(ctx, printFn); err != nil { + ch <- &_msg{msg: err.Error(), t: TypeError} + } + + close(ch) + done <- struct{}{} + }() + + select { + case <-ctx.Done(): + case <-done: + } + + return err +} diff --git a/pkg/tool/must.go b/pkg/tool/must.go new file mode 100644 index 0000000..4cea125 --- /dev/null +++ b/pkg/tool/must.go @@ -0,0 +1,11 @@ +package tool + +import "github.com/loveuer/upp/pkg/log" + +func Must(errs ...error) { + for _, err := range errs { + if err != nil { + log.Panic(err.Error()) + } + } +} diff --git a/pkg/tool/password.go b/pkg/tool/password.go new file mode 100644 index 0000000..33bd690 --- /dev/null +++ b/pkg/tool/password.go @@ -0,0 +1,85 @@ +package tool + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/loveuer/upp/pkg/log" + "golang.org/x/crypto/pbkdf2" +) + +const ( + EncryptHeader string = "pbkdf2:sha256" // 用户密码加密 +) + +func NewPassword(password string) string { + return EncryptPassword(password, RandomString(8), int(RandomInt(50000)+100000)) +} + +func ComparePassword(in, db string) bool { + strs := strings.Split(db, "$") + if len(strs) != 3 { + log.Error("password in db invalid: %s", db) + return false + } + + encs := strings.Split(strs[0], ":") + if len(encs) != 3 { + log.Error("password in db invalid: %s", db) + return false + } + + encIteration, err := strconv.Atoi(encs[2]) + if err != nil { + log.Error("password in db invalid: %s, convert iter err: %s", db, err) + return false + } + + return EncryptPassword(in, strs[1], encIteration) == db +} + +func EncryptPassword(password, salt string, iter int) string { + hash := pbkdf2.Key([]byte(password), []byte(salt), iter, 32, sha256.New) + encrypted := hex.EncodeToString(hash) + return fmt.Sprintf("%s:%d$%s$%s", EncryptHeader, iter, salt, encrypted) +} + +func CheckPassword(password string) error { + if len(password) < 8 || len(password) > 32 { + return errors.New("密码长度不符合") + } + + var ( + err error + match bool + patternList = []string{`[0-9]+`, `[a-z]+`, `[A-Z]+`, `[!@#%]+`} //, `[~!@#$%^&*?_-]+`} + matchAccount = 0 + tips = []string{"缺少数字", "缺少小写字母", "缺少大写字母", "缺少'!@#%'"} + locktips = make([]string, 0) + ) + + for idx, pattern := range patternList { + match, err = regexp.MatchString(pattern, password) + if err != nil { + log.Warn("regex match string err, reg_str: %s, err: %v", pattern, err) + return errors.New("密码强度不够") + } + + if match { + matchAccount++ + } else { + locktips = append(locktips, tips[idx]) + } + } + + if matchAccount < 3 { + return fmt.Errorf("密码强度不够, 可能 %s", strings.Join(locktips, ", ")) + } + + return nil +} diff --git a/pkg/tool/password_test.go b/pkg/tool/password_test.go new file mode 100644 index 0000000..ac4ec28 --- /dev/null +++ b/pkg/tool/password_test.go @@ -0,0 +1,11 @@ +package tool + +import "testing" + +func TestEncPassword(t *testing.T) { + password := "123456" + + result := EncryptPassword(password, RandomString(8), 50000) + + t.Logf("sum => %s", result) +} diff --git a/pkg/tool/random.go b/pkg/tool/random.go new file mode 100644 index 0000000..b960527 --- /dev/null +++ b/pkg/tool/random.go @@ -0,0 +1,54 @@ +package tool + +import ( + "crypto/rand" + "math/big" +) + +var ( + letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + letterNum = []byte("0123456789") + letterLow = []byte("abcdefghijklmnopqrstuvwxyz") + letterCap = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + letterSyb = []byte("!@#$%^&*()_+-=") +) + +func RandomInt(max int64) int64 { + num, _ := rand.Int(rand.Reader, big.NewInt(max)) + return num.Int64() +} + +func RandomString(length int) string { + result := make([]byte, length) + for i := 0; i < length; i++ { + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + result[i] = letters[num.Int64()] + } + return string(result) +} + +func RandomPassword(length int, withSymbol bool) string { + result := make([]byte, length) + kind := 3 + if withSymbol { + kind++ + } + + for i := 0; i < length; i++ { + switch i % kind { + case 0: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterNum)))) + result[i] = letterNum[num.Int64()] + case 1: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterLow)))) + result[i] = letterLow[num.Int64()] + case 2: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterCap)))) + result[i] = letterCap[num.Int64()] + case 3: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterSyb)))) + result[i] = letterSyb[num.Int64()] + } + } + return string(result) +} diff --git a/pkg/tool/table.go b/pkg/tool/table.go new file mode 100644 index 0000000..10acf8b --- /dev/null +++ b/pkg/tool/table.go @@ -0,0 +1,125 @@ +package tool + +import ( + "encoding/json" + "fmt" + "io" + "os" + "reflect" + "strings" + + "github.com/jedib0t/go-pretty/v6/table" + "github.com/loveuer/upp/pkg/log" +) + +func TablePrinter(data any, writers ...io.Writer) { + var w io.Writer = os.Stdout + if len(writers) > 0 && writers[0] != nil { + w = writers[0] + } + + t := table.NewWriter() + structPrinter(t, "", data) + _, _ = fmt.Fprintln(w, t.Render()) +} + +func structPrinter(w table.Writer, prefix string, item any) { +Start: + rv := reflect.ValueOf(item) + if rv.IsZero() { + return + } + + for rv.Type().Kind() == reflect.Pointer { + rv = rv.Elem() + } + + switch rv.Type().Kind() { + case reflect.Invalid, + reflect.Uintptr, + reflect.Chan, + reflect.Func, + reflect.UnsafePointer: + case reflect.Bool, + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Float32, + reflect.Float64, + reflect.Complex64, + reflect.Complex128, + reflect.Interface: + w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), rv.Interface()}) + case reflect.String: + val := rv.String() + if len(val) <= 160 { + w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val}) + return + } + + w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val[0:64] + "..." + val[len(val)-64:]}) + case reflect.Array, reflect.Slice: + for i := 0; i < rv.Len(); i++ { + p := strings.Join([]string{prefix, fmt.Sprintf("[%d]", i)}, ".") + structPrinter(w, p, rv.Index(i).Interface()) + } + case reflect.Map: + for _, k := range rv.MapKeys() { + structPrinter(w, fmt.Sprintf("%s.{%v}", prefix, k), rv.MapIndex(k).Interface()) + } + case reflect.Pointer: + goto Start + case reflect.Struct: + for i := 0; i < rv.NumField(); i++ { + p := fmt.Sprintf("%s.%s", prefix, rv.Type().Field(i).Name) + field := rv.Field(i) + + // log.Debug("TablePrinter: prefix: %s, field: %v", p, rv.Field(i)) + + if !field.CanInterface() { + return + } + + structPrinter(w, p, field.Interface()) + } + } +} + +func TableMapPrinter(data []byte) { + m := make(map[string]any) + if err := json.Unmarshal(data, &m); err != nil { + log.Warn(err.Error()) + return + } + + t := table.NewWriter() + addRow(t, "", m) + fmt.Println(t.Render()) +} + +func addRow(w table.Writer, prefix string, m any) { + rv := reflect.ValueOf(m) + switch rv.Type().Kind() { + case reflect.Map: + for _, k := range rv.MapKeys() { + key := k.String() + if prefix != "" { + key = strings.Join([]string{prefix, k.String()}, ".") + } + addRow(w, key, rv.MapIndex(k).Interface()) + } + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + addRow(w, fmt.Sprintf("%s[%d]", prefix, i), rv.Index(i).Interface()) + } + default: + w.AppendRow(table.Row{prefix, m}) + } +} diff --git a/pkg/tool/tools.go b/pkg/tool/tools.go new file mode 100644 index 0000000..861d7cf --- /dev/null +++ b/pkg/tool/tools.go @@ -0,0 +1,19 @@ +package tool + +import "cmp" + +func Min[T cmp.Ordered](a, b T) T { + if a <= b { + return a + } + + return b +} + +func Max[T cmp.Ordered](a, b T) T { + if a >= b { + return a + } + + return b +} diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..139c81b --- /dev/null +++ b/readme.md @@ -0,0 +1,11 @@ +# UPP - your app + +### Usage + +```go +app := upp.New() + +app.With(db, es, api) + +app.Run(ctx) +``` \ No newline at end of file diff --git a/upp/api.go b/upp/api.go new file mode 100644 index 0000000..3399788 --- /dev/null +++ b/upp/api.go @@ -0,0 +1,41 @@ +package upp + +import ( + "net/http" + + "github.com/loveuer/upp/pkg/api" +) + +func (u *upp) API() *api.App { return u.api.engine } + +func (u *upp) GET(path string, handlers ...api.HandlerFunc) { + u.HandleAPI(http.MethodGet, path, handlers...) +} + +func (u *upp) POST(path string, handlers ...api.HandlerFunc) { + u.HandleAPI(http.MethodPost, path, handlers...) +} + +func (u *upp) PUT(path string, handlers ...api.HandlerFunc) { + u.HandleAPI(http.MethodPut, path, handlers...) +} + +func (u *upp) DELETE(path string, handlers ...api.HandlerFunc) { + u.HandleAPI(http.MethodDelete, path, handlers...) +} + +func (u *upp) PATCH(path string, handlers ...api.HandlerFunc) { + u.HandleAPI(http.MethodPatch, path, handlers...) +} + +func (u *upp) HEAD(path string, handlers ...api.HandlerFunc) { + u.HandleAPI(http.MethodHead, path, handlers...) +} + +func (u *upp) OPTIONS(path string, handlers ...api.HandlerFunc) { + u.HandleAPI(http.MethodOptions, path, handlers...) +} + +func (u *upp) HandleAPI(method, path string, handlers ...api.HandlerFunc) { + u.api.engine.Handle(method, path, handlers...) +} diff --git a/upp/log.go b/upp/log.go new file mode 100644 index 0000000..e4e4f07 --- /dev/null +++ b/upp/log.go @@ -0,0 +1,25 @@ +package upp + +func (u *upp) Debug(msg string, data ...any) { + u.logger.Debug(msg, data...) +} + +func (u *upp) Info(msg string, data ...any) { + u.logger.Info(msg, data...) +} + +func (u *upp) Warn(msg string, data ...any) { + u.logger.Warn(msg, data...) +} + +func (u *upp) Error(msg string, data ...any) { + u.logger.Error(msg, data...) +} + +func (u *upp) Panic(msg string, data ...any) { + u.logger.Panic(msg, data...) +} + +func (u *upp) Fatal(msg string, data ...any) { + u.logger.Fatal(msg, data...) +} diff --git a/upp/module.go b/upp/module.go new file mode 100644 index 0000000..76a21a0 --- /dev/null +++ b/upp/module.go @@ -0,0 +1,80 @@ +package upp + +import ( + "crypto/tls" + "log" + + "github.com/elastic/go-elasticsearch/v7" + "github.com/loveuer/upp/pkg/api" + "github.com/loveuer/upp/pkg/cache" + "github.com/loveuer/upp/pkg/db" + "gorm.io/gorm" +) + +type module func(u *upp) + +func InitDB(uri string, models ...any) module { + db, err := db.New(uri) + if err != nil { + log.Panic(err.Error()) + } + + if err = db.AutoMigrate(models...); err != nil { + log.Panic(err.Error()) + } + + return func(u *upp) { + u.db = db + } +} + +func (u *upp) UseDB() *gorm.DB { + tx := u.db.Session(&gorm.Session{}) + if u.debug { + tx = tx.Debug() + } + return tx +} + +func InitCache(uri string) module { + cache, err := cache.New(uri) + if err != nil { + log.Panic(err.Error()) + } + + return func(u *upp) { + u.cache = cache + } +} + +func (u *upp) UseCache() cache.Cache { + return u.cache +} + +func (u *upp) UseES() *elasticsearch.Client { + return nil +} + +type ApiConfig struct { + Address string + TLSConfig *tls.Config +} + +func InitApi(api *api.App, cfgs ...ApiConfig) module { + cfg := ApiConfig{} + if len(cfgs) > 0 { + cfg = cfgs[0] + } + + if cfg.Address == "" { + cfg.Address = "localhost:8080" + } + + return func(u *upp) { + api.Upp = u + u.api = &uppApi{ + engine: api, + config: cfg, + } + } +} diff --git a/upp/upp.go b/upp/upp.go new file mode 100644 index 0000000..e9be995 --- /dev/null +++ b/upp/upp.go @@ -0,0 +1,71 @@ +package upp + +import ( + "context" + "os/signal" + "syscall" + + "github.com/loveuer/upp/pkg/api" + "github.com/loveuer/upp/pkg/cache" + "github.com/loveuer/upp/pkg/interfaces" + "github.com/loveuer/upp/pkg/log" + "github.com/loveuer/upp/pkg/tool" + "gorm.io/gorm" +) + +type uppApi struct { + engine *api.App + config ApiConfig +} + +type upp struct { + debug bool + logger interfaces.Logger + db *gorm.DB + cache cache.Cache + api *uppApi +} + +func (u *upp) With(modules ...module) { + for _, m := range modules { + m(u) + } +} + +func (u *upp) Run(ctx context.Context) { + u.StartAPI(ctx) + + <-ctx.Done() +} + +func (u *upp) RunSignal() { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + defer cancel() + + if u.api != nil { + u.StartAPI(ctx) + } + + <-ctx.Done() + + u.Warn(" UPP | quit by signal...") + + <-tool.Timeout(2).Done() +} + +func (u *upp) StartAPI(ctx context.Context) { + u.Info("UPP | run api at %s", u.api.config.Address) + go u.api.engine.Run(u.api.config.Address) + go func() { + <-ctx.Done() + u.api.engine.Shutdown(tool.Timeout(2)) + }() +} + +func New() *upp { + app := &upp{ + logger: log.New(), + } + + return app +}