diff --git a/internal/client/client.go b/internal/client/client.go index 683167c..e0a03bc 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -32,7 +32,8 @@ var ( func Run(ctx context.Context) error { app := nf.New() - app.Get("/login", handleLogin) + app.Get("/login", handleLoginPage) + app.Post("/login", handleLoginAction) app.Get("/oauth/v2/redirect", handleRedirect) go func() { @@ -43,9 +44,38 @@ func Run(ctx context.Context) error { return app.Run(":18080") } -func handleLogin(c *nf.Ctx) error { +func handleLoginAction(c *nf.Ctx) error { + type Req struct { + Username string `form:"username"` + Password string `form:"password"` + } + + var ( + err error + req = new(Req) + token *oauth2.Token + ) + + if err = c.BodyParser(req); err != nil { + return c.Status(http.StatusBadRequest).SendString(err.Error()) + } + + log.Info("[C] password mode: username = %s, password = %s", req.Username, req.Password) + + if token, err = config.PasswordCredentialsToken(c.Context(), req.Username, req.Password); err != nil { + log.Error("[C] config do password token err = %s", err) + return c.Status(http.StatusBadRequest).SendString(err.Error()) + } + + bs, _ := json.Marshal(token) + log.Info("[C] oauth finally token =\n%s", string(bs)) + + return resp.Resp200(c, token) +} + +func handleLoginPage(c *nf.Ctx) error { if c.Query("oauth") != "" { - uri := config.AuthCodeURL(state) + uri := config.AuthCodeURL(state, oauth2.ApprovalForce) log.Info("[C] oauth config client_secret = %s", config.ClientSecret) log.Info("[C] redirect to oauth2 server uri = %s", uri) return c.Redirect(uri, http.StatusFound) diff --git a/internal/client/login.html b/internal/client/login.html index bb57e1d..d46556e 100644 --- a/internal/client/login.html +++ b/internal/client/login.html @@ -20,7 +20,7 @@

这里是 xx 产品登录页面

-
+
- + + >登录 使用 OAuth V2 账号登录
diff --git a/internal/serve/handler/token.go b/internal/serve/handler/token.go index a5fd473..044f557 100644 --- a/internal/serve/handler/token.go +++ b/internal/serve/handler/token.go @@ -15,49 +15,34 @@ import ( "uauth/model" ) -func HandleToken(c *nf.Ctx) error { - type Req struct { - Code string `form:"code"` - GrantType string `form:"grant_type"` - RedirectURI string `form:"redirect_uri"` - } - +func verifyClient(c *nf.Ctx) (*model.Client, error) { var ( err error - req = new(Req) - opId uint64 - op = new(model.User) - token string basic string bs []byte strs []string client = new(model.Client) ) - if err = c.BodyParser(req); err != nil { - return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid form") - } - - // client_secret if basic = c.Get("Authorization"); basic == "" { - return c.Status(http.StatusUnauthorized).SendString("Authorization header missing") + return nil, errors.New("authorization header missing") } switch { case strings.HasPrefix(basic, "Basic "): basic = strings.TrimPrefix(basic, "Basic ") default: - return c.Status(http.StatusBadRequest).SendString("Bad Request: authorization scheme not supported") + return nil, errors.New("authorization scheme not supported") } if bs, err = base64.StdEncoding.DecodeString(basic); err != nil { log.Warn("[Token] base64 decode failed, raw = %s, err = %s", basic, err.Error()) - return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid basic authorization") + return nil, errors.New("invalid basic authorization") } if strs = strings.SplitN(string(bs), ":", 2); len(strs) != 2 { log.Warn("[Token] basic split err, decode = %s", string(bs)) - return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid basic authorization") + return nil, errors.New("invalid basic authorization") } clientId, clientSecret := strs[0], strs[1] @@ -66,31 +51,79 @@ func HandleToken(c *nf.Ctx) error { Take(client). Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return c.Status(http.StatusBadRequest).SendString("Bad Request: client invalid") + return nil, errors.New("client invalid") } log.Error("[Token] db take client by id = %s, err = %s", clientId, err.Error()) - return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") + return nil, errors.New("unknown server error") } if client.ClientSecret != clientSecret { log.Warn("[Token] client_secret invalid, want = %s, got = %s", client.ClientSecret, clientSecret) - return c.Status(http.StatusUnauthorized).SendString("Unauthorized: client secret invalid") + return nil, errors.New("client secret invalid") } - if err = cache.Client.GetScan(tool.Timeout(2), cache.Prefix+"auth_code:"+req.Code).Scan(&opId); err != nil { - if errors.Is(err, cache.ErrorKeyNotFound) { - return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid code") + return client, nil +} + +func HandleToken(c *nf.Ctx) error { + var ( + err error + opId uint64 + op = new(model.User) + token string + client *model.Client + grantType = c.Form("grant_type") + ) + + switch grantType { + case "password": + if client, err = verifyClient(c); err != nil { + return c.Status(http.StatusBadRequest).SendString("Bad Request: " + err.Error()) } - log.Error("[S] handleToken: get code from cache err = %s", err.Error()) - return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") - } + username := c.Form("username") + password := c.Form("password") + if err = db.Default.Session().Model(&model.User{}). + Where("username = ?", username). + Take(op).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid username or password") + } - op.Id = opId - if err = db.Default.Session().Take(op).Error; err != nil { - log.Error("[S] handleToken: get op by id err, id = %d, err = %s", opId, err.Error()) - return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") + log.Error("[Token] db take user by username = %s, err = %s", username, err.Error()) + return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") + } + + if !tool.ComparePassword(password, op.Password) { + return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid username or password") + } + case "authorization_code": + if client, err = verifyClient(c); err != nil { + return c.Status(http.StatusBadRequest).SendString("Bad Request: " + err.Error()) + } + + code := c.Form("code") + if code == "" { + return c.Status(http.StatusBadRequest).SendString("Bad Request: no code provided") + } + + if err = cache.Client.GetScan(tool.Timeout(2), cache.Prefix+"auth_code:"+code).Scan(&opId); err != nil { + if errors.Is(err, cache.ErrorKeyNotFound) { + return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid code") + } + + log.Error("[S] handleToken: get code from cache err = %s", err.Error()) + return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") + } + + op.Id = opId + if err = db.Default.Session().Take(op).Error; err != nil { + log.Error("[S] handleToken: get op by id err, id = %d, err = %s", opId, err.Error()) + return c.Status(http.StatusInternalServerError).SendString("Internal Server Error") + } + default: + return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid grant type") } if token, err = op.JwtEncode(); err != nil { @@ -100,6 +133,8 @@ func HandleToken(c *nf.Ctx) error { refreshToken := uuid.New().String() + _ = client + return c.JSON(map[string]any{ "access_token": token, "refresh_token": refreshToken, diff --git a/readme.md b/readme.md index 6f20838..5dc2900 100644 --- a/readme.md +++ b/readme.md @@ -1,12 +1,20 @@ # uauth +## update: + +- 添加 password 模式示例 + ## run - `go run . svc` - `go run . client` - `浏览器打开`[http://localhost:18080/login](http://localhost:18080/login) -## oauth2 authorization flow +## oauth2 authorization flow(password mode) + +- 1. 客户端直接拿到用户的账号和密码来请求 oauth2 服务器获取 token + +## oauth2 authorization flow(authorization_code mode) - 1. 某某 系统/平台(比如: xx_platform) 的用户想要登录该 系统/平台, 并点击到登录页面 - 2. 用户发现该平台上有 `通过 {oauth2} 登录` 的按钮, 用户点击该按钮, 跳转到 `{oauth2}` 服务的登录页面如: `/oauth2/login`