wip: 24/10/28

This commit is contained in:
loveuer 2024-10-28 18:16:36 +08:00
parent ef6cca510a
commit d1597509bb
14 changed files with 496 additions and 56 deletions

View File

@ -7,6 +7,7 @@ import (
"uauth/internal/store/cache" "uauth/internal/store/cache"
"uauth/internal/store/db" "uauth/internal/store/db"
"uauth/internal/tool" "uauth/internal/tool"
"uauth/model"
) )
func initServe() *cobra.Command { func initServe() *cobra.Command {
@ -16,6 +17,7 @@ func initServe() *cobra.Command {
tool.TablePrinter(opt.Cfg) tool.TablePrinter(opt.Cfg)
tool.Must(cache.Init(opt.Cfg.Svc.Cache)) tool.Must(cache.Init(opt.Cfg.Svc.Cache))
tool.Must(db.Init(cmd.Context(), opt.Cfg.Svc.DB)) tool.Must(db.Init(cmd.Context(), opt.Cfg.Svc.DB))
tool.Must(model.Init(db.Default.Session()))
return serve.Run(cmd.Context()) return serve.Run(cmd.Context())
}, },
} }

View File

@ -0,0 +1,98 @@
package auth
import (
"errors"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/resp"
"time"
"uauth/internal/store/cache"
"uauth/internal/tool"
"uauth/model"
)
type Config struct {
IgnoreFn func(c *nf.Ctx) bool
TokenFn func(c *nf.Ctx) (string, bool)
GetUserFn func(c *nf.Ctx, token string) (*model.User, error)
}
var (
defaultIgnoreFn = func(c *nf.Ctx) bool {
return false
}
defaultTokenFn = func(c *nf.Ctx) (string, bool) {
var token string
if token = c.Request.Header.Get("Authorization"); token != "" {
return token, true
}
if token = c.Query("access_token"); token != "" {
return token, true
}
if token = c.Cookies("access_token"); token != "" {
return token, true
}
return "", false
}
defaultGetUserFn = func(c *nf.Ctx, token string) (*model.User, error) {
var (
err error
op = new(model.User)
)
if err = cache.Client.GetExScan(tool.Timeout(3), cache.Prefix+"token:"+token, 24*time.Hour).Scan(op); err != nil {
if errors.Is(err, cache.ErrorKeyNotFound) {
return nil, resp.Resp401(c, nil, "用户认证信息不存在或已过期, 请重新登录")
}
return nil, resp.Resp500(c, err)
}
return op, nil
}
)
func New(cfgs ...*Config) nf.HandlerFunc {
var cfg = &Config{}
if len(cfgs) > 0 && cfgs[0] != nil {
cfg = cfgs[0]
}
if cfg.IgnoreFn == nil {
cfg.IgnoreFn = defaultIgnoreFn
}
if cfg.TokenFn == nil {
cfg.TokenFn = defaultTokenFn
}
if cfg.GetUserFn == nil {
cfg.GetUserFn = defaultGetUserFn
}
return func(c *nf.Ctx) error {
if cfg.IgnoreFn(c) {
return c.Next()
}
token, ok := cfg.TokenFn(c)
if !ok {
return resp.Resp401(c, nil, "请登录")
}
op, err := cfg.GetUserFn(c, token)
if err != nil {
return err
}
c.Locals("user", op)
return c.Next()
}
}

View File

@ -0,0 +1,58 @@
package handler
import (
"errors"
"github.com/google/uuid"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/resp"
"net/http"
"time"
"uauth/internal/store/cache"
"uauth/internal/tool"
"uauth/model"
)
func Approve(c *nf.Ctx) error {
// 获取表单数据
type Req struct {
ClientId string `form:"client_id"`
ClientSecret string `form:"client_secret"`
RedirectURI string `form:"redirect_uri"`
Scope string `form:"scope"`
State string `form:"state"`
}
var (
ok bool
op *model.User
err error
req = new(Req)
)
if op, ok = c.Locals("user").(*model.User); !ok {
return resp.Resp401(c, nil)
}
if err = c.BodyParser(req); err != nil {
return resp.Resp400(c, err)
}
state := cache.Prefix + "state_code:" + req.State
if _, err = cache.Client.Get(c.Context(), state); err != nil {
if errors.Is(err, cache.ErrorKeyNotFound) {
return resp.Resp400(c, req, "Bad Approve Request")
}
return resp.Resp500(c, err)
}
_ = cache.Client.Del(tool.Timeout(3), state)
authorizationCode := uuid.New().String()[:8]
if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil {
return resp.Resp500(c, err)
}
// 重定向到回调 URL 并附带授权码
return c.Redirect(req.RedirectURI+"?code="+authorizationCode, http.StatusFound)
}

