diff --git a/go.mod b/go.mod
index d4965db..4cffa36 100644
--- a/go.mod
+++ b/go.mod
@@ -3,10 +3,13 @@ module uauth
go 1.20
require (
+ gitea.com/taozitaozi/gredis v0.0.0-20240131032054-b02845ce1e9d
github.com/google/uuid v1.6.0
- github.com/gorilla/mux v1.8.1
+ github.com/hashicorp/golang-lru/v2 v2.0.7
+ github.com/jedib0t/go-pretty/v6 v6.6.1
github.com/loveuer/nf v0.2.11
github.com/spf13/cobra v1.8.1
+ golang.org/x/crypto v0.23.0
)
require (
@@ -14,6 +17,9 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/mattn/go-runewidth v0.0.15 // indirect
+ github.com/rivo/uniseg v0.2.0 // indirect
+ github.com/sirupsen/logrus v1.9.2 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sys v0.20.0 // indirect
)
diff --git a/go.sum b/go.sum
index e289871..b6c4813 100644
--- a/go.sum
+++ b/go.sum
@@ -1,12 +1,19 @@
+gitea.com/taozitaozi/gredis v0.0.0-20240131032054-b02845ce1e9d h1:TpEOdRGqwzxx+DaN18nFE+g4EQYjneZOO1jcHtSon/g=
+gitea.com/taozitaozi/gredis v0.0.0-20240131032054-b02845ce1e9d/go.mod h1:QtcL846XUtSnhmW6TZAujUQ9V5jalY7frxzZOs00kFI=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
-github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
+github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
+github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
+github.com/jedib0t/go-pretty/v6 v6.6.1 h1:iJ65Xjb680rHcikRj6DSIbzCex2huitmc7bDtxYVWyc=
+github.com/jedib0t/go-pretty/v6 v6.6.1/go.mod h1:zbn98qrYlh95FIhwwsbIip0LYpwSG8SUOScs+v9/t0E=
github.com/loveuer/nf v0.2.11 h1:W775exDO8eNAHT45WDhXekMYCuWahOW9t1aVmGh3u1o=
github.com/loveuer/nf v0.2.11/go.mod h1:M6reF17/kJBis30H4DxR5hrtgo/oJL4AV4cBe4HzJLw=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
@@ -14,14 +21,30 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
+github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
+github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y=
+github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
+golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
+golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
+golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go
index f15a012..38ce607 100644
--- a/internal/cmd/serve.go
+++ b/internal/cmd/serve.go
@@ -1,11 +1,25 @@
package cmd
-import "github.com/spf13/cobra"
+import (
+ "github.com/spf13/cobra"
+ "uauth/internal/serve"
+)
func initServe() *cobra.Command {
- serve := &cobra.Command{
- Use: "serve",
+ var (
+ address string
+ prefix string
+ )
+
+ svc := &cobra.Command{
+ Use: "svc",
+ RunE: func(cmd *cobra.Command, args []string) error {
+ return serve.Run(cmd.Context(), prefix, address)
+ },
}
- return serve
+ svc.Flags().StringVar(&address, "address", "localhost:8080", "listen address")
+ svc.Flags().StringVar(&prefix, "prefix", "/api/oauth/v2", "api prefix")
+
+ return svc
}
diff --git a/internal/interfaces/database.go b/internal/interfaces/database.go
new file mode 100644
index 0000000..af95046
--- /dev/null
+++ b/internal/interfaces/database.go
@@ -0,0 +1,22 @@
+package interfaces
+
+import (
+ "context"
+ "time"
+)
+
+type Cacher interface {
+ Get(ctx context.Context, key string) ([]byte, error)
+ GetScan(ctx context.Context, key string) Scanner
+ GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
+ GetExScan(ctx context.Context, key string, duration time.Duration) Scanner
+ // Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
+ Set(ctx context.Context, key string, value any) error
+ // SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
+ SetEx(ctx context.Context, key string, value any, duration time.Duration) error
+ Del(ctx context.Context, keys ...string) error
+}
+
+type Scanner interface {
+ Scan(model any) error
+}
diff --git a/internal/interfaces/enum.go b/internal/interfaces/enum.go
new file mode 100644
index 0000000..fd43fa6
--- /dev/null
+++ b/internal/interfaces/enum.go
@@ -0,0 +1,11 @@
+package interfaces
+
+type Enum interface {
+ Value() int64
+ Code() string
+ Label() string
+
+ MarshalJSON() ([]byte, error)
+
+ All() []Enum
+}
diff --git a/internal/interfaces/logger.go b/internal/interfaces/logger.go
new file mode 100644
index 0000000..8e75d58
--- /dev/null
+++ b/internal/interfaces/logger.go
@@ -0,0 +1,7 @@
+package interfaces
+
+type OpLogger interface {
+ Enum
+ Render(content map[string]any) (string, error)
+ Template() string
+}
diff --git a/internal/serve/serve.go b/internal/serve/serve.go
new file mode 100644
index 0000000..687f3cf
--- /dev/null
+++ b/internal/serve/serve.go
@@ -0,0 +1,144 @@
+package serve
+
+import (
+ "context"
+ "fmt"
+ "github.com/google/uuid"
+ "github.com/loveuer/nf"
+ "github.com/loveuer/nf/nft/log"
+ "net/http"
+ "uauth/internal/tool"
+)
+
+func authenticateUser(username, password string) (bool, error) {
+ // 这里你应该实现真实的用户认证逻辑
+ // 为了简化,我们这里直接硬编码一个用户名和密码
+ if username == "user" && password == "pass" {
+ return true, nil
+ }
+
+ return false, fmt.Errorf("invalid username or password")
+}
+
+// 处理登录请求
+func handleLogin(c *nf.Ctx) error {
+ username := c.FormValue("username")
+ password := c.FormValue("password")
+
+ // 认证用户
+ ok, err := authenticateUser(username, password)
+ if err != nil || !ok {
+ return c.Status(http.StatusUnauthorized).SendString("Unauthorized")
+ }
+
+ // 用户认证成功,重定向到授权页面
+ http.Redirect(c.Writer, c.Request, "/authorize?client_id=12345&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback&scope=read%20write", http.StatusFound)
+
+ return nil
+}
+
+// 处理授权请求
+func handleAuthorize(c *nf.Ctx) error {
+ // 解析查询参数
+ clientID := c.Query("client_id")
+ responseType := c.Query("response_type")
+ redirectURI := c.Query("redirect_uri")
+ scope := c.Query("scope")
+
+ // 检查客户端 ID 和其他参数
+ // 在实际应用中,你需要检查这些参数是否合法
+ if clientID != "12345" || responseType != "code" || redirectURI != "http://localhost:8080/callback" {
+ return c.Status(http.StatusBadRequest).SendString("Invalid request")
+ }
+
+ // 显示授权页面给用户
+ _, err := c.Write([]byte(`
+
+
Authorization
+
+ Do you want to authorize this application?
+
+
+
+ `))
+
+ return err
+}
+
+// 处理用户的授权批准
+func handleApprove(c *nf.Ctx) error {
+ // 获取表单数据
+ clientID := c.FormValue("client_id")
+ redirectURI := c.FormValue("redirect_uri")
+ scope := c.FormValue("scope")
+
+ // 生成授权码
+ authorizationCode := uuid.New().String()[:8]
+
+ log.Info("[D] client_id = %s, scope = %s, auth_code = %s", clientID, scope, authorizationCode)
+
+ // 重定向到回调 URL 并附带授权码
+ http.Redirect(c.Writer, c.Request, redirectURI+"?code="+authorizationCode, http.StatusFound)
+ return nil
+}
+
+// 令牌请求的处理
+func handleToken(c *nf.Ctx) error {
+ // 获取请求参数
+ grantType := c.FormValue("grant_type")
+ code := c.FormValue("code")
+ redirectURI := c.FormValue("redirect_uri")
+
+ // 简单验证
+ if grantType != "authorization_code" {
+ return c.Status(http.StatusBadRequest).SendString("Unsupported grant type")
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+
+ // 验证授权码是否有效
+ accessToken, ok := authCodes[code]
+ if !ok {
+ return c.Status(http.StatusBadRequest).SendString("Invalid authorization code")
+ }
+
+ // 生成访问令牌
+ token := generateAccessToken()
+
+ // 返回访问令牌
+ return c.JSON(map[string]string{
+ "access_token": token,
+ "token_type": "bearer",
+ "expires_in": "3600", // 访问令牌有效期(秒)
+ })
+
+ // 清除已使用的授权码
+ delete(authCodes, code)
+}
+
+func Run(ctx context.Context, prefix string, address string) error {
+
+ app := nf.New()
+
+ api := app.Group(prefix)
+ // 设置路由
+ api.Get("/login", handleLogin)
+ api.Get("/authorize", handleAuthorize)
+ api.Post("/approve", handleApprove)
+ api.Post("/token", handleToken)
+
+ // 启动 HTTP 服务器
+ log.Info("Starting server on: %s", address)
+ go func() {
+ <-ctx.Done()
+ _ = app.Shutdown(tool.Timeout(2))
+ }()
+
+ return app.Run(address)
+}
diff --git a/internal/store/cache/cache_lru.go b/internal/store/cache/cache_lru.go
new file mode 100644
index 0000000..99665bf
--- /dev/null
+++ b/internal/store/cache/cache_lru.go
@@ -0,0 +1,117 @@
+package cache
+
+import (
+ "context"
+ "github.com/hashicorp/golang-lru/v2/expirable"
+ _ "github.com/hashicorp/golang-lru/v2/expirable"
+ "time"
+ "uauth/internal/interfaces"
+)
+
+var _ interfaces.Cacher = (*_lru)(nil)
+
+type _lru struct {
+ client *expirable.LRU[string, *_lru_value]
+}
+
+type _lru_value struct {
+ duration time.Duration
+ last time.Time
+ bs []byte
+}
+
+func (l *_lru) Get(ctx context.Context, key string) ([]byte, error) {
+ v, ok := l.client.Get(key)
+ if !ok {
+ return nil, ErrorKeyNotFound
+ }
+
+ if v.duration == 0 {
+ return v.bs, nil
+ }
+
+ if time.Now().Sub(v.last) > v.duration {
+ l.client.Remove(key)
+ return nil, ErrorKeyNotFound
+ }
+
+ return v.bs, nil
+}
+
+func (l *_lru) GetScan(ctx context.Context, key string) interfaces.Scanner {
+ return newScanner(l.Get(ctx, key))
+}
+
+func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
+ v, ok := l.client.Get(key)
+ if !ok {
+ return nil, ErrorKeyNotFound
+ }
+
+ if v.duration == 0 {
+ return v.bs, nil
+ }
+
+ now := time.Now()
+
+ if now.Sub(v.last) > v.duration {
+ l.client.Remove(key)
+ return nil, ErrorKeyNotFound
+ }
+
+ l.client.Add(key, &_lru_value{
+ duration: duration,
+ last: now,
+ bs: v.bs,
+ })
+
+ return v.bs, nil
+}
+
+func (l *_lru) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner {
+ return newScanner(l.GetEx(ctx, key, duration))
+}
+
+func (l *_lru) Set(ctx context.Context, key string, value any) error {
+ bs, err := handleValue(value)
+ if err != nil {
+ return err
+ }
+
+ l.client.Add(key, &_lru_value{
+ duration: 0,
+ last: time.Now(),
+ bs: bs,
+ })
+
+ return nil
+}
+
+func (l *_lru) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
+ bs, err := handleValue(value)
+ if err != nil {
+ return err
+ }
+
+ l.client.Add(key, &_lru_value{
+ duration: duration,
+ last: time.Now(),
+ bs: bs,
+ })
+
+ return nil
+}
+
+func (l *_lru) Del(ctx context.Context, keys ...string) error {
+ for _, key := range keys {
+ l.client.Remove(key)
+ }
+
+ return nil
+}
+
+func newLRUCache() (interfaces.Cacher, error) {
+ client := expirable.NewLRU[string, *_lru_value](1024*1024, nil, 0)
+
+ return &_lru{client: client}, nil
+}
diff --git a/internal/store/cache/cache_memory.go b/internal/store/cache/cache_memory.go
new file mode 100644
index 0000000..fb9e508
--- /dev/null
+++ b/internal/store/cache/cache_memory.go
@@ -0,0 +1,82 @@
+package cache
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+ "uauth/internal/interfaces"
+
+ "gitea.com/taozitaozi/gredis"
+)
+
+var _ interfaces.Cacher = (*_mem)(nil)
+
+type _mem struct {
+ client *gredis.Gredis
+}
+
+func (m *_mem) GetScan(ctx context.Context, key string) interfaces.Scanner {
+ return newScanner(m.Get(ctx, key))
+}
+
+func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner {
+ return newScanner(m.GetEx(ctx, key, duration))
+}
+
+func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) {
+ v, err := m.client.Get(key)
+ if err != nil {
+ if errors.Is(err, gredis.ErrKeyNotFound) {
+ return nil, ErrorKeyNotFound
+ }
+
+ return nil, err
+ }
+
+ bs, ok := v.([]byte)
+ if !ok {
+ return nil, fmt.Errorf("invalid value type=%T", v)
+ }
+
+ return bs, nil
+}
+
+func (m *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
+ v, err := m.client.GetEx(key, duration)
+ if err != nil {
+ if errors.Is(err, gredis.ErrKeyNotFound) {
+ return nil, ErrorKeyNotFound
+ }
+
+ return nil, err
+ }
+
+ bs, ok := v.([]byte)
+ if !ok {
+ return nil, fmt.Errorf("invalid value type=%T", v)
+ }
+
+ return bs, nil
+}
+
+func (m *_mem) Set(ctx context.Context, key string, value any) error {
+ bs, err := handleValue(value)
+ if err != nil {
+ return err
+ }
+ return m.client.Set(key, bs)
+}
+
+func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
+ bs, err := handleValue(value)
+ if err != nil {
+ return err
+ }
+ return m.client.SetEx(key, bs, duration)
+}
+
+func (m *_mem) Del(ctx context.Context, keys ...string) error {
+ m.client.Delete(keys...)
+ return nil
+}
diff --git a/internal/store/cache/cache_redis.go b/internal/store/cache/cache_redis.go
new file mode 100644
index 0000000..5cd2343
--- /dev/null
+++ b/internal/store/cache/cache_redis.go
@@ -0,0 +1,71 @@
+package cache
+
+import (
+ "context"
+ "errors"
+ "time"
+ "uauth/internal/interfaces"
+)
+
+type _redis struct {
+ client *redis.Client
+}
+
+func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) {
+ result, err := r.client.Get(ctx, key).Result()
+ if err != nil {
+ if errors.Is(err, redis.Nil) {
+ return nil, ErrorKeyNotFound
+ }
+
+ return nil, err
+ }
+
+ return []byte(result), nil
+}
+
+func (r *_redis) GetScan(ctx context.Context, key string) interfaces.Scanner {
+ return newScanner(r.Get(ctx, key))
+}
+
+func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
+ result, err := r.client.GetEx(ctx, key, duration).Result()
+ if err != nil {
+ if errors.Is(err, redis.Nil) {
+ return nil, ErrorKeyNotFound
+ }
+
+ return nil, err
+ }
+
+ return []byte(result), nil
+}
+
+func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner {
+ return newScanner(r.GetEx(ctx, key, duration))
+}
+
+func (r *_redis) Set(ctx context.Context, key string, value any) error {
+ bs, err := handleValue(value)
+ if err != nil {
+ return err
+ }
+
+ _, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result()
+ return err
+}
+
+func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
+ bs, err := handleValue(value)
+ if err != nil {
+ return err
+ }
+
+ _, err = r.client.SetEX(ctx, key, bs, duration).Result()
+
+ return err
+}
+
+func (r *_redis) Del(ctx context.Context, keys ...string) error {
+ return r.client.Del(ctx, keys...).Err()
+}
diff --git a/internal/store/cache/client.go b/internal/store/cache/client.go
new file mode 100644
index 0000000..2060d13
--- /dev/null
+++ b/internal/store/cache/client.go
@@ -0,0 +1,38 @@
+package cache
+
+import (
+ "encoding/json"
+ "uauth/internal/interfaces"
+)
+
+var (
+ Client interfaces.Cacher
+)
+
+type encoded_value interface {
+ MarshalBinary() ([]byte, error)
+}
+
+type decoded_value interface {
+ UnmarshalBinary(bs []byte) error
+}
+
+func handleValue(value any) ([]byte, error) {
+ var (
+ bs []byte
+ err error
+ )
+
+ switch value.(type) {
+ case []byte:
+ return value.([]byte), nil
+ }
+
+ if imp, ok := value.(encoded_value); ok {
+ bs, err = imp.MarshalBinary()
+ } else {
+ bs, err = json.Marshal(value)
+ }
+
+ return bs, err
+}
diff --git a/internal/store/cache/error.go b/internal/store/cache/error.go
new file mode 100644
index 0000000..f0798b5
--- /dev/null
+++ b/internal/store/cache/error.go
@@ -0,0 +1,7 @@
+package cache
+
+import "errors"
+
+var (
+ ErrorKeyNotFound = errors.New("key not found")
+)
diff --git a/internal/store/cache/init.go b/internal/store/cache/init.go
new file mode 100644
index 0000000..b5f9845
--- /dev/null
+++ b/internal/store/cache/init.go
@@ -0,0 +1,69 @@
+package cache
+
+import (
+ "fmt"
+ "gitea.com/taozitaozi/gredis"
+ "net/url"
+ "strings"
+ "uauth/internal/opt"
+ "uauth/internal/tool"
+)
+
+func Init() error {
+
+ var (
+ err error
+ )
+
+ strs := strings.Split(opt.Cfg.Cache.Uri, "::")
+
+ switch strs[0] {
+ case "memory":
+ gc := gredis.NewGredis(1024 * 1024)
+ Client = &_mem{client: gc}
+ case "lru":
+ if Client, err = newLRUCache(); err != nil {
+ return err
+ }
+ case "redis":
+ var (
+ ins *url.URL
+ err error
+ )
+
+ if len(strs) != 2 {
+ return fmt.Errorf("cache.Init: invalid cache uri: %s", opt.Cfg.Cache.Uri)
+ }
+
+ uri := strs[1]
+
+ if !strings.Contains(uri, "://") {
+ uri = fmt.Sprintf("redis://%s", uri)
+ }
+
+ if ins, err = url.Parse(uri); err != nil {
+ return fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", opt.Cfg.Cache.Uri, err.Error())
+ }
+
+ addr := ins.Host
+ username := ins.User.Username()
+ password, _ := ins.User.Password()
+
+ var rc *redis.Client
+ rc = redis.NewClient(&redis.Options{
+ Addr: addr,
+ Username: username,
+ Password: password,
+ })
+
+ if err = rc.Ping(tool.Timeout(5)).Err(); err != nil {
+ return fmt.Errorf("cache.Init: redis ping err: %s", err.Error())
+ }
+
+ Client = &_redis{client: rc}
+ default:
+ return fmt.Errorf("cache type %s not support", strs[0])
+ }
+
+ return nil
+}
diff --git a/internal/store/cache/scan.go b/internal/store/cache/scan.go
new file mode 100644
index 0000000..c65d267
--- /dev/null
+++ b/internal/store/cache/scan.go
@@ -0,0 +1,20 @@
+package cache
+
+import "encoding/json"
+
+type scanner struct {
+ err error
+ bs []byte
+}
+
+func (s *scanner) Scan(model any) error {
+ if s.err != nil {
+ return s.err
+ }
+
+ return json.Unmarshal(s.bs, model)
+}
+
+func newScanner(bs []byte, err error) *scanner {
+ return &scanner{bs: bs, err: err}
+}
diff --git a/internal/tool/ctx.go b/internal/tool/ctx.go
new file mode 100644
index 0000000..82242a3
--- /dev/null
+++ b/internal/tool/ctx.go
@@ -0,0 +1,38 @@
+package tool
+
+import (
+ "context"
+ "time"
+)
+
+func Timeout(seconds ...int) (ctx context.Context) {
+ var (
+ duration time.Duration
+ )
+
+ if len(seconds) > 0 && seconds[0] > 0 {
+ duration = time.Duration(seconds[0]) * time.Second
+ } else {
+ duration = time.Duration(30) * time.Second
+ }
+
+ ctx, _ = context.WithTimeout(context.Background(), duration)
+
+ return
+}
+
+func TimeoutCtx(ctx context.Context, seconds ...int) context.Context {
+ var (
+ duration time.Duration
+ )
+
+ if len(seconds) > 0 && seconds[0] > 0 {
+ duration = time.Duration(seconds[0]) * time.Second
+ } else {
+ duration = time.Duration(30) * time.Second
+ }
+
+ nctx, _ := context.WithTimeout(ctx, duration)
+
+ return nctx
+}
diff --git a/internal/tool/file.go b/internal/tool/file.go
new file mode 100644
index 0000000..cefa36d
--- /dev/null
+++ b/internal/tool/file.go
@@ -0,0 +1,30 @@
+package tool
+
+import (
+ "io"
+ "os"
+)
+
+func CopyFile(src string, dst string) (err error) {
+ // Open the source file
+ sourceFile, err := os.Open(src)
+ if err != nil {
+ return err
+ }
+ defer sourceFile.Close()
+
+ // Create the destination file
+ destinationFile, err := os.Create(dst)
+ if err != nil {
+ return err
+ }
+ defer destinationFile.Close()
+
+ // Copy the contents from source to destination
+ _, err = io.Copy(destinationFile, sourceFile)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/internal/tool/human.go b/internal/tool/human.go
new file mode 100644
index 0000000..2c7ce71
--- /dev/null
+++ b/internal/tool/human.go
@@ -0,0 +1,24 @@
+package tool
+
+import "fmt"
+
+func HumanDuration(nano int64) string {
+ duration := float64(nano)
+ unit := "ns"
+ if duration >= 1000 {
+ duration /= 1000
+ unit = "us"
+ }
+
+ if duration >= 1000 {
+ duration /= 1000
+ unit = "ms"
+ }
+
+ if duration >= 1000 {
+ duration /= 1000
+ unit = " s"
+ }
+
+ return fmt.Sprintf("%6.2f%s", duration, unit)
+}
diff --git a/internal/tool/must.go b/internal/tool/must.go
new file mode 100644
index 0000000..0615f8d
--- /dev/null
+++ b/internal/tool/must.go
@@ -0,0 +1,11 @@
+package tool
+
+import "github.com/loveuer/nf/nft/log"
+
+func Must(errs ...error) {
+ for _, err := range errs {
+ if err != nil {
+ log.Panic(err.Error())
+ }
+ }
+}
diff --git a/internal/tool/password.go b/internal/tool/password.go
new file mode 100644
index 0000000..c2d1a17
--- /dev/null
+++ b/internal/tool/password.go
@@ -0,0 +1,84 @@
+package tool
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "github.com/loveuer/nf/nft/log"
+ "golang.org/x/crypto/pbkdf2"
+ "regexp"
+ "strconv"
+ "strings"
+)
+
+const (
+ EncryptHeader string = "pbkdf2:sha256" // 用户密码加密
+)
+
+func NewPassword(password string) string {
+ return EncryptPassword(password, RandomString(8), int(RandomInt(50000)+100000))
+}
+
+func ComparePassword(in, db string) bool {
+ strs := strings.Split(db, "$")
+ if len(strs) != 3 {
+ log.Error("password in db invalid: %s", db)
+ return false
+ }
+
+ encs := strings.Split(strs[0], ":")
+ if len(encs) != 3 {
+ log.Error("password in db invalid: %s", db)
+ return false
+ }
+
+ encIteration, err := strconv.Atoi(encs[2])
+ if err != nil {
+ log.Error("password in db invalid: %s, convert iter err: %s", db, err)
+ return false
+ }
+
+ return EncryptPassword(in, strs[1], encIteration) == db
+}
+
+func EncryptPassword(password, salt string, iter int) string {
+ hash := pbkdf2.Key([]byte(password), []byte(salt), iter, 32, sha256.New)
+ encrypted := hex.EncodeToString(hash)
+ return fmt.Sprintf("%s:%d$%s$%s", EncryptHeader, iter, salt, encrypted)
+}
+
+func CheckPassword(password string) error {
+ if len(password) < 8 || len(password) > 32 {
+ return errors.New("密码长度不符合")
+ }
+
+ var (
+ err error
+ match bool
+ patternList = []string{`[0-9]+`, `[a-z]+`, `[A-Z]+`, `[!@#%]+`} //, `[~!@#$%^&*?_-]+`}
+ matchAccount = 0
+ tips = []string{"缺少数字", "缺少小写字母", "缺少大写字母", "缺少'!@#%'"}
+ locktips = make([]string, 0)
+ )
+
+ for idx, pattern := range patternList {
+ match, err = regexp.MatchString(pattern, password)
+ if err != nil {
+ log.Warn("regex match string err, reg_str: %s, err: %v", pattern, err)
+ return errors.New("密码强度不够")
+ }
+
+ if match {
+ matchAccount++
+ } else {
+ locktips = append(locktips, tips[idx])
+ }
+ }
+
+ if matchAccount < 3 {
+ return fmt.Errorf("密码强度不够, 可能 %s", strings.Join(locktips, ", "))
+ }
+
+ return nil
+}
diff --git a/internal/tool/password_test.go b/internal/tool/password_test.go
new file mode 100644
index 0000000..aabd667
--- /dev/null
+++ b/internal/tool/password_test.go
@@ -0,0 +1,11 @@
+package tool
+
+import "testing"
+
+func TestEncPassword(t *testing.T) {
+ password := "123456"
+
+ result := EncryptPassword(password, RandomString(8), 50000)
+
+ t.Logf("sum => %s", result)
+}
diff --git a/internal/tool/random.go b/internal/tool/random.go
new file mode 100644
index 0000000..266cb4c
--- /dev/null
+++ b/internal/tool/random.go
@@ -0,0 +1,54 @@
+package tool
+
+import (
+ "crypto/rand"
+ "math/big"
+)
+
+var (
+ letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
+ letterNum = []byte("0123456789")
+ letterLow = []byte("abcdefghijklmnopqrstuvwxyz")
+ letterCap = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
+ letterSyb = []byte("!@#$%^&*()_+-=")
+)
+
+func RandomInt(max int64) int64 {
+ num, _ := rand.Int(rand.Reader, big.NewInt(max))
+ return num.Int64()
+}
+
+func RandomString(length int) string {
+ result := make([]byte, length)
+ for i := 0; i < length; i++ {
+ num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
+ result[i] = letters[num.Int64()]
+ }
+ return string(result)
+}
+
+func RandomPassword(length int, withSymbol bool) string {
+ result := make([]byte, length)
+ kind := 3
+ if withSymbol {
+ kind++
+ }
+
+ for i := 0; i < length; i++ {
+ switch i % kind {
+ case 0:
+ num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterNum))))
+ result[i] = letterNum[num.Int64()]
+ case 1:
+ num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterLow))))
+ result[i] = letterLow[num.Int64()]
+ case 2:
+ num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterCap))))
+ result[i] = letterCap[num.Int64()]
+ case 3:
+ num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterSyb))))
+ result[i] = letterSyb[num.Int64()]
+ }
+ }
+ return string(result)
+}
diff --git a/internal/tool/slice.go b/internal/tool/slice.go
new file mode 100644
index 0000000..05a7dd5
--- /dev/null
+++ b/internal/tool/slice.go
@@ -0,0 +1,5 @@
+package tool
+
+func Bulk[T any](slice []T, size int) {
+ // todo
+}
diff --git a/internal/tool/slice_test.go b/internal/tool/slice_test.go
new file mode 100644
index 0000000..05b1676
--- /dev/null
+++ b/internal/tool/slice_test.go
@@ -0,0 +1 @@
+package tool
diff --git a/internal/tool/table.go b/internal/tool/table.go
new file mode 100644
index 0000000..ffaaf31
--- /dev/null
+++ b/internal/tool/table.go
@@ -0,0 +1,124 @@
+package tool
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/jedib0t/go-pretty/v6/table"
+ "github.com/loveuer/nf/nft/log"
+ "io"
+ "os"
+ "reflect"
+ "strings"
+)
+
+func TablePrinter(data any, writers ...io.Writer) {
+ var w io.Writer = os.Stdout
+ if len(writers) > 0 && writers[0] != nil {
+ w = writers[0]
+ }
+
+ t := table.NewWriter()
+ structPrinter(t, "", data)
+ _, _ = fmt.Fprintln(w, t.Render())
+}
+
+func structPrinter(w table.Writer, prefix string, item any) {
+Start:
+ rv := reflect.ValueOf(item)
+ if rv.IsZero() {
+ return
+ }
+
+ for rv.Type().Kind() == reflect.Pointer {
+ rv = rv.Elem()
+ }
+
+ switch rv.Type().Kind() {
+ case reflect.Invalid,
+ reflect.Uintptr,
+ reflect.Chan,
+ reflect.Func,
+ reflect.UnsafePointer:
+ case reflect.Bool,
+ reflect.Int,
+ reflect.Int8,
+ reflect.Int16,
+ reflect.Int32,
+ reflect.Int64,
+ reflect.Uint,
+ reflect.Uint8,
+ reflect.Uint16,
+ reflect.Uint32,
+ reflect.Uint64,
+ reflect.Float32,
+ reflect.Float64,
+ reflect.Complex64,
+ reflect.Complex128,
+ reflect.Interface:
+ w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), rv.Interface()})
+ case reflect.String:
+ val := rv.String()
+ if len(val) <= 160 {
+ w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val})
+ return
+ }
+
+ w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val[0:64] + "..." + val[len(val)-64:]})
+ case reflect.Array, reflect.Slice:
+ for i := 0; i < rv.Len(); i++ {
+ p := strings.Join([]string{prefix, fmt.Sprintf("[%d]", i)}, ".")
+ structPrinter(w, p, rv.Index(i).Interface())
+ }
+ case reflect.Map:
+ for _, k := range rv.MapKeys() {
+ structPrinter(w, fmt.Sprintf("%s.{%v}", prefix, k), rv.MapIndex(k).Interface())
+ }
+ case reflect.Pointer:
+ goto Start
+ case reflect.Struct:
+ for i := 0; i < rv.NumField(); i++ {
+ p := fmt.Sprintf("%s.%s", prefix, rv.Type().Field(i).Name)
+ field := rv.Field(i)
+
+ //log.Debug("TablePrinter: prefix: %s, field: %v", p, rv.Field(i))
+
+ if !field.CanInterface() {
+ return
+ }
+
+ structPrinter(w, p, field.Interface())
+ }
+ }
+}
+
+func TableMapPrinter(data []byte) {
+ m := make(map[string]any)
+ if err := json.Unmarshal(data, &m); err != nil {
+ log.Warn(err.Error())
+ return
+ }
+
+ t := table.NewWriter()
+ addRow(t, "", m)
+ fmt.Println(t.Render())
+}
+
+func addRow(w table.Writer, prefix string, m any) {
+ rv := reflect.ValueOf(m)
+ switch rv.Type().Kind() {
+ case reflect.Map:
+ for _, k := range rv.MapKeys() {
+ key := k.String()
+ if prefix != "" {
+ key = strings.Join([]string{prefix, k.String()}, ".")
+ }
+ addRow(w, key, rv.MapIndex(k).Interface())
+ }
+ case reflect.Slice, reflect.Array:
+ for i := 0; i < rv.Len(); i++ {
+ addRow(w, fmt.Sprintf("%s[%d]", prefix, i), rv.Index(i).Interface())
+ }
+ default:
+ w.AppendRow(table.Row{prefix, m})
+ }
+}
diff --git a/internal/tool/time.go b/internal/tool/time.go
new file mode 100644
index 0000000..a193f20
--- /dev/null
+++ b/internal/tool/time.go
@@ -0,0 +1,13 @@
+package tool
+
+import "time"
+
+// TodayMidnight 返回今日凌晨
+func TodayMidnight() (midnight time.Time) {
+ now := time.Now()
+
+ year, month, day := now.Date()
+ midnight = time.Date(year, month, day, 0, 0, 0, 0, time.Local)
+
+ return
+}
diff --git a/main.go b/main.go
index 81bf409..78d5993 100644
--- a/main.go
+++ b/main.go
@@ -1,100 +1,21 @@
package main
import (
- "fmt"
- "github.com/google/uuid"
- "github.com/loveuer/nf"
- "log"
- "net/http"
+ "context"
+ "github.com/loveuer/nf/nft/log"
+ "os/signal"
+ "syscall"
+ "uauth/internal/cmd"
)
// 假设这是你的用户认证函数
-func authenticateUser(username, password string) (bool, error) {
- // 这里你应该实现真实的用户认证逻辑
- // 为了简化,我们这里直接硬编码一个用户名和密码
- if username == "user" && password == "pass" {
- return true, nil
- }
-
- return false, fmt.Errorf("invalid username or password")
-}
-
-// 处理登录请求
-func handleLogin(c *nf.Ctx) error {
- username := c.FormValue("username")
- password := c.FormValue("password")
-
- // 认证用户
- ok, err := authenticateUser(username, password)
- if err != nil || !ok {
- return c.Status(http.StatusUnauthorized).SendString("Unauthorized")
- }
-
- // 用户认证成功,重定向到授权页面
- http.Redirect(c.Writer, c.Request, "/authorize?client_id=12345&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fcallback&scope=read%20write", http.StatusFound)
-
- return nil
-}
-
-// 处理授权请求
-func handleAuthorize(c *nf.Ctx) error {
- // 解析查询参数
- clientID := c.Query("client_id")
- responseType := c.Query("response_type")
- redirectURI := c.Query("redirect_uri")
- scope := c.Query("scope")
-
- // 检查客户端 ID 和其他参数
- // 在实际应用中,你需要检查这些参数是否合法
- if clientID != "12345" || responseType != "code" || redirectURI != "http://localhost:8080/callback" {
- return c.Status(http.StatusBadRequest).SendString("Invalid request")
- }
-
- // 显示授权页面给用户
- _, err := c.Write([]byte(`
-
- Authorization
-
- Do you want to authorize this application?
-
-
-
- `))
-
- return err
-}
-
-// 处理用户的授权批准
-func handleApprove(c *nf.Ctx) error {
- // 获取表单数据
- clientID := c.FormValue("client_id")
- redirectURI := c.FormValue("redirect_uri")
- scope := c.FormValue("scope")
-
- // 生成授权码
- authorizationCode := uuid.New().String()[:8]
-
- log.Printf("[D] client_id = %s, scope = %s, auth_code = %s", clientID, scope, authorizationCode)
-
- // 重定向到回调 URL 并附带授权码
- http.Redirect(c.Writer, c.Request, redirectURI+"?code="+authorizationCode, http.StatusFound)
- return nil
-}
func main() {
- app := nf.New()
+ ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
+ defer cancel()
- // 设置路由
- app.Get("/login", handleLogin)
- app.Get("/authorize", handleAuthorize)
- app.Post("/approve", handleApprove)
+ if err := cmd.Command.ExecuteContext(ctx); err != nil {
+ log.Error(err.Error())
+ }
- // 启动 HTTP 服务器
- log.Println("Starting server on :8080")
- log.Fatal(app.Run(":8080"))
}