🎉 完成基本的演示和样例

This commit is contained in:
loveuer
2024-10-23 17:46:15 +08:00
commit aefc004e33
56 changed files with 2648 additions and 0 deletions

89
internal/client/client.go Normal file
View File

@ -0,0 +1,89 @@
package client
import (
"context"
_ "embed"
"encoding/json"
"github.com/google/uuid"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"github.com/loveuer/nf/nft/resp"
"golang.org/x/oauth2"
"net/http"
"uauth/internal/tool"
)
//go:embed login.html
var page string
var (
config = oauth2.Config{
ClientID: "test",
ClientSecret: "Foobar123",
Endpoint: oauth2.Endpoint{
AuthURL: "http://localhost:8080/oauth/v2/authorize",
TokenURL: "http://localhost:8080/oauth/v2/token",
},
RedirectURL: "http://localhost:18080/oauth/v2/redirect",
Scopes: []string{"test"},
}
state = uuid.New().String()[:8]
)
func Run(ctx context.Context) error {
app := nf.New()
app.Get("/login", handleLogin)
app.Get("/oauth/v2/redirect", handleRedirect)
go func() {
<-ctx.Done()
_ = app.Shutdown(tool.Timeout(2))
}()
return app.Run(":18080")
}
func handleLogin(c *nf.Ctx) error {
if c.Query("oauth") != "" {
uri := config.AuthCodeURL(state)
log.Info("[C] oauth config client_secret = %s", config.ClientSecret)
log.Info("[C] redirect to oauth2 server uri = %s", uri)
return c.Redirect(uri, http.StatusFound)
}
return c.HTML(page)
}
func handleRedirect(c *nf.Ctx) error {
type Req struct {
State string `query:"state"`
Code string `query:"code"`
Scope string `query:"scope"`
ClientId string `query:"client_id"`
}
var (
err error
req = new(Req)
token *oauth2.Token
)
if err = c.QueryParser(req); err != nil {
return resp.Resp400(c, err.Error())
}
if req.State != state {
log.Error("[C] state mismatch, want = %s, got = %s", state, req.State)
return c.Status(http.StatusBadRequest).SendString("Bad Request: state mismatch")
}
if token, err = config.Exchange(c.Context(), req.Code); err != nil {
log.Error("[C] oauth config exchange err: %s", err.Error())
return resp.Resp500(c, err.Error())
}
bs, _ := json.Marshal(token)
log.Info("[C] oauth finally token =\n%s", string(bs))
return resp.Resp200(c, token)
}

View File

@ -0,0 +1,52 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.min.css"
>
<title>Client Login</title>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
</style>
</head>
<body>
<div>
<h3>这里是 xx 产品登录页面</h3>
<form>
<fieldset>
<label>
Username
<input
name="username"
placeholder="username"
autocomplete="given-name"
/>
</label>
<label>
Password
<input
type="password"
name="password"
placeholder="password"
autocomplete="password"
/>
</label>
</fieldset>
<input
type="submit"
value="登录"
/>
<a href="/login?oauth=true">使用 OAuth V2 账号登录</a>
</form>
</div>
</body>
</html>

16
internal/cmd/client.go Normal file
View File

@ -0,0 +1,16 @@
package cmd
import (
"github.com/spf13/cobra"
"uauth/internal/client"
)
func initClient() *cobra.Command {
return &cobra.Command{
Use: "client",
Short: "Run the client",
RunE: func(cmd *cobra.Command, args []string) error {
return client.Run(cmd.Context())
},
}
}

31
internal/cmd/cmd.go Normal file
View File

@ -0,0 +1,31 @@
package cmd
import (
"github.com/loveuer/nf/nft/log"
"github.com/spf13/cobra"
"uauth/internal/opt"
)
var (
Command = &cobra.Command{
Use: "uauth",
Short: "uauth: oauth v2 server",
Example: "",
PersistentPreRun: func(cmd *cobra.Command, args []string) {
if opt.Cfg.Debug {
log.SetLogLevel(log.LogLevelDebug)
}
},
}
)
func init() {
Command.PersistentFlags().BoolVar(&opt.Cfg.Debug, "debug", false, "debug mode")
initServe()
Command.AddCommand(
initServe(),
initClient(),
)
}

31
internal/cmd/serve.go Normal file
View File

@ -0,0 +1,31 @@
package cmd
import (
"github.com/spf13/cobra"
"uauth/internal/opt"
"uauth/internal/serve"
"uauth/internal/store/cache"
"uauth/internal/store/db"
"uauth/internal/tool"
"uauth/model"
)
func initServe() *cobra.Command {
svc := &cobra.Command{
Use: "svc",
RunE: func(cmd *cobra.Command, args []string) error {
tool.TablePrinter(opt.Cfg)
tool.Must(cache.Init(opt.Cfg.Svc.Cache))
tool.Must(db.Init(cmd.Context(), opt.Cfg.Svc.DB))
tool.Must(model.Init(db.Default.Session()))
return serve.Run(cmd.Context())
},
}
svc.Flags().StringVar(&opt.Cfg.Svc.Address, "address", "localhost:8080", "listen address")
svc.Flags().StringVar(&opt.Cfg.Svc.Prefix, "prefix", "/oauth/v2", "api prefix")
svc.Flags().StringVar(&opt.Cfg.Svc.Cache, "cache", "lru::", "cache uri")
svc.Flags().StringVar(&opt.Cfg.Svc.DB, "db", "sqlite::data.sqlite", "database uri")
return svc
}

View File

@ -0,0 +1,22 @@
package interfaces
import (
"context"
"time"
)
type Cacher interface {
Get(ctx context.Context, key string) ([]byte, error)
GetScan(ctx context.Context, key string) Scanner
GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
GetExScan(ctx context.Context, key string, duration time.Duration) Scanner
// Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
Set(ctx context.Context, key string, value any) error
// SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
SetEx(ctx context.Context, key string, value any, duration time.Duration) error
Del(ctx context.Context, keys ...string) error
}
type Scanner interface {
Scan(model any) error
}

View File

@ -0,0 +1,11 @@
package interfaces
type Enum interface {
Value() int64
Code() string
Label() string
MarshalJSON() ([]byte, error)
All() []Enum
}

View File

@ -0,0 +1,7 @@
package interfaces
type OpLogger interface {
Enum
Render(content map[string]any) (string, error)
Template() string
}

View File