View File

@ -1,39 +1,97 @@
package handler package handler
import ( import (
_ "embed"
"errors"
"github.com/google/uuid"
"github.com/loveuer/nf" "github.com/loveuer/nf"
"github.com/loveuer/nf/nft/resp"
"gorm.io/gorm"
"net/http" "net/http"
"time"
"uauth/internal/store/cache"
"uauth/internal/store/db"
"uauth/model"
)
var (
//go:embed serve_authorize.html
pageAuthorize string
) )
func Authorize(c *nf.Ctx) error { func Authorize(c *nf.Ctx) error {
// 解析查询参数 type Req struct {
clientID := c.Query("client_id") ClientId string `query:"client_id"`
responseType := c.Query("response_type") ClientSecret string `query:"client_secret"`
redirectURI := c.Query("redirect_uri") ResponseType string `query:"response_type"`
scope := c.Query("scope") RedirectURI string `query:"redirect_uri"`
Scope string `query:"scope"`
}
// 检查客户端 ID 和其他参数 var (
// 在实际应用中,你需要检查这些参数是否合法 ok bool
if clientID != "12345" || responseType != "code" || redirectURI != "http://localhost:8080/callback" { op *model.User
req = new(Req)
err error
client = &model.Client{}
authRecord = &model.AuthorizationRecord{}
)
if err = c.QueryParser(req); err != nil {
return resp.Resp400(c, err)
}
if req.ResponseType != "code" {
return c.Status(http.StatusBadRequest).SendString("Invalid request") return c.Status(http.StatusBadRequest).SendString("Invalid request")
} }
// 显示授权页面给用户 if op, ok = c.Locals("user").(*model.User); !ok {
_, err := c.Write([]byte(` return resp.Resp401(c, nil)
<html> }
<head><title>Authorization</title></head>
<body>
<h1>Do you want to authorize this application?</h1>
<form action="/approve" method="post">
<input type="hidden" name="client_id" value="` + clientID + `"/>
<input type="hidden" name="redirect_uri" value="` + redirectURI + `"/>
<input type="hidden" name="scope" value="` + scope + `"/>
<button type="submit">Yes, I authorize</button>
</form>
</body>
</html>
`))
return err if err = db.Default.Session().Where("client_id", req.ClientId).Take(client).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return resp.Resp400(c, err, "无效的 client id")
}
}
if err = db.Default.Session().Model(&model.AuthorizationRecord{}).
Where("user_id", op.Id).
Where("client_id", client.ClientId).
Take(authRecord).
Error; err != nil {
// 用户第一次对该 client 进行授权
if errors.Is(err, gorm.ErrRecordNotFound) {
state := uuid.New().String()[:8]
if err = cache.Client.SetEx(c.Context(), cache.Prefix+"state_code:"+state, nil, 10*time.Minute); err != nil {
return resp.Resp500(c, err.Error())
}
return c.RenderHTML("authorize", pageAuthorize, map[string]any{
"user": map[string]any{
"username": op.Username,
"avatar": "https://picsum.photos/200",
},
"client_id": req.ClientId,
"client_secret": req.ClientSecret,
"redirect_uri": req.RedirectURI,
"scope": req.Scope,
"state": state,
})
}
return resp.Resp500(c, err)
}
// 当用户已经授权过时
// 生成授权码并缓存授权码
authorizationCode := uuid.New().String()[:8]
if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil {
return resp.Resp500(c, err)
}
// 重定向到回调 URL 并附带授权码
return c.Redirect(req.RedirectURI+"?code="+authorizationCode, http.StatusFound)
} }

View File

