diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go index 5f415f6..98bff22 100644 --- a/internal/cmd/serve.go +++ b/internal/cmd/serve.go @@ -7,6 +7,7 @@ import ( "uauth/internal/store/cache" "uauth/internal/store/db" "uauth/internal/tool" + "uauth/model" ) func initServe() *cobra.Command { @@ -16,6 +17,7 @@ func initServe() *cobra.Command { tool.TablePrinter(opt.Cfg) tool.Must(cache.Init(opt.Cfg.Svc.Cache)) tool.Must(db.Init(cmd.Context(), opt.Cfg.Svc.DB)) + tool.Must(model.Init(db.Default.Session())) return serve.Run(cmd.Context()) }, } diff --git a/internal/middleware/auth/auth.go b/internal/middleware/auth/auth.go new file mode 100644 index 0000000..58b00d0 --- /dev/null +++ b/internal/middleware/auth/auth.go @@ -0,0 +1,98 @@ +package auth + +import ( + "errors" + "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/resp" + "time" + "uauth/internal/store/cache" + "uauth/internal/tool" + "uauth/model" +) + +type Config struct { + IgnoreFn func(c *nf.Ctx) bool + TokenFn func(c *nf.Ctx) (string, bool) + GetUserFn func(c *nf.Ctx, token string) (*model.User, error) +} + +var ( + defaultIgnoreFn = func(c *nf.Ctx) bool { + return false + } + + defaultTokenFn = func(c *nf.Ctx) (string, bool) { + var token string + + if token = c.Request.Header.Get("Authorization"); token != "" { + return token, true + } + + if token = c.Query("access_token"); token != "" { + return token, true + } + + if token = c.Cookies("access_token"); token != "" { + return token, true + } + + return "", false + } + + defaultGetUserFn = func(c *nf.Ctx, token string) (*model.User, error) { + var ( + err error + op = new(model.User) + ) + + if err = cache.Client.GetExScan(tool.Timeout(3), cache.Prefix+"token:"+token, 24*time.Hour).Scan(op); err != nil { + if errors.Is(err, cache.ErrorKeyNotFound) { + return nil, resp.Resp401(c, nil, "用户认证信息不存在或已过期, 请重新登录") + } + + return nil, resp.Resp500(c, err) + } + + return op, nil + } +) + +func New(cfgs ...*Config) nf.HandlerFunc { + var cfg = &Config{} + + if len(cfgs) > 0 && cfgs[0] != nil { + cfg = cfgs[0] + } + + if cfg.IgnoreFn == nil { + cfg.IgnoreFn = defaultIgnoreFn + } + + if cfg.TokenFn == nil { + cfg.TokenFn = defaultTokenFn + } + + if cfg.GetUserFn == nil { + cfg.GetUserFn = defaultGetUserFn + } + + return func(c *nf.Ctx) error { + if cfg.IgnoreFn(c) { + return c.Next() + } + + token, ok := cfg.TokenFn(c) + if !ok { + return resp.Resp401(c, nil, "请登录") + } + + op, err := cfg.GetUserFn(c, token) + if err != nil { + return err + } + + c.Locals("user", op) + + return c.Next() + } +} diff --git a/internal/serve/handler/approve.go b/internal/serve/handler/approve.go new file mode 100644 index 0000000..6ad0e34 --- /dev/null +++ b/internal/serve/handler/approve.go @@ -0,0 +1,58 @@ +package handler + +import ( + "errors" + "github.com/google/uuid" + "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/resp" + "net/http" + "time" + "uauth/internal/store/cache" + "uauth/internal/tool" + "uauth/model" +) + +func Approve(c *nf.Ctx) error { + // 获取表单数据 + type Req struct { + ClientId string `form:"client_id"` + ClientSecret string `form:"client_secret"` + RedirectURI string `form:"redirect_uri"` + Scope string `form:"scope"` + State string `form:"state"` + } + + var ( + ok bool + op *model.User + err error + req = new(Req) + ) + + if op, ok = c.Locals("user").(*model.User); !ok { + return resp.Resp401(c, nil) + } + + if err = c.BodyParser(req); err != nil { + return resp.Resp400(c, err) + } + + state := cache.Prefix + "state_code:" + req.State + if _, err = cache.Client.Get(c.Context(), state); err != nil { + if errors.Is(err, cache.ErrorKeyNotFound) { + return resp.Resp400(c, req, "Bad Approve Request") + } + + return resp.Resp500(c, err) + } + + _ = cache.Client.Del(tool.Timeout(3), state) + + authorizationCode := uuid.New().String()[:8] + if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil { + return resp.Resp500(c, err) + } + + // 重定向到回调 URL 并附带授权码 + return c.Redirect(req.RedirectURI+"?code="+authorizationCode, http.StatusFound) +} diff --git a/internal/serve/handler/authorize.go b/internal/serve/handler/authorize.go index b80a882..f565167 100644 --- a/internal/serve/handler/authorize.go +++ b/internal/serve/handler/authorize.go @@ -1,39 +1,97 @@ package handler import ( + _ "embed" + "errors" + "github.com/google/uuid" "github.com/loveuer/nf" + "github.com/loveuer/nf/nft/resp" + "gorm.io/gorm" "net/http" + "time" + "uauth/internal/store/cache" + "uauth/internal/store/db" + "uauth/model" +) + +var ( + //go:embed serve_authorize.html + pageAuthorize string ) func Authorize(c *nf.Ctx) error { - // 解析查询参数 - clientID := c.Query("client_id") - responseType := c.Query("response_type") - redirectURI := c.Query("redirect_uri") - scope := c.Query("scope") + type Req struct { + ClientId string `query:"client_id"` + ClientSecret string `query:"client_secret"` + ResponseType string `query:"response_type"` + RedirectURI string `query:"redirect_uri"` + Scope string `query:"scope"` + } - // 检查客户端 ID 和其他参数 - // 在实际应用中,你需要检查这些参数是否合法 - if clientID != "12345" || responseType != "code" || redirectURI != "http://localhost:8080/callback" { + var ( + ok bool + op *model.User + req = new(Req) + err error + client = &model.Client{} + authRecord = &model.AuthorizationRecord{} + ) + + if err = c.QueryParser(req); err != nil { + return resp.Resp400(c, err) + } + + if req.ResponseType != "code" { return c.Status(http.StatusBadRequest).SendString("Invalid request") } - // 显示授权页面给用户 - _, err := c.Write([]byte(` - - Authorization - -

Do you want to authorize this application?

-
- - - - -
- - - `)) + if op, ok = c.Locals("user").(*model.User); !ok { + return resp.Resp401(c, nil) + } - return err + if err = db.Default.Session().Where("client_id", req.ClientId).Take(client).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return resp.Resp400(c, err, "无效的 client id") + } + } + + if err = db.Default.Session().Model(&model.AuthorizationRecord{}). + Where("user_id", op.Id). + Where("client_id", client.ClientId). + Take(authRecord). + Error; err != nil { + // 用户第一次对该 client 进行授权 + if errors.Is(err, gorm.ErrRecordNotFound) { + + state := uuid.New().String()[:8] + if err = cache.Client.SetEx(c.Context(), cache.Prefix+"state_code:"+state, nil, 10*time.Minute); err != nil { + return resp.Resp500(c, err.Error()) + } + + return c.RenderHTML("authorize", pageAuthorize, map[string]any{ + "user": map[string]any{ + "username": op.Username, + "avatar": "https://picsum.photos/200", + }, + "client_id": req.ClientId, + "client_secret": req.ClientSecret, + "redirect_uri": req.RedirectURI, + "scope": req.Scope, + "state": state, + }) + } + + return resp.Resp500(c, err) + } + + // 当用户已经授权过时 + // 生成授权码并缓存授权码 + authorizationCode := uuid.New().String()[:8] + if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil { + return resp.Resp500(c, err) + } + + // 重定向到回调 URL 并附带授权码 + return c.Redirect(req.RedirectURI+"?code="+authorizationCode, http.StatusFound) } diff --git a/internal/serve/handler/login.go b/internal/serve/handler/login.go index 03b2d08..d6dc629 100644 --- a/internal/serve/handler/login.go +++ b/internal/serve/handler/login.go @@ -2,23 +2,26 @@ package handler import ( _ "embed" + "github.com/google/uuid" "github.com/loveuer/nf" "github.com/loveuer/nf/nft/resp" "net/http" "net/url" + "time" + "uauth/internal/store/cache" ) var ( //go:embed serve_login.html - page string + pageLogin string ) func LoginPage(c *nf.Ctx) error { type Req struct { - ClientID string `query:"client_id"` - ClientSecret string `query:"client_secret"` - Scope string `query:"scope"` - RedirectURI string `query:"redirect_uri"` + ClientId string `query:"client_id" json:"client_id"` + ClientSecret string `query:"client_secret" json:"client_secret"` + Scope string `query:"scope" json:"scope"` + RedirectURI string `query:"redirect_uri" json:"redirect_uri"` } var ( @@ -30,7 +33,7 @@ func LoginPage(c *nf.Ctx) error { return resp.Resp400(c, err.Error()) } - if req.ClientID == "" || req.ClientSecret == "" || req.RedirectURI == "" { + if req.ClientId == "" || req.ClientSecret == "" || req.RedirectURI == "" { return resp.Resp400(c, req) } @@ -38,11 +41,17 @@ func LoginPage(c *nf.Ctx) error { // todo: 如果用户是已登录状态,则直接带上信息返回到 authorize 页面 - return c.RenderHTML("login", page, map[string]interface{}{ - "client_id": req.ClientID, + state := uuid.New().String()[:8] + if err = cache.Client.SetEx(c.Context(), cache.Prefix+"state_code:"+state, nil, 10*time.Minute); err != nil { + return resp.Resp500(c, err.Error()) + } + + return c.RenderHTML("login", pageLogin, map[string]interface{}{ + "client_id": req.ClientId, "client_secret": req.ClientSecret, "redirect_uri": req.RedirectURI, "scope": req.Scope, + "state": state, }) } diff --git a/internal/serve/handler/registry.go b/internal/serve/handler/registry.go index e77b34d..97cd406 100644 --- a/internal/serve/handler/registry.go +++ b/internal/serve/handler/registry.go @@ -1,6 +1,7 @@ package handler import ( + _ "embed" "errors" "github.com/loveuer/nf" "github.com/loveuer/nf/nft/resp" @@ -28,7 +29,7 @@ func ClientRegistry(c *nf.Ctx) error { Secret := tool.RandomString(32) - platform := &model.Platform{ + platform := &model.Client{ ClientId: req.ClientId, Icon: req.Icon, Name: req.Name, @@ -45,3 +46,66 @@ func ClientRegistry(c *nf.Ctx) error { return resp.Resp200(c, platform) } + +var ( + //go:embed serve_registry.html + registryLogin string +) + +func UserRegistryPage(c *nf.Ctx) error { + return c.HTML(registryLogin) +} + +func UserRegistryAction(c *nf.Ctx) error { + type Req struct { + Username string `form:"username"` + Nickname string `form:"nickname"` + Password string `form:"password"` + ConfirmPassword string `form:"confirm_password"` + } + + var ( + err error + req = new(Req) + ) + + if err = c.BodyParser(req); err != nil { + return resp.Resp400(c, err.Error()) + } + + if err = tool.CheckPassword(req.Password); err != nil { + return resp.Resp400(c, req, err.Error()) + } + + op := &model.User{ + Username: req.Username, + Nickname: req.Nickname, + Password: tool.NewPassword(req.Password), + } + + if err = db.Default.Session().Create(op).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return resp.Resp400(c, err, "用户名已存在") + } + + return resp.Resp500(c, err) + } + + return c.HTML(` + + + + + + 注册成功 + + +

注册成功

+

快去试试吧

+ + +`) +} diff --git a/internal/serve/handler/serve_authorize.html b/internal/serve/handler/serve_authorize.html new file mode 100644 index 0000000..e6bb163 --- /dev/null +++ b/internal/serve/handler/serve_authorize.html @@ -0,0 +1,49 @@ + + + + + + Server Login + + + +
+

授权登录到 {{ .client_name }} 平台

+
+
+
+ +
+
+ {{ .user.username }} +
+
+
+
+
+ + + + + +
+ +
+ + +
+
+
+ + diff --git a/internal/serve/handler/serve_login.html b/internal/serve/handler/serve_login.html index 1cc8cc2..cb225ce 100644 --- a/internal/serve/handler/serve_login.html +++ b/internal/serve/handler/serve_login.html @@ -19,7 +19,7 @@
-

欢迎来到 Pro

+

欢迎来到 UAuth