@ -0,0 +1,118 @@
package auth
import (
"errors"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"github.com/loveuer/nf/nft/resp"
"net/http"
"time"
"uauth/internal/store/cache"
"uauth/internal/tool"
"uauth/model"
)
type Config struct {
IgnoreFn func(c *nf.Ctx) bool
TokenFn func(c *nf.Ctx) (string, bool)
GetUserFn func(c *nf.Ctx, token string) (*model.User, error)
NextOnError bool
}
var (
defaultIgnoreFn = func(c *nf.Ctx) bool {
return false
}
defaultTokenFn = func(c *nf.Ctx) (string, bool) {
var token string
if token = c.Request.Header.Get("Authorization"); token != "" {
return token, true
}
if token = c.Query("access_token"); token != "" {
return token, true
}
if token = c.Cookies("access_token"); token != "" {
return token, true
}
return "", false
}
defaultGetUserFn = func(c *nf.Ctx, token string) (*model.User, error) {
var (
err error
op = new(model.User)
key = cache.Prefix + "token:" + token
)
if err = cache.Client.GetExScan(tool.Timeout(3), key, 24*time.Hour).Scan(op); err != nil {
if errors.Is(err, cache.ErrorKeyNotFound) {
return nil, err
}
log.Error("[M] cache client get user by token key = %s, err = %s", key, err.Error())
return nil, errors.New("Internal Server Error")
}
return op, nil
}
)
func New(cfgs ...*Config) nf.HandlerFunc {
var cfg = &Config{}
if len(cfgs) > 0 && cfgs[0] != nil {
cfg = cfgs[0]
}
if cfg.IgnoreFn == nil {
cfg.IgnoreFn = defaultIgnoreFn
}
if cfg.TokenFn == nil {
cfg.TokenFn = defaultTokenFn
}
if cfg.GetUserFn == nil {
cfg.GetUserFn = defaultGetUserFn
}
return func(c *nf.Ctx) error {
if cfg.IgnoreFn(c) {
return c.Next()
}
token, ok := cfg.TokenFn(c)
if !ok {
if cfg.NextOnError {
return c.Next()
}
return resp.Resp401(c, nil, "请登录")
}
op, err := cfg.GetUserFn(c, token)
if err != nil {
if cfg.NextOnError {
return c.Next()
}
if errors.Is(err, cache.ErrorKeyNotFound) {
return c.Status(http.StatusUnauthorized).JSON(map[string]any{
"status": 500,
"msg": "用户认证信息不存在或已过期, 请重新登录",
})
}
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
c.Locals("user", op)
return c.Next()
}
}

17
internal/opt/config.go Normal file
View File

@ -0,0 +1,17 @@
package opt
type svc struct {
Address string `json:"address"`
Prefix string `json:"prefix"`
Cache string `json:"cache"`
DB string `json:"db"`
}
type config struct {
Debug bool `json:"debug"`
Svc svc `json:"svc"`
}
var (
Cfg = config{}
)

6
internal/opt/var.go Normal file
View File

@ -0,0 +1,6 @@
package opt
const (
// 记得替换这个
JwtTokenSecret = "2(v6UW3pBf1Miz^bY9u4rAUyv&dj8Kdz"
)

View File

@ -0,0 +1,86 @@
package handler
import (
_ "embed"
"errors"
"github.com/google/uuid"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"github.com/loveuer/nf/nft/resp"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"net/http"
"net/url"
"time"
"uauth/internal/store/cache"
"uauth/internal/store/db"
"uauth/model"
)
//go:embed serve_approve.html
var pageApprove string
func Approve(c *nf.Ctx) error {
// 获取表单数据
type Req struct {
ClientId string `form:"client_id"`
RedirectURI string `form:"redirect_uri"`
Scope string `form:"scope"`
State string `form:"state"`
}
var (
ok bool
op *model.User
err error
req = new(Req)
uri *url.URL
client = new(model.Client)
)
if op, ok = c.Locals("user").(*model.User); !ok {
return resp.Resp401(c, nil)
}
if err = c.BodyParser(req); err != nil {
return resp.Resp400(c, err)
}
if uri, err = url.Parse(req.RedirectURI); err != nil {
log.Warn("[S] parse redirect uri = %s, err = %s", req.RedirectURI, err.Error())
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid redirect uri")
}
if err = db.Default.Session().Where("client_id", req.ClientId).Take(client).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid client_id")
}
log.Error("[S] get client by id fail, client_id = %s, err = %s", req.ClientId, err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
db.Default.Session().Clauses(clause.OnConflict{DoNothing: true}).
Create(&model.AuthorizationRecord{
UserId: op.Id,
ClientId: client.Id,
})
authorizationCode := uuid.New().String()[:8]
if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil {
return resp.Resp500(c, err)
}
qs := uri.Query()
qs.Add("code", authorizationCode)
qs.Add("client_id", req.ClientId)
qs.Add("scope", req.Scope)
qs.Add("state", req.State)
uri.ForceQuery = true
value := uri.String() + qs.Encode()
return c.RenderHTML("approve", pageApprove, map[string]interface{}{
"redirect_uri": value,
})
}

View File

@ -0,0 +1,121 @@
package handler
import (
_ "embed"
"errors"
"github.com/google/uuid"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"github.com/loveuer/nf/nft/resp"
"gorm.io/gorm"
"net/http"
"net/url"
"time"
"uauth/internal/store/cache"
"uauth/internal/store/db"
"uauth/model"
)
var (
//go:embed serve_authorize.html
pageAuthorize string
)
func Authorize(c *nf.Ctx) error {
type Req struct {
ClientId string `query:"client_id"`
ResponseType string `query:"response_type"`
RedirectURI string `query:"redirect_uri"`
Scope string `query:"scope"`
State string `query:"state"`
}
var (
ok bool
op *model.User
req = new(Req)
err error
client = &model.Client{}
authRecord = &model.AuthorizationRecord{}
uri *url.URL
)
if err = c.QueryParser(req); err != nil {
log.Error("[S] query parser err = %s", err.Error())
return c.Status(http.StatusBadRequest).SendString("Invalid request")
}
if req.ResponseType != "code" {
log.Warn("[S] response type = %s", req.ResponseType)
return c.Status(http.StatusBadRequest).SendString("Invalid request")
}
// 如果未登录,则跳转到登录界面
if op, ok = c.Locals("user").(*model.User); !ok {
log.Info("[S] op not logined, redirect to login page")
return c.Redirect("/oauth/v2/login?"+c.Request.URL.Query().Encode(), http.StatusFound)
}
log.Info("[S] Authorize: username = %s, client_id = %s", op.Username, req.ClientId)
if err = db.Default.Session().Where("client_id", req.ClientId).Take(client).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid client_id")
}
log.Error("[Authorize]: db take clients err = %s", err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
if err = db.Default.Session().Model(&model.AuthorizationRecord{}).
Where("user_id", op.Id).
Where("client_id", client.Id).
Take(authRecord).
Error; err != nil {
// 用户第一次对该 client 进行授权
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.RenderHTML("authorize", pageAuthorize, map[string]any{
"user": map[string]any{
"username": op.Username,
"avatar": "https://picsum.photos/200",
},
"client_id": req.ClientId,
"redirect_uri": req.RedirectURI,
"scope": req.Scope,
"state": req.State,
})
}
log.Error("[Authorize]: db take authorization_records err = %s", err.Error())
return resp.Resp500(c, err)
}
// 当用户已经授权过时
// 生成授权码并缓存授权码
log.Debug("[Authorize]: username = %s already approved %s", op.Username, client.Name)
authorizationCode := uuid.New().String()[:8]
if err = cache.Client.SetEx(c.Context(), cache.Prefix+"auth_code:"+authorizationCode, op.Id, 10*time.Minute); err != nil {
return resp.Resp500(c, err)
}
if uri, err = url.Parse(req.RedirectURI); err != nil {
log.Warn("[S] parse redirect uri = %s, err = %s", req.RedirectURI, err.Error())
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid redirect uri")
}
qs := uri.Query()
qs.Add("code", authorizationCode)
qs.Add("client_id", req.ClientId)
qs.Add("scope", req.Scope)
qs.Add("state", req.State)
uri.ForceQuery = true
value := uri.String() + qs.Encode()
return c.RenderHTML("approve", pageApprove, map[string]interface{}{
"redirect_uri": value,
})
}