@ -2,23 +2,26 @@ package handler
import ( import (
_ "embed" _ "embed"
"github.com/google/uuid"
"github.com/loveuer/nf" "github.com/loveuer/nf"
"github.com/loveuer/nf/nft/resp" "github.com/loveuer/nf/nft/resp"
"net/http" "net/http"
"net/url" "net/url"
"time"
"uauth/internal/store/cache"
) )
var ( var (
//go:embed serve_login.html //go:embed serve_login.html
page string pageLogin string
) )
func LoginPage(c *nf.Ctx) error { func LoginPage(c *nf.Ctx) error {
type Req struct { type Req struct {
ClientID string `query:"client_id"` ClientId string `query:"client_id" json:"client_id"`
ClientSecret string `query:"client_secret"` ClientSecret string `query:"client_secret" json:"client_secret"`
Scope string `query:"scope"` Scope string `query:"scope" json:"scope"`
RedirectURI string `query:"redirect_uri"` RedirectURI string `query:"redirect_uri" json:"redirect_uri"`
} }
var ( var (
@ -30,7 +33,7 @@ func LoginPage(c *nf.Ctx) error {
return resp.Resp400(c, err.Error()) return resp.Resp400(c, err.Error())
} }
if req.ClientID == "" || req.ClientSecret == "" || req.RedirectURI == "" { if req.ClientId == "" || req.ClientSecret == "" || req.RedirectURI == "" {
return resp.Resp400(c, req) return resp.Resp400(c, req)
} }
@ -38,11 +41,17 @@ func LoginPage(c *nf.Ctx) error {
// todo: 如果用户是已登录状态,则直接带上信息返回到 authorize 页面 // todo: 如果用户是已登录状态,则直接带上信息返回到 authorize 页面
return c.RenderHTML("login", page, map[string]interface{}{ state := uuid.New().String()[:8]
"client_id": req.ClientID, if err = cache.Client.SetEx(c.Context(), cache.Prefix+"state_code:"+state, nil, 10*time.Minute); err != nil {
return resp.Resp500(c, err.Error())
}
return c.RenderHTML("login", pageLogin, map[string]interface{}{
"client_id": req.ClientId,
"client_secret": req.ClientSecret, "client_secret": req.ClientSecret,
"redirect_uri": req.RedirectURI, "redirect_uri": req.RedirectURI,
"scope": req.Scope, "scope": req.Scope,
"state": state,
}) })
} }

View File

@ -1,6 +1,7 @@
package handler package handler
import ( import (
_ "embed"
"errors" "errors"
"github.com/loveuer/nf" "github.com/loveuer/nf"
"github.com/loveuer/nf/nft/resp" "github.com/loveuer/nf/nft/resp"
@ -28,7 +29,7 @@ func ClientRegistry(c *nf.Ctx) error {
Secret := tool.RandomString(32) Secret := tool.RandomString(32)
platform := &model.Platform{ platform := &model.Client{
ClientId: req.ClientId, ClientId: req.ClientId,
Icon: req.Icon, Icon: req.Icon,
Name: req.Name, Name: req.Name,
@ -45,3 +46,66 @@ func ClientRegistry(c *nf.Ctx) error {
return resp.Resp200(c, platform) return resp.Resp200(c, platform)
} }
var (
//go:embed serve_registry.html
registryLogin string
)
func UserRegistryPage(c *nf.Ctx) error {
return c.HTML(registryLogin)
}
func UserRegistryAction(c *nf.Ctx) error {
type Req struct {
Username string `form:"username"`
Nickname string `form:"nickname"`
Password string `form:"password"`
ConfirmPassword string `form:"confirm_password"`
}
var (
err error
req = new(Req)
)
if err = c.BodyParser(req); err != nil {
return resp.Resp400(c, err.Error())
}
if err = tool.CheckPassword(req.Password); err != nil {
return resp.Resp400(c, req, err.Error())
}
op := &model.User{
Username: req.Username,
Nickname: req.Nickname,
Password: tool.NewPassword(req.Password),
}
if err = db.Default.Session().Create(op).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return resp.Resp400(c, err, "用户名已存在")
}
return resp.Resp500(c, err)
}
return c.HTML(`
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<title>注册成功</title>
</head>
<body>
<h1>注册成功</h1>
<h3>快去试试吧</h3>
</body>
</html>
`)
}

View File

@ -0,0 +1,49 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<title>Server Login</title>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
</style>
</head>
<body>
<div>
<h3>授权登录到 {{ .client_name }} 平台</h3>
<div class="userinfo">
<article style="display: flex; align-items: center;">
<div style="height:50px; width:50px;">
<img src="{{ .user.avatar }}"/>
</div>
<div style="margin-left:20px; ">
{{ .user.username }}
</div>
</article>
</div>
<form action="/api/oauth/v2/authorize" method="POST">
<fieldset>
<input type="hidden" name="client_id" value="{{ .client_id }}"/>
<input type="hidden" name="client_secret" value="{{ .client_secret }}"/>
<input type="hidden" name="redirect_uri" value="{{ .redirect_uri }}"/>
<input type="hidden" name="scope" value="{{ .scope }}"/>
<input type="hidden" name="state" value="{{ .state }}"/>
</fieldset>
<div style="display: flex;">
<button type="button" class="contrast" style="flex:1; margin-right: 10px">取消</button>
<button type="submit" style="flex: 1;">授权</button>
</div>
</form>
</div>
</body>
</html>

View File

@ -19,7 +19,7 @@
</head> </head>
<body> <body>
<div> <div>
<h3>欢迎来到 Pro</h3> <h3>欢迎来到 UAuth</h3>
<form action="/api/oauth/v2/login" method="POST"> <form action="/api/oauth/v2/login" method="POST">
<fieldset> <fieldset>
<label> <label>

View File

@ -0,0 +1,70 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<title>Server Login</title>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
</style>
</head>
<body>
<div>
<h3>欢迎注册 UAuth</h3>
<form action="/api/oauth/v2/registry/user" method="POST" id="form">
<fieldset>
<label>
用户名
<input id="username" name="username" autocomplete="given-name"/>
</label>
<label>
昵称
<input id="nickname" name="nickname" autocomplete="given-name"/>
</label>
<label>
密码
<input id="password" type="password" name="password" autocomplete="password"/>
</label>
<label>
重复密码
<input id="confirm_password" type="password" name="confirm_password" autocomplete="password"/>
</label>
<button type="button" style="flex: 1;width: 100%;" onclick="registry()">注册</button>
</fieldset>
</form>
</div>
<script type="text/javascript">
function registry() {
let user = {
username: document.querySelector("#username").value,
nickname: document.querySelector("#nickname").value,
password: document.querySelector("#password").value,
confirm_password: document.querySelector("#confirm_password").value,
}
console.log('[D] user = ', user)
if (!user.username || !user.password || !user.nickname) {
window.alert("参数均不能为空")
return
}
if (user.password !== user.confirm_password) {
window.alert("两次密码不一致")
return
}
document.querySelector("#form").submit()
}
</script>
</body>
</html>

View File

@ -8,6 +8,7 @@ import (
"github.com/loveuer/nf/nft/log" "github.com/loveuer/nf/nft/log"
"net/http" "net/http"
"time" "time"
"uauth/internal/middleware/auth"
"uauth/internal/opt" "uauth/internal/opt"
"uauth/internal/serve/handler" "uauth/internal/serve/handler"
"uauth/internal/store/cache" "uauth/internal/store/cache"
@ -42,23 +43,6 @@ func handleLogin(c *nf.Ctx) error {
return nil return nil
} }
// 处理用户的授权批准
func handleApprove(c *nf.Ctx) error {
// 获取表单数据
clientID := c.FormValue("client_id")
redirectURI := c.FormValue("redirect_uri")
scope := c.FormValue("scope")
// 生成授权码
authorizationCode := uuid.New().String()[:8]
log.Info("[D] client_id = %s, scope = %s, auth_code = %s", clientID, scope, authorizationCode)
// 重定向到回调 URL 并附带授权码
http.Redirect(c.Writer, c.Request, redirectURI+"?code="+authorizationCode, http.StatusFound)
return nil
}
// 令牌请求的处理 // 令牌请求的处理
func handleToken(c *nf.Ctx) error { func handleToken(c *nf.Ctx) error {
var ( var (
@ -131,15 +115,20 @@ func Run(ctx context.Context) error {
app := nf.New() app := nf.New()
api := app.Group(opt.Cfg.Svc.Prefix) api := app.Group(opt.Cfg.Svc.Prefix)
// 设置路由
api.Post("/client/registry", handler.ClientRegistry) api.Get("/registry/user", handler.UserRegistryPage)
api.Post("/registry/user", handler.UserRegistryAction)
api.Post("/registry/client", handler.ClientRegistry)
api.Get("/login", handler.LoginPage) api.Get("/login", handler.LoginPage)
api.Post("/login", handler.LoginAction) api.Post("/login", handler.LoginAction)
api.Get("/authorize", handler.Authorize) api.Get("/authorize", handler.Authorize)
api.Post("/approve", handleApprove) api.Post("/approve", handler.Approve)
api.Post("/token", handleToken) api.Post("/token", handleToken)
// 启动 HTTP 服务器 api.Get("/after", auth.New(), func(c *nf.Ctx) error {
return c.SendString("hello world")
})
log.Info("Starting server on: %s", opt.Cfg.Svc.Address) log.Info("Starting server on: %s", opt.Cfg.Svc.Address)
go func() { go func() {
<-ctx.Done() <-ctx.Done()

11
model/authorization.go Normal file
View File

@ -0,0 +1,11 @@
package model
type AuthorizationRecord struct {
Id uint64 `json:"id" gorm:"primaryKey;column:id"`
CreatedAt int64 `json:"created_at" gorm:"column:created_at;autoCreateTime:milli"`
UpdatedAt int64 `json:"updated_at" gorm:"column:updated_at;autoUpdateTime:milli"`
DeletedAt int64 `json:"deleted_at" gorm:"index;column:deleted_at;default:0"`
UserId uint64 `json:"user_id" gorm:"column:user_id"`
ClientId uint64 `json:"client_id" gorm:"column:client_id"`
}

View File

@ -1,6 +1,6 @@
package model package model
type Platform struct { type Client struct {
Id uint64 `json:"id" gorm:"primaryKey;column:id"` Id uint64 `json:"id" gorm:"primaryKey;column:id"`
CreatedAt int64 `json:"created_at" gorm:"column:created_at;autoCreateTime:milli"` CreatedAt int64 `json:"created_at" gorm:"column:created_at;autoCreateTime:milli"`
UpdatedAt int64 `json:"updated_at" gorm:"column:updated_at;autoUpdateTime:milli"` UpdatedAt int64 `json:"updated_at" gorm:"column:updated_at;autoUpdateTime:milli"`

30
model/init.go Normal file
View File

@ -0,0 +1,30 @@
package model
import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"uauth/internal/tool"
)
func Init(tx *gorm.DB) error {
var err error
if err = tx.AutoMigrate(
&User{},
&Client{},
&AuthorizationRecord{},
); err != nil {
return err
}
if err = tx.Clauses(clause.OnConflict{DoNothing: true}).
Create(&User{Username: "admin", Nickname: "admin", Password: tool.NewPassword("Foobar123")}).Error; err != nil {
return err
}
if err = tx.Clauses(clause.OnConflict{DoNothing: true}).
Create(&Client{ClientId: "test", ClientSecret: "test", Name: "测试", Icon: "https://picsum.photos/200"}).Error; err != nil {
return err
}
return nil
}

View File

@ -22,6 +22,7 @@ type User struct {
Status Status `json:"status" gorm:"column:status;default:0"` Status Status `json:"status" gorm:"column:status;default:0"`
Nickname string `json:"nickname" gorm:"column:nickname;type:varchar(64)"` Nickname string `json:"nickname" gorm:"column:nickname;type:varchar(64)"`
Avatar string `json:"avatar" gorm:"column:avatar;type:varchar(256)"`
Comment string `json:"comment" gorm:"column:comment"` Comment string `json:"comment" gorm:"column:comment"`
CreatedById uint64 `json:"created_by_id" gorm:"column:created_by_id"` CreatedById uint64 `json:"created_by_id" gorm:"column:created_by_id"`
@ -37,6 +38,7 @@ func (u *User) JwtEncode() (token string, err error) {
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, jwt.MapClaims{ jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, jwt.MapClaims{
"id": u.Id, "id": u.Id,
"username": u.Username, "username": u.Username,
"nickname": u.Nickname,
"status": u.Status, "status": u.Status,
"login_at": now.UnixMilli(), "login_at": now.UnixMilli(),
}) })