diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go index 5f415f6..98bff22 100644 --- a/internal/cmd/serve.go +++ b/internal/cmd/serve.go @@ -7,6 +7,7 @@ import ( "uauth/internal/store/cache" "uauth/internal/store/db" "uauth/internal/tool" + "uauth/model" ) func initServe() *cobra.Command { @@ -16,6 +17,7 @@ func initServe() *cobra.Command { tool.TablePrinter(opt.Cfg) tool.Must(cache.Init(opt.Cfg.Svc.Cache)) tool.Must(db.Init(cmd.Context(), opt.Cfg.Svc.DB)) + tool.Must(model.Init(db.Default.Session())) return serve.Run(cmd.Context()) }, } diff --git a/internal/middleware/auth/auth.go b/internal/middleware/auth/auth.go new file mode 100644 index 0000000..58b00d0 --- /dev/null +++ b/internal/middleware/auth/auth.go @@ -0,0 +1,98 @@ +package auth + +import ( + "errors" + "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/resp" + "time" + "uauth/internal/store/cache" + "uauth/internal/tool" + "uauth/model" +) + +type Config struct { + IgnoreFn func(c *nf.Ctx) bool + TokenFn func(c *nf.Ctx) (string, bool) + GetUserFn func(c *nf.Ctx, token string) (*model.User, error) +} + +var ( + defaultIgnoreFn = func(c *nf.Ctx) bool { + return false + } + + defaultTokenFn = func(c *nf.Ctx) (string, bool) { + var token string + + if token = c.Request.Header.Get("Authorization"); token != "" { + return token, true + } + + if token = c.Query("access_token"); token != "" { + return token, true + } + + if token = c.Cookies("access_token"); token != "" { + return token, true + } + + return "", false + } + + defaultGetUserFn = func(c *nf.Ctx, token string) (*model.User, error) { + var ( + err error + op = new(model.User) + ) + + if err = cache.Client.GetExScan(tool.Timeout(3), cache.Prefix+"token:"+token, 24*time.Hour).Scan(op); err != nil { + if errors.Is(err, cache.ErrorKeyNotFound) { + return nil, resp.Resp401(c, nil, "用户认证信息不存在或已过期, 请重新登录") + } + + return nil, resp.Resp500(c, err) + } + + return op, nil + } +) + +func New(cfgs ...*Config) nf.HandlerFunc { + var cfg = &Config{} + + if len(cfgs) > 0 && cfgs[0] != nil { + cfg = cfgs[0] + } + + if cfg.IgnoreFn == nil { + cfg.IgnoreFn = defaultIgnoreFn + } + + if cfg.TokenFn == nil { + cfg.TokenFn = defaultTokenFn + } + + if cfg.GetUserFn == nil { + cfg.GetUserFn = defaultGetUserFn + } + + return func(c *nf.Ctx) error { + if cfg.IgnoreFn(c) { + return c.Next() + } + + token, ok := cfg.TokenFn(c) + if !ok { + return resp.Resp401(c, nil, "请登录") + } + + op, err := cfg.GetUserFn(c, token) + if err != nil { + return err + } + + c.Locals("user", op) + + return c.Next() + } +} diff --git a/internal/serve/handler/approve.go b/internal/serve/handler/approve.go new file mode 100644 index 0000000..6ad0e34 --- /dev/null +++ b/internal/serve/handler/approve.go @@ -0,0 +1,58 @@ +package handler + +import ( + "errors" + "github.com/google/uuid" + "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/resp" + "net/http" + "time" + "uauth/internal/store/cache" + "uauth/internal/tool" + "uauth/model" +) + +func Approve(c *nf.Ctx) error { + // 获取表单数据 + type Req struct { + ClientId string `form:"client_id"` + ClientSecret string `form:"client_secret"` + RedirectURI string `form:"redirect_uri"` + Scope string `form:"scope"` + State string `form:"state"` + } + + var ( + ok bool + op *model.User + err error + req = new(Req) + ) + + if op, ok = c.Locals("user").(*model.User); !ok { + return resp.Resp401(c, nil) + } + + if err = c.BodyParser(req); err != nil { + return resp.Resp400(c, err) + } + + state := cache.Prefix + "state_code:" + req.State + if _, err = cache.Client.Get(c.Context(), state); err != nil { + if errors.Is(err, cache.ErrorKeyNotFound) { + return resp.Resp400(c, req, "Bad Approve Request") + } + + return resp.Resp500(c, err) + } + + _ = cache.Client.Del(tool.Timeout(3), state) + + authorizationCode := uuid.New().String()[:8] + if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil { + return resp.Resp500(c, err) + } + + // 重定向到回调 URL 并附带授权码 + return c.Redirect(req.RedirectURI+"?code="+authorizationCode, http.StatusFound) +} diff --git a/internal/serve/handler/authorize.go b/internal/serve/handler/authorize.go index b80a882..f565167 100644 --- a/internal/serve/handler/authorize.go +++ b/internal/serve/handler/authorize.go @@ -1,39 +1,97 @@ package handler import ( + _ "embed" + "errors" + "github.com/google/uuid" "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/resp" + "gorm.io/gorm" "net/http" + "time" + "uauth/internal/store/cache" + "uauth/internal/store/db" + "uauth/model" +) + +var ( + //go:embed serve_authorize.html + pageAuthorize string ) func Authorize(c *nf.Ctx) error { - // 解析查询参数 - clientID := c.Query("client_id") - responseType := c.Query("response_type") - redirectURI := c.Query("redirect_uri") - scope := c.Query("scope") + type Req struct { + ClientId string `query:"client_id"` + ClientSecret string `query:"client_secret"` + ResponseType string `query:"response_type"` + RedirectURI string `query:"redirect_uri"` + Scope string `query:"scope"` + } - // 检查客户端 ID 和其他参数 - // 在实际应用中,你需要检查这些参数是否合法 - if clientID != "12345" || responseType != "code" || redirectURI != "http://localhost:8080/callback" { + var ( + ok bool + op *model.User + req = new(Req) + err error + client = &model.Client{} + authRecord = &model.AuthorizationRecord{} + ) + + if err = c.QueryParser(req); err != nil { + return resp.Resp400(c, err) + } + + if req.ResponseType != "code" { return c.Status(http.StatusBadRequest).SendString("Invalid request") } - // 显示授权页面给用户 - _, err := c.Write([]byte(` - -