View File

@ -0,0 +1,137 @@
package handler
import (
_ "embed"
"errors"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"github.com/loveuer/nf/nft/resp"
"gorm.io/gorm"
"net/http"
"time"
"uauth/internal/store/cache"
"uauth/internal/store/db"
"uauth/internal/tool"
"uauth/model"
)
var (
//go:embed serve_login.html
pageLogin string
)
func LoginPage(c *nf.Ctx) error {
type Req struct {
ClientId string `query:"client_id" json:"client_id"`
Scope string `query:"scope" json:"scope"`
RedirectURI string `query:"redirect_uri" json:"redirect_uri"`
State string `query:"state" json:"state"`
ResponseType string `query:"response_type" json:"response_type"`
}
var (
err error
req = new(Req)
client = new(model.Client)
)
if err = c.QueryParser(req); err != nil {
return resp.Resp400(c, err.Error())
}
if req.ClientId == "" || req.RedirectURI == "" {
return resp.Resp400(c, req)
}
if err = db.Default.Session().Model(&model.Client{}).
Where("client_id = ?", req.ClientId).
Take(client).
Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(http.StatusForbidden).SendString("Client Not Registry")
}
log.Error("[S] model take client id = %s, err = %s", req.ClientId, err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
return c.RenderHTML("login", pageLogin, map[string]interface{}{
"client_id": req.ClientId,
"redirect_uri": req.RedirectURI,
"scope": req.Scope,
"state": req.State,
"response_type": req.ResponseType,
"client_name": client.Name,
"client_icon": client.Icon,
})
}
//go:embed serve_login_success.html
var pageLoginSuccess string
func LoginAction(c *nf.Ctx) error {
type Req struct {
Username string `form:"username"`
Password string `form:"password"`
ClientId string `form:"client_id"`
RedirectURI string `form:"redirect_uri"`
Scope string `form:"scope"`
State string `form:"state"`
ResponseType string `form:"response_type"`
}
var (
err error
req = new(Req)
op = new(model.User)
token string
)
if err = c.BodyParser(req); err != nil {
log.Warn("[S] LoginAction: body parser err = %s", err.Error())
return c.Status(http.StatusBadRequest).SendString("Bad Request")
}
if req.Username == "" || req.Password == "" {
return c.Status(http.StatusBadRequest).SendString("Bad Request: username, password is required")
}
if err = db.Default.Session().Model(&model.User{}).
Where("username = ?", req.Username).
Take(op).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn("[S] LoginAction: username = %s not found", req.Username)
return c.Status(http.StatusBadRequest).SendString("Bad Request")
}
log.Error("[S] LoginAction: model take username = %s, err = %s", req.Username, err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
// todo: 验证用户登录是否成功,等等
if !tool.ComparePassword(req.Password, op.Password) {
log.Warn("[S] LoginAction: model take username = %s, password is invalid", req.Username)
return c.Status(http.StatusBadRequest).SendString("Bad Request")
}
if token, err = op.JwtEncode(); err != nil {
log.Error("[S] LoginAction: jwtEncode err = %s", err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
key := cache.Prefix + "token:" + token
if err = cache.Client.SetEx(c.Context(), key, op, 24*time.Hour); err != nil {
log.Error("[S] LoginAction: cache SetEx err = %s", err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
c.Writer.Header().Add("Set-Cookie", "access_token="+token)
return c.RenderHTML("login_success", pageLoginSuccess, map[string]interface{}{
"client_id": req.ClientId,
"redirect_uri": req.RedirectURI,
"scope": req.Scope,
"state": req.State,
"response_type": req.ResponseType,
})
}

View File

@ -0,0 +1,111 @@
package handler
import (
_ "embed"
"errors"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/resp"
"gorm.io/gorm"
"uauth/internal/store/db"
"uauth/internal/tool"
"uauth/model"
)
func ClientRegistry(c *nf.Ctx) error {
type Req struct {
ClientId string `json:"client_id"`
Icon string `json:"icon"` // url
Name string `json:"name"`
}
var (
err error
req = new(Req)
)
if err = c.BodyParser(req); err != nil {
return resp.Resp400(c, err.Error())
}
Secret := tool.RandomString(32)
platform := &model.Client{
ClientId: req.ClientId,
Icon: req.Icon,
Name: req.Name,
ClientSecret: Secret,
}
if err = db.Default.Session().Create(platform).Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return resp.Resp400(c, err, "当前平台已经存在")
}
return resp.Resp500(c, err)
}
return resp.Resp200(c, platform)
}
var (
//go:embed serve_registry.html
registryLogin string
)
func UserRegistryPage(c *nf.Ctx) error {
return c.HTML(registryLogin)
}
func UserRegistryAction(c *nf.Ctx) error {
type Req struct {
Username string `form:"username"`
Nickname string `form:"nickname"`
Password string `form:"password"`
ConfirmPassword string `form:"confirm_password"`
}
var (
err error
req = new(Req)
)
if err = c.BodyParser(req); err != nil {
return resp.Resp400(c, err.Error())
}
if err = tool.CheckPassword(req.Password); err != nil {
return resp.Resp400(c, req, err.Error())
}
op := &model.User{
Username: req.Username,
Nickname: req.Nickname,
Password: tool.NewPassword(req.Password),
}
if err = db.Default.Session().Create(op).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return resp.Resp400(c, err, "用户名已存在")
}
return resp.Resp500(c, err)
}
return c.HTML(`
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<title>注册成功</title>
</head>
<body>
<h1>注册成功</h1>
<h3>快去试试吧</h3>
</body>
</html>
`)
}

