From f3861184b6e7da8d864c6f05cb8cc074b5f8b991 Mon Sep 17 00:00:00 2001 From: loveuer Date: Wed, 23 Oct 2024 22:42:13 +0800 Subject: [PATCH] =?UTF-8?q?wip:=20=E7=BB=A7=E7=BB=AD...?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 8 +- go.sum | 27 ++++- internal/cmd/serve.go | 22 +++- internal/interfaces/database.go | 22 ++++ internal/interfaces/enum.go | 11 ++ internal/interfaces/logger.go | 7 ++ internal/serve/serve.go | 144 +++++++++++++++++++++++++++ internal/store/cache/cache_lru.go | 117 ++++++++++++++++++++++ internal/store/cache/cache_memory.go | 82 +++++++++++++++ internal/store/cache/cache_redis.go | 71 +++++++++++++ internal/store/cache/client.go | 38 +++++++ internal/store/cache/error.go | 7 ++ internal/store/cache/init.go | 69 +++++++++++++ internal/store/cache/scan.go | 20 ++++ internal/tool/ctx.go | 38 +++++++ internal/tool/file.go | 30 ++++++ internal/tool/human.go | 24 +++++ internal/tool/must.go | 11 ++ internal/tool/password.go | 84 ++++++++++++++++ internal/tool/password_test.go | 11 ++ internal/tool/random.go | 54 ++++++++++ internal/tool/slice.go | 5 + internal/tool/slice_test.go | 1 + internal/tool/table.go | 124 +++++++++++++++++++++++ internal/tool/time.go | 13 +++ main.go | 99 ++---------------- 26 files changed, 1043 insertions(+), 96 deletions(-) create mode 100644 internal/interfaces/database.go create mode 100644 internal/interfaces/enum.go create mode 100644 internal/interfaces/logger.go create mode 100644 internal/serve/serve.go create mode 100644 internal/store/cache/cache_lru.go create mode 100644 internal/store/cache/cache_memory.go create mode 100644 internal/store/cache/cache_redis.go create mode 100644 internal/store/cache/client.go create mode 100644 internal/store/cache/error.go create mode 100644 internal/store/cache/init.go create mode 100644 internal/store/cache/scan.go create mode 100644 internal/tool/ctx.go create mode 100644 internal/tool/file.go create mode 100644 internal/tool/human.go create mode 100644 internal/tool/must.go create mode 100644 internal/tool/password.go create mode 100644 internal/tool/password_test.go create mode 100644 internal/tool/random.go create mode 100644 internal/tool/slice.go create mode 100644 internal/tool/slice_test.go create mode 100644 internal/tool/table.go create mode 100644 internal/tool/time.go diff --git a/go.mod b/go.mod index d4965db..4cffa36 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,13 @@ module uauth go 1.20 require ( + gitea.com/taozitaozi/gredis v0.0.0-20240131032054-b02845ce1e9d github.com/google/uuid v1.6.0 - github.com/gorilla/mux v1.8.1 + github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/jedib0t/go-pretty/v6 v6.6.1 github.com/loveuer/nf v0.2.11 github.com/spf13/cobra v1.8.1 + golang.org/x/crypto v0.23.0 ) require ( @@ -14,6 +17,9 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/sirupsen/logrus v1.9.2 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/sys v0.20.0 // indirect ) diff --git a/go.sum b/go.sum index e289871..b6c4813 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,19 @@ +gitea.com/taozitaozi/gredis v0.0.0-20240131032054-b02845ce1e9d h1:TpEOdRGqwzxx+DaN18nFE+g4EQYjneZOO1jcHtSon/g= +gitea.com/taozitaozi/gredis v0.0.0-20240131032054-b02845ce1e9d/go.mod h1:QtcL846XUtSnhmW6TZAujUQ9V5jalY7frxzZOs00kFI= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jedib0t/go-pretty/v6 v6.6.1 h1:iJ65Xjb680rHcikRj6DSIbzCex2huitmc7bDtxYVWyc= +github.com/jedib0t/go-pretty/v6 v6.6.1/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E= github.com/loveuer/nf v0.2.11 h1:W775exDO8eNAHT45WDhXekMYCuWahOW9t1aVmGh3u1o= github.com/loveuer/nf v0.2.11/go.mod h1:M6reF17/kJBis30H4DxR5hrtgo/oJL4AV4cBe4HzJLw= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -14,14 +21,30 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y= +github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go index f15a012..38ce607 100644 --- a/internal/cmd/serve.go +++ b/internal/cmd/serve.go @@ -1,11 +1,25 @@ package cmd -import "github.com/spf13/cobra" +import ( + "github.com/spf13/cobra" + "uauth/internal/serve" +) func initServe() *cobra.Command { - serve := &cobra.Command{ - Use: "serve", + var ( + address string + prefix string + ) + + svc := &cobra.Command{ + Use: "svc", + RunE: func(cmd *cobra.Command, args []string) error { + return serve.Run(cmd.Context(), prefix, address) + }, } - return serve + svc.Flags().StringVar(&address, "address", "localhost:8080", "listen address") + svc.Flags().StringVar(&prefix, "prefix", "/api/oauth/v2", "api prefix") + + return svc } diff --git a/internal/interfaces/database.go b/internal/interfaces/database.go new file mode 100644 index 0000000..af95046 --- /dev/null +++ b/internal/interfaces/database.go @@ -0,0 +1,22 @@ +package interfaces + +import ( + "context" + "time" +) + +type Cacher interface { + Get(ctx context.Context, key string) ([]byte, error) + GetScan(ctx context.Context, key string) Scanner + GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) + GetExScan(ctx context.Context, key string, duration time.Duration) Scanner + // Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal + Set(ctx context.Context, key string, value any) error + // SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal + SetEx(ctx context.Context, key string, value any, duration time.Duration) error + Del(ctx context.Context, keys ...string) error +} + +type Scanner interface { + Scan(model any) error +} diff --git a/internal/interfaces/enum.go b/internal/interfaces/enum.go new file mode 100644 index 0000000..fd43fa6 --- /dev/null +++ b/internal/interfaces/enum.go @@ -0,0 +1,11 @@ +package interfaces + +type Enum interface { + Value() int64 + Code() string + Label() string + + MarshalJSON() ([]byte, error) + + All() []Enum +} diff --git a/internal/interfaces/logger.go b/internal/interfaces/logger.go new file mode 100644 index 0000000..8e75d58 --- /dev/null +++ b/internal/interfaces/logger.go @@ -0,0 +1,7 @@ +package interfaces + +type OpLogger interface { + Enum + Render(content map[string]any) (string, error) + Template() string +} diff --git a/internal/serve/serve.go b/internal/serve/serve.go new file mode 100644 index 0000000..687f3cf --- /dev/null +++ b/internal/serve/serve.go @@ -0,0 +1,144 @@ +package serve + +import ( + "context" + "fmt" + "github.com/google/uuid" + "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/log" + "net/http" + "uauth/internal/tool" +) + +func authenticateUser(username, password string) (bool, error) { + // 这里你应该实现真实的用户认证逻辑 + // 为了简化,我们这里直接硬编码一个用户名和密码 + if username == "user" && password == "pass" { + return true, nil + } + + return false, fmt.Errorf("invalid username or password") +} + +// 处理登录请求 +func handleLogin(c *nf.Ctx) error { + username := c.FormValue("username") + password := c.FormValue("password") + + // 认证用户 + ok, err := authenticateUser(username, password) + if err != nil || !ok { + return c.Status(http.StatusUnauthorized).SendString("Unauthorized") + } + + // 用户认证成功,重定向到授权页面 + http.Redirect(c.Writer, c.Request, "/authorize?client_id=12345&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback&scope=read%20write", http.StatusFound) + + return nil +} + +// 处理授权请求 +func handleAuthorize(c *nf.Ctx) error { + // 解析查询参数 + clientID := c.Query("client_id") + responseType := c.Query("response_type") + redirectURI := c.Query("redirect_uri") + scope := c.Query("scope") + + // 检查客户端 ID 和其他参数 + // 在实际应用中,你需要检查这些参数是否合法 + if clientID != "12345" || responseType != "code" || redirectURI != "http://localhost:8080/callback" { + return c.Status(http.StatusBadRequest).SendString("Invalid request") + } + + // 显示授权页面给用户 + _, err := c.Write([]byte(` + + Authorization + +

Do you want to authorize this application?

+
+ + + + +
+ + + `)) + + return err +} + +// 处理用户的授权批准 +func handleApprove(c *nf.Ctx) error { + // 获取表单数据 + clientID := c.FormValue("client_id") + redirectURI := c.FormValue("redirect_uri") + scope := c.FormValue("scope") + + // 生成授权码 + authorizationCode := uuid.New().String()[:8] + + log.Info("[D] client_id = %s, scope = %s, auth_code = %s", clientID, scope, authorizationCode) + + // 重定向到回调 URL 并附带授权码 + http.Redirect(c.Writer, c.Request, redirectURI+"?code="+authorizationCode, http.StatusFound) + return nil +} + +// 令牌请求的处理 +func handleToken(c *nf.Ctx) error { + // 获取请求参数 + grantType := c.FormValue("grant_type") + code := c.FormValue("code") + redirectURI := c.FormValue("redirect_uri") + + // 简单验证 + if grantType != "authorization_code" { + return c.Status(http.StatusBadRequest).SendString("Unsupported grant type") + } + + mu.Lock() + defer mu.Unlock() + + // 验证授权码是否有效 + accessToken, ok := authCodes[code] + if !ok { + return c.Status(http.StatusBadRequest).SendString("Invalid authorization code") + } + + // 生成访问令牌 + token := generateAccessToken() + + // 返回访问令牌 + return c.JSON(map[string]string{ + "access_token": token, + "token_type": "bearer", + "expires_in": "3600", // 访问令牌有效期(秒) + }) + + // 清除已使用的授权码 + delete(authCodes, code) +} + +func Run(ctx context.Context, prefix string, address string) error { + + app := nf.New() + + api := app.Group(prefix) + // 设置路由 + api.Get("/login", handleLogin) + api.Get("/authorize", handleAuthorize) + api.Post("/approve", handleApprove) + api.Post("/token", handleToken) + + // 启动 HTTP 服务器 + log.Info("Starting server on: %s", address) + go func() { + <-ctx.Done() + _ = app.Shutdown(tool.Timeout(2)) + }() + + return app.Run(address) +} diff --git a/internal/store/cache/cache_lru.go b/internal/store/cache/cache_lru.go new file mode 100644 index 0000000..99665bf --- /dev/null +++ b/internal/store/cache/cache_lru.go @@ -0,0 +1,117 @@ +package cache + +import ( + "context" + "github.com/hashicorp/golang-lru/v2/expirable" + _ "github.com/hashicorp/golang-lru/v2/expirable" + "time" + "uauth/internal/interfaces" +) + +var _ interfaces.Cacher = (*_lru)(nil) + +type _lru struct { + client *expirable.LRU[string, *_lru_value] +} + +type _lru_value struct { + duration time.Duration + last time.Time + bs []byte +} + +func (l *_lru) Get(ctx context.Context, key string) ([]byte, error) { + v, ok := l.client.Get(key) + if !ok { + return nil, ErrorKeyNotFound + } + + if v.duration == 0 { + return v.bs, nil + } + + if time.Now().Sub(v.last) > v.duration { + l.client.Remove(key) + return nil, ErrorKeyNotFound + } + + return v.bs, nil +} + +func (l *_lru) GetScan(ctx context.Context, key string) interfaces.Scanner { + return newScanner(l.Get(ctx, key)) +} + +func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) { + v, ok := l.client.Get(key) + if !ok { + return nil, ErrorKeyNotFound + } + + if v.duration == 0 { + return v.bs, nil + } + + now := time.Now() + + if now.Sub(v.last) > v.duration { + l.client.Remove(key) + return nil, ErrorKeyNotFound + } + + l.client.Add(key, &_lru_value{ + duration: duration, + last: now, + bs: v.bs, + }) + + return v.bs, nil +} + +func (l *_lru) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner { + return newScanner(l.GetEx(ctx, key, duration)) +} + +func (l *_lru) Set(ctx context.Context, key string, value any) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + l.client.Add(key, &_lru_value{ + duration: 0, + last: time.Now(), + bs: bs, + }) + + return nil +} + +func (l *_lru) SetEx(ctx context.Context, key string, value any, duration time.Duration) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + l.client.Add(key, &_lru_value{ + duration: duration, + last: time.Now(), + bs: bs, + }) + + return nil +} + +func (l *_lru) Del(ctx context.Context, keys ...string) error { + for _, key := range keys { + l.client.Remove(key) + } + + return nil +} + +func newLRUCache() (interfaces.Cacher, error) { + client := expirable.NewLRU[string, *_lru_value](1024*1024, nil, 0) + + return &_lru{client: client}, nil +} diff --git a/internal/store/cache/cache_memory.go b/internal/store/cache/cache_memory.go new file mode 100644 index 0000000..fb9e508 --- /dev/null +++ b/internal/store/cache/cache_memory.go @@ -0,0 +1,82 @@ +package cache + +import ( + "context" + "errors" + "fmt" + "time" + "uauth/internal/interfaces" + + "gitea.com/taozitaozi/gredis" +) + +var _ interfaces.Cacher = (*_mem)(nil) + +type _mem struct { + client *gredis.Gredis +} + +func (m *_mem) GetScan(ctx context.Context, key string) interfaces.Scanner { + return newScanner(m.Get(ctx, key)) +} + +func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner { + return newScanner(m.GetEx(ctx, key, duration)) +} + +func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) { + v, err := m.client.Get(key) + if err != nil { + if errors.Is(err, gredis.ErrKeyNotFound) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + bs, ok := v.([]byte) + if !ok { + return nil, fmt.Errorf("invalid value type=%T", v) + } + + return bs, nil +} + +func (m *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) { + v, err := m.client.GetEx(key, duration) + if err != nil { + if errors.Is(err, gredis.ErrKeyNotFound) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + bs, ok := v.([]byte) + if !ok { + return nil, fmt.Errorf("invalid value type=%T", v) + } + + return bs, nil +} + +func (m *_mem) Set(ctx context.Context, key string, value any) error { + bs, err := handleValue(value) + if err != nil { + return err + } + return m.client.Set(key, bs) +} + +func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error { + bs, err := handleValue(value) + if err != nil { + return err + } + return m.client.SetEx(key, bs, duration) +} + +func (m *_mem) Del(ctx context.Context, keys ...string) error { + m.client.Delete(keys...) + return nil +} diff --git a/internal/store/cache/cache_redis.go b/internal/store/cache/cache_redis.go new file mode 100644 index 0000000..5cd2343 --- /dev/null +++ b/internal/store/cache/cache_redis.go @@ -0,0 +1,71 @@ +package cache + +import ( + "context" + "errors" + "time" + "uauth/internal/interfaces" +) + +type _redis struct { + client *redis.Client +} + +func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) { + result, err := r.client.Get(ctx, key).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + return []byte(result), nil +} + +func (r *_redis) GetScan(ctx context.Context, key string) interfaces.Scanner { + return newScanner(r.Get(ctx, key)) +} + +func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) { + result, err := r.client.GetEx(ctx, key, duration).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrorKeyNotFound + } + + return nil, err + } + + return []byte(result), nil +} + +func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner { + return newScanner(r.GetEx(ctx, key, duration)) +} + +func (r *_redis) Set(ctx context.Context, key string, value any) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + _, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result() + return err +} + +func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error { + bs, err := handleValue(value) + if err != nil { + return err + } + + _, err = r.client.SetEX(ctx, key, bs, duration).Result() + + return err +} + +func (r *_redis) Del(ctx context.Context, keys ...string) error { + return r.client.Del(ctx, keys...).Err() +} diff --git a/internal/store/cache/client.go b/internal/store/cache/client.go new file mode 100644 index 0000000..2060d13 --- /dev/null +++ b/internal/store/cache/client.go @@ -0,0 +1,38 @@ +package cache + +import ( + "encoding/json" + "uauth/internal/interfaces" +) + +var ( + Client interfaces.Cacher +) + +type encoded_value interface { + MarshalBinary() ([]byte, error) +} + +type decoded_value interface { + UnmarshalBinary(bs []byte) error +} + +func handleValue(value any) ([]byte, error) { + var ( + bs []byte + err error + ) + + switch value.(type) { + case []byte: + return value.([]byte), nil + } + + if imp, ok := value.(encoded_value); ok { + bs, err = imp.MarshalBinary() + } else { + bs, err = json.Marshal(value) + } + + return bs, err +} diff --git a/internal/store/cache/error.go b/internal/store/cache/error.go new file mode 100644 index 0000000..f0798b5 --- /dev/null +++ b/internal/store/cache/error.go @@ -0,0 +1,7 @@ +package cache + +import "errors" + +var ( + ErrorKeyNotFound = errors.New("key not found") +) diff --git a/internal/store/cache/init.go b/internal/store/cache/init.go new file mode 100644 index 0000000..b5f9845 --- /dev/null +++ b/internal/store/cache/init.go @@ -0,0 +1,69 @@ +package cache + +import ( + "fmt" + "gitea.com/taozitaozi/gredis" + "net/url" + "strings" + "uauth/internal/opt" + "uauth/internal/tool" +) + +func Init() error { + + var ( + err error + ) + + strs := strings.Split(opt.Cfg.Cache.Uri, "::") + + switch strs[0] { + case "memory": + gc := gredis.NewGredis(1024 * 1024) + Client = &_mem{client: gc} + case "lru": + if Client, err = newLRUCache(); err != nil { + return err + } + case "redis": + var ( + ins *url.URL + err error + ) + + if len(strs) != 2 { + return fmt.Errorf("cache.Init: invalid cache uri: %s", opt.Cfg.Cache.Uri) + } + + uri := strs[1] + + if !strings.Contains(uri, "://") { + uri = fmt.Sprintf("redis://%s", uri) + } + + if ins, err = url.Parse(uri); err != nil { + return fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", opt.Cfg.Cache.Uri, err.Error()) + } + + addr := ins.Host + username := ins.User.Username() + password, _ := ins.User.Password() + + var rc *redis.Client + rc = redis.NewClient(&redis.Options{ + Addr: addr, + Username: username, + Password: password, + }) + + if err = rc.Ping(tool.Timeout(5)).Err(); err != nil { + return fmt.Errorf("cache.Init: redis ping err: %s", err.Error()) + } + + Client = &_redis{client: rc} + default: + return fmt.Errorf("cache type %s not support", strs[0]) + } + + return nil +} diff --git a/internal/store/cache/scan.go b/internal/store/cache/scan.go new file mode 100644 index 0000000..c65d267 --- /dev/null +++ b/internal/store/cache/scan.go @@ -0,0 +1,20 @@ +package cache + +import "encoding/json" + +type scanner struct { + err error + bs []byte +} + +func (s *scanner) Scan(model any) error { + if s.err != nil { + return s.err + } + + return json.Unmarshal(s.bs, model) +} + +func newScanner(bs []byte, err error) *scanner { + return &scanner{bs: bs, err: err} +} diff --git a/internal/tool/ctx.go b/internal/tool/ctx.go new file mode 100644 index 0000000..82242a3 --- /dev/null +++ b/internal/tool/ctx.go @@ -0,0 +1,38 @@ +package tool + +import ( + "context" + "time" +) + +func Timeout(seconds ...int) (ctx context.Context) { + var ( + duration time.Duration + ) + + if len(seconds) > 0 && seconds[0] > 0 { + duration = time.Duration(seconds[0]) * time.Second + } else { + duration = time.Duration(30) * time.Second + } + + ctx, _ = context.WithTimeout(context.Background(), duration) + + return +} + +func TimeoutCtx(ctx context.Context, seconds ...int) context.Context { + var ( + duration time.Duration + ) + + if len(seconds) > 0 && seconds[0] > 0 { + duration = time.Duration(seconds[0]) * time.Second + } else { + duration = time.Duration(30) * time.Second + } + + nctx, _ := context.WithTimeout(ctx, duration) + + return nctx +} diff --git a/internal/tool/file.go b/internal/tool/file.go new file mode 100644 index 0000000..cefa36d --- /dev/null +++ b/internal/tool/file.go @@ -0,0 +1,30 @@ +package tool + +import ( + "io" + "os" +) + +func CopyFile(src string, dst string) (err error) { + // Open the source file + sourceFile, err := os.Open(src) + if err != nil { + return err + } + defer sourceFile.Close() + + // Create the destination file + destinationFile, err := os.Create(dst) + if err != nil { + return err + } + defer destinationFile.Close() + + // Copy the contents from source to destination + _, err = io.Copy(destinationFile, sourceFile) + if err != nil { + return err + } + + return nil +} diff --git a/internal/tool/human.go b/internal/tool/human.go new file mode 100644 index 0000000..2c7ce71 --- /dev/null +++ b/internal/tool/human.go @@ -0,0 +1,24 @@ +package tool + +import "fmt" + +func HumanDuration(nano int64) string { + duration := float64(nano) + unit := "ns" + if duration >= 1000 { + duration /= 1000 + unit = "us" + } + + if duration >= 1000 { + duration /= 1000 + unit = "ms" + } + + if duration >= 1000 { + duration /= 1000 + unit = " s" + } + + return fmt.Sprintf("%6.2f%s", duration, unit) +} diff --git a/internal/tool/must.go b/internal/tool/must.go new file mode 100644 index 0000000..0615f8d --- /dev/null +++ b/internal/tool/must.go @@ -0,0 +1,11 @@ +package tool + +import "github.com/loveuer/nf/nft/log" + +func Must(errs ...error) { + for _, err := range errs { + if err != nil { + log.Panic(err.Error()) + } + } +} diff --git a/internal/tool/password.go b/internal/tool/password.go new file mode 100644 index 0000000..c2d1a17 --- /dev/null +++ b/internal/tool/password.go @@ -0,0 +1,84 @@ +package tool + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "github.com/loveuer/nf/nft/log" + "golang.org/x/crypto/pbkdf2" + "regexp" + "strconv" + "strings" +) + +const ( + EncryptHeader string = "pbkdf2:sha256" // 用户密码加密 +) + +func NewPassword(password string) string { + return EncryptPassword(password, RandomString(8), int(RandomInt(50000)+100000)) +} + +func ComparePassword(in, db string) bool { + strs := strings.Split(db, "$") + if len(strs) != 3 { + log.Error("password in db invalid: %s", db) + return false + } + + encs := strings.Split(strs[0], ":") + if len(encs) != 3 { + log.Error("password in db invalid: %s", db) + return false + } + + encIteration, err := strconv.Atoi(encs[2]) + if err != nil { + log.Error("password in db invalid: %s, convert iter err: %s", db, err) + return false + } + + return EncryptPassword(in, strs[1], encIteration) == db +} + +func EncryptPassword(password, salt string, iter int) string { + hash := pbkdf2.Key([]byte(password), []byte(salt), iter, 32, sha256.New) + encrypted := hex.EncodeToString(hash) + return fmt.Sprintf("%s:%d$%s$%s", EncryptHeader, iter, salt, encrypted) +} + +func CheckPassword(password string) error { + if len(password) < 8 || len(password) > 32 { + return errors.New("密码长度不符合") + } + + var ( + err error + match bool + patternList = []string{`[0-9]+`, `[a-z]+`, `[A-Z]+`, `[!@#%]+`} //, `[~!@#$%^&*?_-]+`} + matchAccount = 0 + tips = []string{"缺少数字", "缺少小写字母", "缺少大写字母", "缺少'!@#%'"} + locktips = make([]string, 0) + ) + + for idx, pattern := range patternList { + match, err = regexp.MatchString(pattern, password) + if err != nil { + log.Warn("regex match string err, reg_str: %s, err: %v", pattern, err) + return errors.New("密码强度不够") + } + + if match { + matchAccount++ + } else { + locktips = append(locktips, tips[idx]) + } + } + + if matchAccount < 3 { + return fmt.Errorf("密码强度不够, 可能 %s", strings.Join(locktips, ", ")) + } + + return nil +} diff --git a/internal/tool/password_test.go b/internal/tool/password_test.go new file mode 100644 index 0000000..aabd667 --- /dev/null +++ b/internal/tool/password_test.go @@ -0,0 +1,11 @@ +package tool + +import "testing" + +func TestEncPassword(t *testing.T) { + password := "123456" + + result := EncryptPassword(password, RandomString(8), 50000) + + t.Logf("sum => %s", result) +} diff --git a/internal/tool/random.go b/internal/tool/random.go new file mode 100644 index 0000000..266cb4c --- /dev/null +++ b/internal/tool/random.go @@ -0,0 +1,54 @@ +package tool + +import ( + "crypto/rand" + "math/big" +) + +var ( + letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + letterNum = []byte("0123456789") + letterLow = []byte("abcdefghijklmnopqrstuvwxyz") + letterCap = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + letterSyb = []byte("!@#$%^&*()_+-=") +) + +func RandomInt(max int64) int64 { + num, _ := rand.Int(rand.Reader, big.NewInt(max)) + return num.Int64() +} + +func RandomString(length int) string { + result := make([]byte, length) + for i := 0; i < length; i++ { + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + result[i] = letters[num.Int64()] + } + return string(result) +} + +func RandomPassword(length int, withSymbol bool) string { + result := make([]byte, length) + kind := 3 + if withSymbol { + kind++ + } + + for i := 0; i < length; i++ { + switch i % kind { + case 0: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterNum)))) + result[i] = letterNum[num.Int64()] + case 1: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterLow)))) + result[i] = letterLow[num.Int64()] + case 2: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterCap)))) + result[i] = letterCap[num.Int64()] + case 3: + num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterSyb)))) + result[i] = letterSyb[num.Int64()] + } + } + return string(result) +} diff --git a/internal/tool/slice.go b/internal/tool/slice.go new file mode 100644 index 0000000..05a7dd5 --- /dev/null +++ b/internal/tool/slice.go @@ -0,0 +1,5 @@ +package tool + +func Bulk[T any](slice []T, size int) { + // todo +} diff --git a/internal/tool/slice_test.go b/internal/tool/slice_test.go new file mode 100644 index 0000000..05b1676 --- /dev/null +++ b/internal/tool/slice_test.go @@ -0,0 +1 @@ +package tool diff --git a/internal/tool/table.go b/internal/tool/table.go new file mode 100644 index 0000000..ffaaf31 --- /dev/null +++ b/internal/tool/table.go @@ -0,0 +1,124 @@ +package tool + +import ( + "encoding/json" + "fmt" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/loveuer/nf/nft/log" + "io" + "os" + "reflect" + "strings" +) + +func TablePrinter(data any, writers ...io.Writer) { + var w io.Writer = os.Stdout + if len(writers) > 0 && writers[0] != nil { + w = writers[0] + } + + t := table.NewWriter() + structPrinter(t, "", data) + _, _ = fmt.Fprintln(w, t.Render()) +} + +func structPrinter(w table.Writer, prefix string, item any) { +Start: + rv := reflect.ValueOf(item) + if rv.IsZero() { + return + } + + for rv.Type().Kind() == reflect.Pointer { + rv = rv.Elem() + } + + switch rv.Type().Kind() { + case reflect.Invalid, + reflect.Uintptr, + reflect.Chan, + reflect.Func, + reflect.UnsafePointer: + case reflect.Bool, + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Float32, + reflect.Float64, + reflect.Complex64, + reflect.Complex128, + reflect.Interface: + w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), rv.Interface()}) + case reflect.String: + val := rv.String() + if len(val) <= 160 { + w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val}) + return + } + + w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val[0:64] + "..." + val[len(val)-64:]}) + case reflect.Array, reflect.Slice: + for i := 0; i < rv.Len(); i++ { + p := strings.Join([]string{prefix, fmt.Sprintf("[%d]", i)}, ".") + structPrinter(w, p, rv.Index(i).Interface()) + } + case reflect.Map: + for _, k := range rv.MapKeys() { + structPrinter(w, fmt.Sprintf("%s.{%v}", prefix, k), rv.MapIndex(k).Interface()) + } + case reflect.Pointer: + goto Start + case reflect.Struct: + for i := 0; i < rv.NumField(); i++ { + p := fmt.Sprintf("%s.%s", prefix, rv.Type().Field(i).Name) + field := rv.Field(i) + + //log.Debug("TablePrinter: prefix: %s, field: %v", p, rv.Field(i)) + + if !field.CanInterface() { + return + } + + structPrinter(w, p, field.Interface()) + } + } +} + +func TableMapPrinter(data []byte) { + m := make(map[string]any) + if err := json.Unmarshal(data, &m); err != nil { + log.Warn(err.Error()) + return + } + + t := table.NewWriter() + addRow(t, "", m) + fmt.Println(t.Render()) +} + +func addRow(w table.Writer, prefix string, m any) { + rv := reflect.ValueOf(m) + switch rv.Type().Kind() { + case reflect.Map: + for _, k := range rv.MapKeys() { + key := k.String() + if prefix != "" { + key = strings.Join([]string{prefix, k.String()}, ".") + } + addRow(w, key, rv.MapIndex(k).Interface()) + } + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + addRow(w, fmt.Sprintf("%s[%d]", prefix, i), rv.Index(i).Interface()) + } + default: + w.AppendRow(table.Row{prefix, m}) + } +} diff --git a/internal/tool/time.go b/internal/tool/time.go new file mode 100644 index 0000000..a193f20 --- /dev/null +++ b/internal/tool/time.go @@ -0,0 +1,13 @@ +package tool + +import "time" + +// TodayMidnight 返回今日凌晨 +func TodayMidnight() (midnight time.Time) { + now := time.Now() + + year, month, day := now.Date() + midnight = time.Date(year, month, day, 0, 0, 0, 0, time.Local) + + return +} diff --git a/main.go b/main.go index 81bf409..78d5993 100644 --- a/main.go +++ b/main.go @@ -1,100 +1,21 @@ package main import ( - "fmt" - "github.com/google/uuid" - "github.com/loveuer/nf" - "log" - "net/http" + "context" + "github.com/loveuer/nf/nft/log" + "os/signal" + "syscall" + "uauth/internal/cmd" ) // 假设这是你的用户认证函数 -func authenticateUser(username, password string) (bool, error) { - // 这里你应该实现真实的用户认证逻辑 - // 为了简化,我们这里直接硬编码一个用户名和密码 - if username == "user" && password == "pass" { - return true, nil - } - - return false, fmt.Errorf("invalid username or password") -} - -// 处理登录请求 -func handleLogin(c *nf.Ctx) error { - username := c.FormValue("username") - password := c.FormValue("password") - - // 认证用户 - ok, err := authenticateUser(username, password) - if err != nil || !ok { - return c.Status(http.StatusUnauthorized).SendString("Unauthorized") - } - - // 用户认证成功,重定向到授权页面 - http.Redirect(c.Writer, c.Request, "/authorize?client_id=12345&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback&scope=read%20write", http.StatusFound) - - return nil -} - -// 处理授权请求 -func handleAuthorize(c *nf.Ctx) error { - // 解析查询参数 - clientID := c.Query("client_id") - responseType := c.Query("response_type") - redirectURI := c.Query("redirect_uri") - scope := c.Query("scope") - - // 检查客户端 ID 和其他参数 - // 在实际应用中,你需要检查这些参数是否合法 - if clientID != "12345" || responseType != "code" || redirectURI != "http://localhost:8080/callback" { - return c.Status(http.StatusBadRequest).SendString("Invalid request") - } - - // 显示授权页面给用户 - _, err := c.Write([]byte(` - - Authorization - -

Do you want to authorize this application?

-
- - - - -
- - - `)) - - return err -} - -// 处理用户的授权批准 -func handleApprove(c *nf.Ctx) error { - // 获取表单数据 - clientID := c.FormValue("client_id") - redirectURI := c.FormValue("redirect_uri") - scope := c.FormValue("scope") - - // 生成授权码 - authorizationCode := uuid.New().String()[:8] - - log.Printf("[D] client_id = %s, scope = %s, auth_code = %s", clientID, scope, authorizationCode) - - // 重定向到回调 URL 并附带授权码 - http.Redirect(c.Writer, c.Request, redirectURI+"?code="+authorizationCode, http.StatusFound) - return nil -} func main() { - app := nf.New() + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + defer cancel() - // 设置路由 - app.Get("/login", handleLogin) - app.Get("/authorize", handleAuthorize) - app.Post("/approve", handleApprove) + if err := cmd.Command.ExecuteContext(ctx); err != nil { + log.Error(err.Error()) + } - // 启动 HTTP 服务器 - log.Println("Starting server on :8080") - log.Fatal(app.Run(":8080")) }