View File

@ -0,0 +1,33 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>授权成功</title>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
align-items: center;
justify-content: center;
}
</style>
</head>
<body>
<span aria-busy="true">授权成功, 正在跳转回原网页...</span>
<div style="display: none">
<input type="hidden" id="redirect_uri" value="{{ .redirect_uri }}"/>
</div>
<script type="text/javascript">
setTimeout(() => {
console.log('[D] after 1s console')
let redirect_uri = document.getElementById('redirect_uri').value
window.location.href = redirect_uri
}, 1000)
</script>
</body>
</html>

View File

@ -0,0 +1,48 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<title>Server Login</title>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
</style>
</head>
<body>
<div>
<h3>授权登录到 {{ .client_name }} 平台</h3>
<div class="userinfo">
<article style="display: flex; align-items: center;">
<div style="height:50px; width:50px;">
<img src="{{ .user.avatar }}"/>
</div>
<div style="margin-left:20px; ">
{{ .user.username }}
</div>
</article>
</div>
<form action="/oauth/v2/approve" method="POST">
<fieldset>
<input type="hidden" name="client_id" value="{{ .client_id }}"/>
<input type="hidden" name="redirect_uri" value="{{ .redirect_uri }}"/>
<input type="hidden" name="scope" value="{{ .scope }}"/>
<input type="hidden" name="state" value="{{ .state }}"/>
</fieldset>
<div style="display: flex;">
<button type="button" class="contrast" style="flex:1; margin-right: 10px">取消</button>
<button type="submit" style="flex: 1;">授权</button>
</div>
</form>
</div>
</body>
</html>

View File

@ -0,0 +1,89 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<title>Server Login</title>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
div.row {
display: flex;
align-items: center;
max-width: 33%;
height: 50px;
overflow: hidden;
flex: 1;
}
div.row:nth-child(2) {
justify-content: center;
}
div.row:last-child {
margin-left: auto;
}
</style>
</head>
<body>
<div>
<h3>欢迎来到 UAuth</h3>
<article style="display: flex; align-items: center;">
<div class="row">
<div style="height:50px; width:50px;">
<img src="https://picsum.photos/seed/drealism/200"/>
</div>
<div style="margin-left:10px; ">UAuth</div>
</div>
<div style="transform: rotate(90deg);" class="row">
<svg t="1730168415342" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="1981" xmlns:xlink="http://www.w3.org/1999/xlink" width="40" height="40"><path d="M428.3 66.4c-12-5-25.7-2.2-34.9 6.9l-320 319.6c-12.5 12.5-12.5 32.7 0 45.3 12.5 12.5 32.7 12.5 45.3 0l265.4-265V928c0 17.7 14.3 32 32 32s32-14.3 32-32V96c-0.1-12.9-7.9-24.6-19.8-29.6zM950.6 585.8c-12.5-12.5-32.8-12.5-45.3 0L640 850.8V96c0-17.7-14.3-32-32-32s-32 14.3-32 32v832c0 12.9 7.8 24.6 19.7 29.6 4 1.6 8.1 2.4 12.2 2.4 8.3 0 16.5-3.2 22.6-9.4l320-319.6c12.6-12.4 12.6-32.7 0.1-45.2z" p-id="1982"></path></svg>
</div>
<div class="row">
<div style="height:50px; width:50px;">
<img src="{{ .client_icon }}"/>
</div>
<div style="margin-left:10px; ">
{{ .client_name }}
</div>
</div>
</article>
<form action="/oauth/v2/login" method="POST">
<fieldset>
<label>
Username
<input
name="username"
placeholder="username"
autocomplete="given-name"
/>
</label>
<label>
Password
<input
type="password"
name="password"
placeholder="password"
autocomplete="password"
/>
</label>
<input type="hidden" name="client_id" value="{{ .client_id }}"/>
<input type="hidden" name="redirect_uri" value="{{ .redirect_uri }}"/>
<input type="hidden" name="scope" value="{{ .scope }}"/>
<input type="hidden" name="state" value="{{ .state }}" />
<input type="hidden" name="response_type" value="{{ .response_type }}" />
</fieldset>
<input
type="submit"
value="登录"
/>
</form>
</div>
</body>
</html>

View File

@ -0,0 +1,41 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>登录成功</title>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
align-items: center;
justify-content: center;
}
</style>
</head>
<body>
<span aria-busy="true">登录成功, 正在跳转...</span>
<div style="display: none">
<input type="hidden" id="client_id" value="{{ .client_id }}"/>
<input type="hidden" id="scope" value="{{ .scope }}"/>
<input type="hidden" id="state" value="{{ .state }}"/>
<input type="hidden" id="redirect_uri" value="{{ .redirect_uri }}"/>
<input type="hidden" id="response_type" value="{{ .response_type }}"/>
</div>
<script type="text/javascript">
setTimeout(() => {
console.log('[D] after 1s console')
let client_id = document.querySelector('#client_id').value;
let scope = document.querySelector('#scope').value;
let state = document.querySelector('#state').value;
let redirect_uri = document.querySelector('#redirect_uri').value;
let response_type = document.querySelector('#response_type').value;
window.location.href = `/oauth/v2/authorize?client_id=${client_id}&scope=${scope}&redirect_uri=${redirect_uri}&state=${state}&response_type=${response_type}`;
}, 1000)
</script>
</body>
</html>

View File

@ -0,0 +1,70 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/@picocss/pico@2/css/pico.jade.min.css"
>
<title>Server Login</title>
<style>
body {
height: 100vh;
width: 100vw;
display: flex;
justify-content: center;
align-items: center;
}
</style>
</head>
<body>
<div>
<h3>欢迎注册 UAuth</h3>
<form action="/api/oauth/v2/registry/user" method="POST" id="form">
<fieldset>
<label>
用户名
<input id="username" name="username" autocomplete="given-name"/>
</label>
<label>
昵称
<input id="nickname" name="nickname" autocomplete="given-name"/>
</label>
<label>
密码
<input id="password" type="password" name="password" autocomplete="password"/>
</label>
<label>
重复密码
<input id="confirm_password" type="password" name="confirm_password" autocomplete="password"/>
</label>
<button type="button" style="flex: 1;width: 100%;" onclick="registry()">注册</button>
</fieldset>
</form>
</div>
<script type="text/javascript">
function registry() {
let user = {
username: document.querySelector("#username").value,
nickname: document.querySelector("#nickname").value,
password: document.querySelector("#password").value,
confirm_password: document.querySelector("#confirm_password").value,
}
console.log('[D] user = ', user)
if (!user.username || !user.password || !user.nickname) {
window.alert("参数均不能为空")
return
}
if (user.password !== user.confirm_password) {
window.alert("两次密码不一致")
return
}
document.querySelector("#form").submit()
}
</script>
</body>
</html>

View File

@ -0,0 +1,109 @@
package handler
import (
"encoding/base64"
"errors"
"github.com/google/uuid"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"gorm.io/gorm"
"net/http"
"strings"
"uauth/internal/store/cache"
"uauth/internal/store/db"
"uauth/internal/tool"
"uauth/model"
)
func HandleToken(c *nf.Ctx) error {
type Req struct {
Code string `form:"code"`
GrantType string `form:"grant_type"`
RedirectURI string `form:"redirect_uri"`
}
var (
err error
req = new(Req)
opId uint64
op = new(model.User)
token string
basic string
bs []byte
strs []string
client = new(model.Client)
)
if err = c.BodyParser(req); err != nil {
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid form")
}
// client_secret
if basic = c.Get("Authorization"); basic == "" {
return c.Status(http.StatusUnauthorized).SendString("Authorization header missing")
}
switch {
case strings.HasPrefix(basic, "Basic "):
basic = strings.TrimPrefix(basic, "Basic ")
default:
return c.Status(http.StatusBadRequest).SendString("Bad Request: authorization scheme not supported")
}
if bs, err = base64.StdEncoding.DecodeString(basic); err != nil {
log.Warn("[Token] base64 decode failed, raw = %s, err = %s", basic, err.Error())
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid basic authorization")
}
if strs = strings.SplitN(string(bs), ":", 2); len(strs) != 2 {
log.Warn("[Token] basic split err, decode = %s", string(bs))
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid basic authorization")
}
clientId, clientSecret := strs[0], strs[1]
if err = db.Default.Session().Model(&model.Client{}).
Where("client_id", clientId).
Take(client).
Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(http.StatusBadRequest).SendString("Bad Request: client invalid")
}
log.Error("[Token] db take client by id = %s, err = %s", clientId, err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
if client.ClientSecret != clientSecret {
log.Warn("[Token] client_secret invalid, want = %s, got = %s", client.ClientSecret, clientSecret)
return c.Status(http.StatusUnauthorized).SendString("Unauthorized: client secret invalid")
}
if err = cache.Client.GetScan(tool.Timeout(2), cache.Prefix+"auth_code:"+req.Code).Scan(&opId); err != nil {
if errors.Is(err, cache.ErrorKeyNotFound) {
return c.Status(http.StatusBadRequest).SendString("Bad Request: invalid code")
}
log.Error("[S] handleToken: get code from cache err = %s", err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
op.Id = opId
if err = db.Default.Session().Take(op).Error; err != nil {
log.Error("[S] handleToken: get op by id err, id = %d, err = %s", opId, err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
if token, err = op.JwtEncode(); err != nil {
log.Error("[S] handleToken: encode token err, id = %d, err = %s", opId, err.Error())
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
refreshToken := uuid.New().String()
return c.JSON(map[string]any{
"access_token": token,
"refresh_token": refreshToken,
"token_type": "Bearer",
"expires_in": 24 * 3600,
})
}

39
internal/serve/serve.go Normal file
View File

@ -0,0 +1,39 @@
package serve
import (
"context"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"uauth/internal/middleware/auth"
"uauth/internal/opt"
"uauth/internal/serve/handler"
"uauth/internal/tool"
)
func Run(ctx context.Context) error {
app := nf.New()
api := app.Group(opt.Cfg.Svc.Prefix)
api.Get("/registry/user", handler.UserRegistryPage)
api.Post("/registry/user", handler.UserRegistryAction)
api.Post("/registry/client", handler.ClientRegistry)
api.Get("/login", handler.LoginPage)
api.Post("/login", handler.LoginAction)
api.Get("/authorize", auth.New(&auth.Config{NextOnError: true}), handler.Authorize)
api.Post("/approve", auth.New(), handler.Approve)
api.Post("/token", handler.HandleToken)
api.Get("/after", auth.New(), func(c *nf.Ctx) error {
return c.SendString("hello world")
})
log.Info("Starting server on: %s", opt.Cfg.Svc.Address)
go func() {
<-ctx.Done()
_ = app.Shutdown(tool.Timeout(2))
}()
return app.Run(opt.Cfg.Svc.Address)
}

117
internal/store/cache/cache_lru.go vendored Normal file
View File

@ -0,0 +1,117 @@
package cache
import (
"context"
"github.com/hashicorp/golang-lru/v2/expirable"
_ "github.com/hashicorp/golang-lru/v2/expirable"
"time"
"uauth/internal/interfaces"
)
var _ interfaces.Cacher = (*_lru)(nil)
type _lru struct {
client *expirable.LRU[string, *_lru_value]
}
type _lru_value struct {
duration time.Duration
last time.Time
bs []byte
}
func (l *_lru) Get(ctx context.Context, key string) ([]byte, error) {
v, ok := l.client.Get(key)
if !ok {
return nil, ErrorKeyNotFound
}
if v.duration == 0 {
return v.bs, nil
}
if time.Now().Sub(v.last) > v.duration {
l.client.Remove(key)
return nil, ErrorKeyNotFound
}
return v.bs, nil
}
func (l *_lru) GetScan(ctx context.Context, key string) interfaces.Scanner {
return newScanner(l.Get(ctx, key))
}
func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
v, ok := l.client.Get(key)
if !ok {
return nil, ErrorKeyNotFound
}
if v.duration == 0 {
return v.bs, nil
}
now := time.Now()
if now.Sub(v.last) > v.duration {
l.client.Remove(key)
return nil, ErrorKeyNotFound
}
l.client.Add(key, &_lru_value{
duration: duration,
last: now,
bs: v.bs,
})
return v.bs, nil
}
func (l *_lru) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner {
return newScanner(l.GetEx(ctx, key, duration))
}
func (l *_lru) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
return err
}
l.client.Add(key, &_lru_value{
duration: 0,
last: time.Now(),
bs: bs,
})
return nil
}
func (l *_lru) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
return err
}
l.client.Add(key, &_lru_value{
duration: duration,
last: time.Now(),
bs: bs,
})
return nil
}
func (l *_lru) Del(ctx context.Context, keys ...string) error {
for _, key := range keys {
l.client.Remove(key)
}
return nil
}
func newLRUCache() (interfaces.Cacher, error) {
client := expirable.NewLRU[string, *_lru_value](1024*1024, nil, 0)
return &_lru{client: client}, nil
}

82
internal/store/cache/cache_memory.go vendored Normal file
View File

@ -0,0 +1,82 @@
package cache
import (
"context"
"errors"
"fmt"
"time"
"uauth/internal/interfaces"
"gitea.com/taozitaozi/gredis"
)
var _ interfaces.Cacher = (*_mem)(nil)
type _mem struct {
client *gredis.Gredis
}
func (m *_mem) GetScan(ctx context.Context, key string) interfaces.Scanner {
return newScanner(m.Get(ctx, key))
}
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner {
return newScanner(m.GetEx(ctx, key, duration))
}
func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) {
v, err := m.client.Get(key)
if err != nil {
if errors.Is(err, gredis.ErrKeyNotFound) {
return nil, ErrorKeyNotFound
}
return nil, err
}
bs, ok := v.([]byte)
if !ok {
return nil, fmt.Errorf("invalid value type=%T", v)
}
return bs, nil
}
func (m *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
v, err := m.client.GetEx(key, duration)
if err != nil {
if errors.Is(err, gredis.ErrKeyNotFound) {
return nil, ErrorKeyNotFound
}
return nil, err
}
bs, ok := v.([]byte)
if !ok {
return nil, fmt.Errorf("invalid value type=%T", v)
}
return bs, nil
}
func (m *_mem) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
return err
}
return m.client.Set(key, bs)
}
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
return err
}
return m.client.SetEx(key, bs, duration)
}
func (m *_mem) Del(ctx context.Context, keys ...string) error {
m.client.Delete(keys...)
return nil
}

72
internal/store/cache/cache_redis.go vendored Normal file
View File

@ -0,0 +1,72 @@
package cache
import (
"context"
"errors"
"github.com/go-redis/redis/v8"
"time"
"uauth/internal/interfaces"
)
type _redis struct {
client *redis.Client
}
func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) {
result, err := r.client.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return []byte(result), nil
}
func (r *_redis) GetScan(ctx context.Context, key string) interfaces.Scanner {
return newScanner(r.Get(ctx, key))
}
func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
result, err := r.client.GetEx(ctx, key, duration).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return []byte(result), nil
}
func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) interfaces.Scanner {
return newScanner(r.GetEx(ctx, key, duration))
}
func (r *_redis) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
return err
}
_, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result()
return err
}
func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
return err
}
_, err = r.client.SetEX(ctx, key, bs, duration).Result()
return err
}
func (r *_redis) Del(ctx context.Context, keys ...string) error {
return r.client.Del(ctx, keys...).Err()
}

42
internal/store/cache/client.go vendored Normal file
View File

@ -0,0 +1,42 @@
package cache
import (
"encoding/json"
"uauth/internal/interfaces"
)
const (
Prefix = "sys:uauth:"
)
var (
Client interfaces.Cacher
)
type encoded_value interface {
MarshalBinary() ([]byte, error)
}
type decoded_value interface {
UnmarshalBinary(bs []byte) error
}
func handleValue(value any) ([]byte, error) {
var (
bs []byte
err error
)
switch value.(type) {
case []byte:
return value.([]byte), nil
}
if imp, ok := value.(encoded_value); ok {
bs, err = imp.MarshalBinary()
} else {
bs, err = json.Marshal(value)
}
return bs, err
}

7
internal/store/cache/error.go vendored Normal file
View File

@ -0,0 +1,7 @@
package cache
import "errors"
var (
ErrorKeyNotFound = errors.New("key not found")
)

69
internal/store/cache/init.go vendored Normal file
View File

@ -0,0 +1,69 @@
package cache
import (
"fmt"
"gitea.com/taozitaozi/gredis"
"github.com/go-redis/redis/v8"
"net/url"
"strings"
"uauth/internal/tool"
)
func Init(uri string) error {
var (
err error
)
strs := strings.Split(uri, "::")
switch strs[0] {
case "memory":
gc := gredis.NewGredis(1024 * 1024)
Client = &_mem{client: gc}
case "lru":
if Client, err = newLRUCache(); err != nil {
return err
}
case "redis":
var (
ins *url.URL
err error
)
if len(strs) != 2 {
return fmt.Errorf("cache.Init: invalid cache uri: %s", uri)
}
uri := strs[1]
if !strings.Contains(uri, "://") {
uri = fmt.Sprintf("redis://%s", uri)
}
if ins, err = url.Parse(uri); err != nil {
return fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", uri, err.Error())
}
addr := ins.Host
username := ins.User.Username()
password, _ := ins.User.Password()
var rc *redis.Client
rc = redis.NewClient(&redis.Options{
Addr: addr,
Username: username,
Password: password,
})
if err = rc.Ping(tool.Timeout(5)).Err(); err != nil {
return fmt.Errorf("cache.Init: redis ping err: %s", err.Error())
}
Client = &_redis{client: rc}
default:
return fmt.Errorf("cache type %s not support", strs[0])
}
return nil
}

20
internal/store/cache/scan.go vendored Normal file
View File

@ -0,0 +1,20 @@
package cache
import "encoding/json"
type scanner struct {
err error
bs []byte
}
func (s *scanner) Scan(model any) error {
if s.err != nil {
return s.err
}
return json.Unmarshal(s.bs, model)
}
func newScanner(bs []byte, err error) *scanner {
return &scanner{bs: bs, err: err}
}

View File

@ -0,0 +1,45 @@
package db
import (
"context"
"uauth/internal/opt"
"uauth/internal/tool"
"gorm.io/gorm"
)
var (
Default *Client
)
type Client struct {
ctx context.Context
cli *gorm.DB
ttype string
}
func (c *Client) Type() string {
return c.ttype
}
func (c *Client) Session(ctxs ...context.Context) *gorm.DB {
var ctx context.Context
if len(ctxs) > 0 && ctxs[0] != nil {
ctx = ctxs[0]
} else {
ctx = tool.Timeout(30)
}
session := c.cli.Session(&gorm.Session{Context: ctx})
if opt.Cfg.Debug {
session = session.Debug()
}
return session
}
func (c *Client) Close() {
d, _ := c.cli.DB()
d.Close()
}

View File

@ -0,0 +1,9 @@
package db
import (
"testing"
)
func TestOpen(t *testing.T) {
}

52
internal/store/db/init.go Normal file
View File

@ -0,0 +1,52 @@
package db
import (
"context"
"fmt"
"strings"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
func New(ctx context.Context, uri string) (*Client, error) {
strs := strings.Split(uri, "::")
if len(strs) != 2 {
return nil, fmt.Errorf("db.Init: opt db uri invalid: %s", uri)
}
c := &Client{ttype: strs[0]}
var (
err error
dsn = strs[1]
)
switch strs[0] {
case "sqlite":
c.cli, err = gorm.Open(sqlite.Open(dsn))
case "mysql":
c.cli, err = gorm.Open(mysql.Open(dsn))
case "postgres":
c.cli, err = gorm.Open(postgres.Open(dsn))
default:
return nil, fmt.Errorf("db type only support: [sqlite, mysql, postgres], unsupported db type: %s", strs[0])
}
if err != nil {
return nil, fmt.Errorf("db.Init: open %s with dsn:%s, err: %w", strs[0], dsn, err)
}
return c, nil
}
func Init(ctx context.Context, uri string) (err error) {
if Default, err = New(ctx, uri); err != nil {
return err
}
return nil
}

104
internal/tool/cert.go Normal file
View File

@ -0,0 +1,104 @@
package tool
import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"time"
)
func GenerateTlsConfig() (serverTLSConf *tls.Config, clientTLSConf *tls.Config, err error) {
ca := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{"California"},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(99, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
// create our private and public key
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
// create the CA
caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}
// pem encode
caPEM := new(bytes.Buffer)
pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
})
caPrivKeyPEM := new(bytes.Buffer)
pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
})
// set up our server certificate
cert := &x509.Certificate{
SerialNumber: big.NewInt(2019),
Subject: pkix.Name{
Organization: []string{"Company, INC."},
Country: []string{"US"},
Province: []string{"California"},
Locality: []string{"San Francisco"},
StreetAddress: []string{"Golden Gate Bridge"},
PostalCode: []string{"94016"},
},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, err
}
certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
if err != nil {
return nil, nil, err
}
certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
if err != nil {
return nil, nil, err
}
serverTLSConf = &tls.Config{
Certificates: []tls.Certificate{serverCert},
}
certpool := x509.NewCertPool()
certpool.AppendCertsFromPEM(caPEM.Bytes())
clientTLSConf = &tls.Config{
RootCAs: certpool,
}
return
}

38
internal/tool/ctx.go Normal file
View File

@ -0,0 +1,38 @@
package tool
import (
"context"
"time"
)
func Timeout(seconds ...int) (ctx context.Context) {
var (
duration time.Duration
)
if len(seconds) > 0 && seconds[0] > 0 {
duration = time.Duration(seconds[0]) * time.Second
} else {
duration = time.Duration(30) * time.Second
}
ctx, _ = context.WithTimeout(context.Background(), duration)
return
}
func TimeoutCtx(ctx context.Context, seconds ...int) context.Context {
var (
duration time.Duration
)
if len(seconds) > 0 && seconds[0] > 0 {
duration = time.Duration(seconds[0]) * time.Second
} else {
duration = time.Duration(30) * time.Second
}
nctx, _ := context.WithTimeout(ctx, duration)
return nctx
}

30
internal/tool/file.go Normal file
View File

@ -0,0 +1,30 @@
package tool
import (
"io"
"os"
)
func CopyFile(src string, dst string) (err error) {
// Open the source file
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
// Create the destination file
destinationFile, err := os.Create(dst)
if err != nil {
return err
}
defer destinationFile.Close()
// Copy the contents from source to destination
_, err = io.Copy(destinationFile, sourceFile)
if err != nil {
return err
}
return nil
}

24
internal/tool/human.go Normal file
View File

@ -0,0 +1,24 @@
package tool
import "fmt"
func HumanDuration(nano int64) string {
duration := float64(nano)
unit := "ns"
if duration >= 1000 {
duration /= 1000
unit = "us"
}
if duration >= 1000 {
duration /= 1000
unit = "ms"
}
if duration >= 1000 {
duration /= 1000
unit = " s"
}
return fmt.Sprintf("%6.2f%s", duration, unit)
}

11
internal/tool/must.go Normal file
View File

@ -0,0 +1,11 @@
package tool
import "github.com/loveuer/nf/nft/log"
func Must(errs ...error) {
for _, err := range errs {
if err != nil {
log.Panic(err.Error())
}
}
}

84
internal/tool/password.go Normal file
View File

@ -0,0 +1,84 @@
package tool
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"github.com/loveuer/nf/nft/log"
"golang.org/x/crypto/pbkdf2"
"regexp"
"strconv"
"strings"
)
const (
EncryptHeader string = "pbkdf2:sha256" // 用户密码加密
)
func NewPassword(password string) string {
return EncryptPassword(password, RandomString(8), int(RandomInt(50000)+100000))
}
func ComparePassword(in, db string) bool {
strs := strings.Split(db, "$")
if len(strs) != 3 {
log.Error("password in db invalid: %s", db)
return false
}
encs := strings.Split(strs[0], ":")
if len(encs) != 3 {
log.Error("password in db invalid: %s", db)
return false
}
encIteration, err := strconv.Atoi(encs[2])
if err != nil {
log.Error("password in db invalid: %s, convert iter err: %s", db, err)
return false
}
return EncryptPassword(in, strs[1], encIteration) == db
}
func EncryptPassword(password, salt string, iter int) string {
hash := pbkdf2.Key([]byte(password), []byte(salt), iter, 32, sha256.New)
encrypted := hex.EncodeToString(hash)
return fmt.Sprintf("%s:%d$%s$%s", EncryptHeader, iter, salt, encrypted)
}
func CheckPassword(password string) error {
if len(password) < 8 || len(password) > 32 {
return errors.New("密码长度不符合")
}
var (
err error
match bool
patternList = []string{`[0-9]+`, `[a-z]+`, `[A-Z]+`, `[!@#%]+`} //, `[~!@#$%^&*?_-]+`}
matchAccount = 0
tips = []string{"缺少数字", "缺少小写字母", "缺少大写字母", "缺少'!@#%'"}
locktips = make([]string, 0)
)
for idx, pattern := range patternList {
match, err = regexp.MatchString(pattern, password)
if err != nil {
log.Warn("regex match string err, reg_str: %s, err: %v", pattern, err)
return errors.New("密码强度不够")
}
if match {
matchAccount++
} else {
locktips = append(locktips, tips[idx])
}
}
if matchAccount < 3 {
return fmt.Errorf("密码强度不够, 可能 %s", strings.Join(locktips, ", "))
}
return nil
}

View File

@ -0,0 +1,11 @@
package tool
import "testing"
func TestEncPassword(t *testing.T) {
password := "123456"
result := EncryptPassword(password, RandomString(8), 50000)
t.Logf("sum => %s", result)
}

54
internal/tool/random.go Normal file
View File

@ -0,0 +1,54 @@
package tool
import (
"crypto/rand"
"math/big"
)
var (
letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
letterNum = []byte("0123456789")
letterLow = []byte("abcdefghijklmnopqrstuvwxyz")
letterCap = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
letterSyb = []byte("!@#$%^&*()_+-=")
)
func RandomInt(max int64) int64 {
num, _ := rand.Int(rand.Reader, big.NewInt(max))
return num.Int64()
}
func RandomString(length int) string {
result := make([]byte, length)
for i := 0; i < length; i++ {
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
result[i] = letters[num.Int64()]
}
return string(result)
}
func RandomPassword(length int, withSymbol bool) string {
result := make([]byte, length)
kind := 3
if withSymbol {
kind++
}
for i := 0; i < length; i++ {
switch i % kind {
case 0:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterNum))))
result[i] = letterNum[num.Int64()]
case 1:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterLow))))
result[i] = letterLow[num.Int64()]
case 2:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterCap))))
result[i] = letterCap[num.Int64()]
case 3:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterSyb))))
result[i] = letterSyb[num.Int64()]
}
}
return string(result)
}

5
internal/tool/slice.go Normal file
View File

@ -0,0 +1,5 @@
package tool
func Bulk[T any](slice []T, size int) {
// todo
}

View File

@ -0,0 +1 @@
package tool

124
internal/tool/table.go Normal file
View File

@ -0,0 +1,124 @@
package tool
import (
"encoding/json"
"fmt"
"github.com/jedib0t/go-pretty/v6/table"
"github.com/loveuer/nf/nft/log"
"io"
"os"
"reflect"
"strings"
)
func TablePrinter(data any, writers ...io.Writer) {
var w io.Writer = os.Stdout
if len(writers) > 0 && writers[0] != nil {
w = writers[0]
}
t := table.NewWriter()
structPrinter(t, "", data)
_, _ = fmt.Fprintln(w, t.Render())
}
func structPrinter(w table.Writer, prefix string, item any) {
Start:
rv := reflect.ValueOf(item)
if rv.IsZero() {
return
}
for rv.Type().Kind() == reflect.Pointer {
rv = rv.Elem()
}
switch rv.Type().Kind() {
case reflect.Invalid,
reflect.Uintptr,
reflect.Chan,
reflect.Func,
reflect.UnsafePointer:
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Float32,
reflect.Float64,
reflect.Complex64,
reflect.Complex128,
reflect.Interface:
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), rv.Interface()})
case reflect.String:
val := rv.String()
if len(val) <= 160 {
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val})
return
}
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val[0:64] + "..." + val[len(val)-64:]})
case reflect.Array, reflect.Slice:
for i := 0; i < rv.Len(); i++ {
p := strings.Join([]string{prefix, fmt.Sprintf("[%d]", i)}, ".")
structPrinter(w, p, rv.Index(i).Interface())
}
case reflect.Map:
for _, k := range rv.MapKeys() {
structPrinter(w, fmt.Sprintf("%s.{%v}", prefix, k), rv.MapIndex(k).Interface())
}
case reflect.Pointer:
goto Start
case reflect.Struct:
for i := 0; i < rv.NumField(); i++ {
p := fmt.Sprintf("%s.%s", prefix, rv.Type().Field(i).Name)
field := rv.Field(i)
//log.Debug("TablePrinter: prefix: %s, field: %v", p, rv.Field(i))
if !field.CanInterface() {
return
}
structPrinter(w, p, field.Interface())
}
}
}
func TableMapPrinter(data []byte) {
m := make(map[string]any)
if err := json.Unmarshal(data, &m); err != nil {
log.Warn(err.Error())
return
}
t := table.NewWriter()
addRow(t, "", m)
fmt.Println(t.Render())
}
func addRow(w table.Writer, prefix string, m any) {
rv := reflect.ValueOf(m)
switch rv.Type().Kind() {
case reflect.Map:
for _, k := range rv.MapKeys() {
key := k.String()
if prefix != "" {
key = strings.Join([]string{prefix, k.String()}, ".")
}
addRow(w, key, rv.MapIndex(k).Interface())
}
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
addRow(w, fmt.Sprintf("%s[%d]", prefix, i), rv.Index(i).Interface())
}
default:
w.AppendRow(table.Row{prefix, m})
}
}

13
internal/tool/time.go Normal file
View File

@ -0,0 +1,13 @@
package tool
import "time"
// TodayMidnight 返回今日凌晨
func TodayMidnight() (midnight time.Time) {
now := time.Now()
year, month, day := now.Date()
midnight = time.Date(year, month, day, 0, 0, 0, 0, time.Local)
return
}