structure: 确定基本结构(保持基本形式, 采用组合)
This commit is contained in:
116
pkg/cache/cache_lru.go
vendored
Normal file
116
pkg/cache/cache_lru.go
vendored
Normal file
@@ -0,0 +1,116 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
_ "github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ Cache = (*_lru)(nil)
|
||||
|
||||
type _lru struct {
|
||||
client *expirable.LRU[string, *_lru_value]
|
||||
}
|
||||
|
||||
type _lru_value struct {
|
||||
duration time.Duration
|
||||
last time.Time
|
||||
bs []byte
|
||||
}
|
||||
|
||||
func (l *_lru) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
v, ok := l.client.Get(key)
|
||||
if !ok {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
if v.duration == 0 {
|
||||
return v.bs, nil
|
||||
}
|
||||
|
||||
if time.Now().Sub(v.last) > v.duration {
|
||||
l.client.Remove(key)
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return v.bs, nil
|
||||
}
|
||||
|
||||
func (l *_lru) GetScan(ctx context.Context, key string) Scanner {
|
||||
return newScanner(l.Get(ctx, key))
|
||||
}
|
||||
|
||||
func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
|
||||
v, ok := l.client.Get(key)
|
||||
if !ok {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
if v.duration == 0 {
|
||||
return v.bs, nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
if now.Sub(v.last) > v.duration {
|
||||
l.client.Remove(key)
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
l.client.Add(key, &_lru_value{
|
||||
duration: duration,
|
||||
last: now,
|
||||
bs: v.bs,
|
||||
})
|
||||
|
||||
return v.bs, nil
|
||||
}
|
||||
|
||||
func (l *_lru) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
|
||||
return newScanner(l.GetEx(ctx, key, duration))
|
||||
}
|
||||
|
||||
func (l *_lru) Set(ctx context.Context, key string, value any) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.client.Add(key, &_lru_value{
|
||||
duration: 0,
|
||||
last: time.Now(),
|
||||
bs: bs,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *_lru) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.client.Add(key, &_lru_value{
|
||||
duration: duration,
|
||||
last: time.Now(),
|
||||
bs: bs,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *_lru) Del(ctx context.Context, keys ...string) error {
|
||||
for _, key := range keys {
|
||||
l.client.Remove(key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newLRUCache() (Cache, error) {
|
||||
client := expirable.NewLRU[string, *_lru_value](1024*1024, nil, 0)
|
||||
|
||||
return &_lru{client: client}, nil
|
||||
}
|
||||
81
pkg/cache/cache_memory.go
vendored
Normal file
81
pkg/cache/cache_memory.go
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gitea.com/loveuer/gredis"
|
||||
)
|
||||
|
||||
var _ Cache = (*_mem)(nil)
|
||||
|
||||
type _mem struct {
|
||||
client *gredis.Gredis
|
||||
}
|
||||
|
||||
func (m *_mem) GetScan(ctx context.Context, key string) Scanner {
|
||||
return newScanner(m.Get(ctx, key))
|
||||
}
|
||||
|
||||
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
|
||||
return newScanner(m.GetEx(ctx, key, duration))
|
||||
}
|
||||
|
||||
func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
v, err := m.client.Get(key)
|
||||
if err != nil {
|
||||
if errors.Is(err, gredis.ErrKeyNotFound) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, ok := v.([]byte)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid value type=%T", v)
|
||||
}
|
||||
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func (m *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
|
||||
v, err := m.client.GetEx(key, duration)
|
||||
if err != nil {
|
||||
if errors.Is(err, gredis.ErrKeyNotFound) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bs, ok := v.([]byte)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid value type=%T", v)
|
||||
}
|
||||
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func (m *_mem) Set(ctx context.Context, key string, value any) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.client.Set(key, bs)
|
||||
}
|
||||
|
||||
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return m.client.SetEx(key, bs, duration)
|
||||
}
|
||||
|
||||
func (m *_mem) Del(ctx context.Context, keys ...string) error {
|
||||
m.client.Delete(keys...)
|
||||
return nil
|
||||
}
|
||||
71
pkg/cache/cache_redis.go
vendored
Normal file
71
pkg/cache/cache_redis.go
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
)
|
||||
|
||||
type _redis struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
result, err := r.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(result), nil
|
||||
}
|
||||
|
||||
func (r *_redis) GetScan(ctx context.Context, key string) Scanner {
|
||||
return newScanner(r.Get(ctx, key))
|
||||
}
|
||||
|
||||
func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
|
||||
result, err := r.client.GetEx(ctx, key, duration).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []byte(result), nil
|
||||
}
|
||||
|
||||
func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
|
||||
return newScanner(r.GetEx(ctx, key, duration))
|
||||
}
|
||||
|
||||
func (r *_redis) Set(ctx context.Context, key string, value any) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = r.client.SetEX(ctx, key, bs, duration).Result()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *_redis) Del(ctx context.Context, keys ...string) error {
|
||||
return r.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
60
pkg/cache/client.go
vendored
Normal file
60
pkg/cache/client.go
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// todo: config this
|
||||
Prefix = "sys:uauth:"
|
||||
)
|
||||
|
||||
type Cache interface {
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
GetScan(ctx context.Context, key string) Scanner
|
||||
GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
|
||||
GetExScan(ctx context.Context, key string, duration time.Duration) Scanner
|
||||
// Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
|
||||
Set(ctx context.Context, key string, value any) error
|
||||
// SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
|
||||
SetEx(ctx context.Context, key string, value any, duration time.Duration) error
|
||||
Del(ctx context.Context, keys ...string) error
|
||||
}
|
||||
|
||||
type Scanner interface {
|
||||
Scan(model any) error
|
||||
}
|
||||
|
||||
var (
|
||||
Client Cache
|
||||
)
|
||||
|
||||
type encoded_value interface {
|
||||
MarshalBinary() ([]byte, error)
|
||||
}
|
||||
|
||||
type decoded_value interface {
|
||||
UnmarshalBinary(bs []byte) error
|
||||
}
|
||||
|
||||
func handleValue(value any) ([]byte, error) {
|
||||
var (
|
||||
bs []byte
|
||||
err error
|
||||
)
|
||||
|
||||
switch value.(type) {
|
||||
case []byte:
|
||||
return value.([]byte), nil
|
||||
}
|
||||
|
||||
if imp, ok := value.(encoded_value); ok {
|
||||
bs, err = imp.MarshalBinary()
|
||||
} else {
|
||||
bs, err = json.Marshal(value)
|
||||
}
|
||||
|
||||
return bs, err
|
||||
}
|
||||
7
pkg/cache/error.go
vendored
Normal file
7
pkg/cache/error.go
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
package cache
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrorKeyNotFound = errors.New("key not found")
|
||||
)
|
||||
74
pkg/cache/init.go
vendored
Normal file
74
pkg/cache/init.go
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitea.com/loveuer/gredis"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"net/url"
|
||||
"strings"
|
||||
"uauth/tool"
|
||||
)
|
||||
|
||||
func Init(uri string) (err error) {
|
||||
Client, err = New(uri)
|
||||
return err
|
||||
}
|
||||
|
||||
func New(uri string) (Cache, error) {
|
||||
var (
|
||||
client Cache
|
||||
err error
|
||||
)
|
||||
|
||||
strs := strings.Split(uri, "::")
|
||||
|
||||
switch strs[0] {
|
||||
case "memory":
|
||||
gc := gredis.NewGredis(1024 * 1024)
|
||||
client = &_mem{client: gc}
|
||||
case "lru":
|
||||
if client, err = newLRUCache(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "redis":
|
||||
var (
|
||||
ins *url.URL
|
||||
err error
|
||||
)
|
||||
|
||||
if len(strs) != 2 {
|
||||
return nil, fmt.Errorf("cache.Init: invalid cache uri: %s", uri)
|
||||
}
|
||||
|
||||
uri := strs[1]
|
||||
|
||||
if !strings.Contains(uri, "://") {
|
||||
uri = fmt.Sprintf("redis://%s", uri)
|
||||
}
|
||||
|
||||
if ins, err = url.Parse(uri); err != nil {
|
||||
return nil, fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", uri, err.Error())
|
||||
}
|
||||
|
||||
addr := ins.Host
|
||||
username := ins.User.Username()
|
||||
password, _ := ins.User.Password()
|
||||
|
||||
var rc *redis.Client
|
||||
rc = redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
|
||||
if err = rc.Ping(tool.Timeout(5)).Err(); err != nil {
|
||||
return nil, fmt.Errorf("cache.Init: redis ping err: %s", err.Error())
|
||||
}
|
||||
|
||||
client = &_redis{client: rc}
|
||||
default:
|
||||
return nil, fmt.Errorf("cache type %s not support", strs[0])
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
20
pkg/cache/scan.go
vendored
Normal file
20
pkg/cache/scan.go
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
package cache
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type scanner struct {
|
||||
err error
|
||||
bs []byte
|
||||
}
|
||||
|
||||
func (s *scanner) Scan(model any) error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
|
||||
return json.Unmarshal(s.bs, model)
|
||||
}
|
||||
|
||||
func newScanner(bs []byte, err error) *scanner {
|
||||
return &scanner{bs: bs, err: err}
|
||||
}
|
||||
118
pkg/middleware/auth/auth.go
Normal file
118
pkg/middleware/auth/auth.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/loveuer/nf"
|
||||
"github.com/loveuer/nf/nft/log"
|
||||
"github.com/loveuer/nf/nft/resp"
|
||||
"net/http"
|
||||
"time"
|
||||
"uauth/model"
|
||||
"uauth/pkg/cache"
|
||||
"uauth/tool"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
IgnoreFn func(c *nf.Ctx) bool
|
||||
TokenFn func(c *nf.Ctx) (string, bool)
|
||||
GetUserFn func(c *nf.Ctx, token string) (*model.User, error)
|
||||
NextOnError bool
|
||||
}
|
||||
|
||||
var (
|
||||
defaultIgnoreFn = func(c *nf.Ctx) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
defaultTokenFn = func(c *nf.Ctx) (string, bool) {
|
||||
var token string
|
||||
|
||||
if token = c.Request.Header.Get("Authorization"); token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
if token = c.Query("access_token"); token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
if token = c.Cookies("access_token"); token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
defaultGetUserFn = func(c *nf.Ctx, token string) (*model.User, error) {
|
||||
var (
|
||||
err error
|
||||
op = new(model.User)
|
||||
key = cache.Prefix + "token:" + token
|
||||
)
|
||||
|
||||
if err = cache.Client.GetExScan(tool.Timeout(3), key, 24*time.Hour).Scan(op); err != nil {
|
||||
if errors.Is(err, cache.ErrorKeyNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Error("[M] cache client get user by token key = %s, err = %s", key, err.Error())
|
||||
return nil, errors.New("Internal Server Error")
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
)
|
||||
|
||||
func New(cfgs ...*Config) nf.HandlerFunc {
|
||||
var cfg = &Config{}
|
||||
|
||||
if len(cfgs) > 0 && cfgs[0] != nil {
|
||||
cfg = cfgs[0]
|
||||
}
|
||||
|
||||
if cfg.IgnoreFn == nil {
|
||||
cfg.IgnoreFn = defaultIgnoreFn
|
||||
}
|
||||
|
||||
if cfg.TokenFn == nil {
|
||||
cfg.TokenFn = defaultTokenFn
|
||||
}
|
||||
|
||||
if cfg.GetUserFn == nil {
|
||||
cfg.GetUserFn = defaultGetUserFn
|
||||
}
|
||||
|
||||
return func(c *nf.Ctx) error {
|
||||
if cfg.IgnoreFn(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
token, ok := cfg.TokenFn(c)
|
||||
if !ok {
|
||||
if cfg.NextOnError {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
return resp.Resp401(c, nil, "请登录")
|
||||
}
|
||||
|
||||
op, err := cfg.GetUserFn(c, token)
|
||||
if err != nil {
|
||||
if cfg.NextOnError {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
if errors.Is(err, cache.ErrorKeyNotFound) {
|
||||
return c.Status(http.StatusUnauthorized).JSON(map[string]any{
|
||||
"status": 500,
|
||||
"msg": "用户认证信息不存在或已过期, 请重新登录",
|
||||
})
|
||||
}
|
||||
|
||||
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
|
||||
}
|
||||
|
||||
c.Locals("user", op)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
47
pkg/middleware/logger/logger.go
Normal file
47
pkg/middleware/logger/logger.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/loveuer/nf"
|
||||
"github.com/loveuer/nf/nft/log"
|
||||
"github.com/loveuer/nf/nft/resp"
|
||||
"strconv"
|
||||
"time"
|
||||
"uauth/tool"
|
||||
)
|
||||
|
||||
func New() nf.HandlerFunc {
|
||||
|
||||
return func(c *nf.Ctx) error {
|
||||
var (
|
||||
now = time.Now()
|
||||
logFn func(msg string, data ...any)
|
||||
ip = c.IP()
|
||||
)
|
||||
|
||||
traceId := c.Context().Value(nf.TraceKey)
|
||||
c.Locals(nf.TraceKey, traceId)
|
||||
|
||||
err := c.Next()
|
||||
|
||||
c.Writer.Header().Set(nf.TraceKey, fmt.Sprint(traceId))
|
||||
|
||||
status, _ := strconv.Atoi(c.Writer.Header().Get(resp.RealStatusHeader))
|
||||
duration := time.Since(now)
|
||||
|
||||
msg := fmt.Sprintf("%s | %15s | %d[%3d] | %s | %6s | %s", traceId, ip, c.StatusCode, status, tool.HumanDuration(duration.Nanoseconds()), c.Method(), c.Path())
|
||||
|
||||
switch {
|
||||
case status >= 500:
|
||||
logFn = log.Error
|
||||
case status >= 400:
|
||||
logFn = log.Warn
|
||||
default:
|
||||
logFn = log.Info
|
||||
}
|
||||
|
||||
logFn(msg)
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
110
pkg/rbac/create.go
Normal file
110
pkg/rbac/create.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package rbac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"strings"
|
||||
"uauth/model"
|
||||
)
|
||||
|
||||
func (u *RBAC) newScope(ctx context.Context, code, label, parent string) (*model.Scope, error) {
|
||||
s := &model.Scope{Code: code, Label: label, Parent: parent}
|
||||
if err := u.store.Session(ctx).Create(s).Error; err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (u *RBAC) GetScopeGroup(ctx context.Context, name string) (*model.Scope, error) {
|
||||
scope := new(model.Scope)
|
||||
err := u.store.Session(ctx).Where("name = ?", name).Take(scope).Error
|
||||
|
||||
return scope, err
|
||||
}
|
||||
|
||||
func (u *RBAC) newRole(ctx context.Context, code, label, parent string, privileges ...*model.Privilege) (*model.Role, error) {
|
||||
ps := lo.FilterMap(
|
||||
privileges,
|
||||
func(p *model.Privilege, _ int) (string, bool) {
|
||||
if p == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return p.Code, p.Code != ""
|
||||
},
|
||||
)
|
||||
|
||||
r := &model.Role{
|
||||
Code: code,
|
||||
Label: label,
|
||||
Parent: parent,
|
||||
PrivilegeCodes: ps,
|
||||
}
|
||||
|
||||
if err := u.store.Session(ctx).Create(r).Error; err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (u *RBAC) GetRole(ctx context.Context, name string) (*model.Role, error) {
|
||||
var r model.Role
|
||||
if err := u.store.Session(ctx).Take(&r, "name = ?", name).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (u *RBAC) newPrivilege(ctx context.Context, code, label string, parent string, scope string) (*model.Privilege, error) {
|
||||
p := &model.Privilege{Code: code, Label: label, Parent: parent, Scope: scope}
|
||||
|
||||
codes := strings.SplitN(code, ":", 4)
|
||||
if len(codes) != 4 {
|
||||
return nil, fmt.Errorf("invalid code format")
|
||||
}
|
||||
|
||||
wailcard := false
|
||||
for _, item := range codes {
|
||||
if item == "*" {
|
||||
wailcard = true
|
||||
}
|
||||
|
||||
if wailcard && item != "*" {
|
||||
return nil, fmt.Errorf("invalid code format")
|
||||
}
|
||||
|
||||
if len(item) > 8 {
|
||||
return nil, fmt.Errorf("invalid code format: code snippet too long")
|
||||
}
|
||||
}
|
||||
|
||||
if codes[0] != "*" {
|
||||
if _, err := u.GetScopeGroup(ctx, codes[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := u.store.Session(ctx).Create(p).Error; err != nil {
|
||||
return p, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (u *RBAC) newUser(ctx context.Context, target *model.User) (*model.User, error) {
|
||||
result := u.store.Session(ctx).
|
||||
Create(target)
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected != 1 {
|
||||
return nil, fmt.Errorf("invalid rows affected")
|
||||
}
|
||||
|
||||
return target, nil
|
||||
}
|
||||
83
pkg/rbac/rbac.go
Normal file
83
pkg/rbac/rbac.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package rbac
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"uauth/model"
|
||||
"uauth/pkg/cache"
|
||||
"uauth/pkg/store"
|
||||
"uauth/tool"
|
||||
)
|
||||
|
||||
type RBAC struct {
|
||||
cache cache.Cache
|
||||
store store.Store
|
||||
}
|
||||
|
||||
var (
|
||||
Default *RBAC
|
||||
)
|
||||
|
||||
func New(store store.Store, cache cache.Cache) (*RBAC, error) {
|
||||
var (
|
||||
err error
|
||||
u = &RBAC{
|
||||
store: store,
|
||||
cache: cache,
|
||||
}
|
||||
rootPrivilege *model.Privilege
|
||||
rootRole *model.Role
|
||||
rootScope *model.Scope
|
||||
rootUser *model.User
|
||||
)
|
||||
|
||||
if err = u.store.Session(tool.Timeout()).AutoMigrate(
|
||||
&model.Scope{},
|
||||
&model.Privilege{},
|
||||
&model.Role{},
|
||||
&model.User{},
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("urbac migrate err: %w", err)
|
||||
}
|
||||
|
||||
if rootPrivilege, err = u.newPrivilege(tool.Timeout(), "*:*:*:*", "admin", "", "*"); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if rootRole, err = u.newRole(tool.Timeout(), "admin", "管理员", "", rootPrivilege); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if rootScope, err = u.newScope(tool.Timeout(), "*", "全部", ""); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rootUser = &model.User{
|
||||
Username: "admin",
|
||||
Password: tool.NewPassword("123456"),
|
||||
Status: model.StatusActive,
|
||||
Nickname: "管理员",
|
||||
RoleNames: []string{rootRole.Code},
|
||||
}
|
||||
|
||||
if _, err = u.newUser(tool.Timeout(3), rootUser); err != nil {
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
_ = rootScope
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func Init(store store.Store, cache cache.Cache) (err error) {
|
||||
Default, err = New(store, cache)
|
||||
return err
|
||||
}
|
||||
9
pkg/sqlType/error.go
Normal file
9
pkg/sqlType/error.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package sqlType
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrConvertScanVal = errors.New("convert scan val to str err")
|
||||
ErrInvalidScanVal = errors.New("scan val invalid")
|
||||
ErrConvertVal = errors.New("convert err")
|
||||
)
|
||||
76
pkg/sqlType/jsonb.go
Normal file
76
pkg/sqlType/jsonb.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package sqlType
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
)
|
||||
|
||||
type JSONB struct {
|
||||
Val pgtype.JSONB
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func NewJSONB(v interface{}) JSONB {
|
||||
j := new(JSONB)
|
||||
j.Val = pgtype.JSONB{}
|
||||
if err := j.Val.Set(v); err == nil {
|
||||
j.Valid = true
|
||||
return *j
|
||||
}
|
||||
|
||||
return *j
|
||||
}
|
||||
|
||||
func (j *JSONB) Set(value interface{}) error {
|
||||
if err := j.Val.Set(value); err != nil {
|
||||
j.Valid = false
|
||||
return err
|
||||
}
|
||||
|
||||
j.Valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JSONB) Bind(model interface{}) error {
|
||||
return j.Val.AssignTo(model)
|
||||
}
|
||||
|
||||
func (j *JSONB) Scan(value interface{}) error {
|
||||
j.Val = pgtype.JSONB{}
|
||||
if value == nil {
|
||||
j.Valid = false
|
||||
return nil
|
||||
}
|
||||
|
||||
j.Valid = true
|
||||
|
||||
return j.Val.Scan(value)
|
||||
}
|
||||
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j.Valid {
|
||||
return j.Val.Value()
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (j JSONB) MarshalJSON() ([]byte, error) {
|
||||
if j.Valid {
|
||||
return j.Val.MarshalJSON()
|
||||
}
|
||||
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
|
||||
func (j *JSONB) UnmarshalJSON(b []byte) error {
|
||||
if string(b) == "null" {
|
||||
j.Valid = false
|
||||
return j.Val.UnmarshalJSON(b)
|
||||
}
|
||||
|
||||
return j.Val.UnmarshalJSON(b)
|
||||
}
|
||||
71
pkg/sqlType/num_slice.go
Normal file
71
pkg/sqlType/num_slice.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package sqlType
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
type NumSlice[T ~int | ~int64 | ~uint | ~uint64] []T
|
||||
|
||||
func (n *NumSlice[T]) Scan(val interface{}) error {
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return ErrConvertScanVal
|
||||
}
|
||||
|
||||
length := len(str)
|
||||
|
||||
if length <= 0 {
|
||||
*n = make(NumSlice[T], 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
if str[0] != '{' || str[length-1] != '}' {
|
||||
return ErrInvalidScanVal
|
||||
}
|
||||
|
||||
str = str[1 : length-1]
|
||||
if len(str) == 0 {
|
||||
*n = make(NumSlice[T], 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
numStrs := strings.Split(str, ",")
|
||||
nums := make([]T, len(numStrs))
|
||||
|
||||
for idx := range numStrs {
|
||||
num, err := cast.ToInt64E(strings.TrimSpace(numStrs[idx]))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: can't convert to %T", ErrConvertVal, T(0))
|
||||
}
|
||||
|
||||
nums[idx] = T(num)
|
||||
}
|
||||
|
||||
*n = nums
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n NumSlice[T]) Value() (driver.Value, error) {
|
||||
if n == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
if len(n) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
ss := make([]string, 0, len(n))
|
||||
for idx := range n {
|
||||
ss = append(ss, strconv.Itoa(int(n[idx])))
|
||||
}
|
||||
|
||||
s := strings.Join(ss, ", ")
|
||||
|
||||
return fmt.Sprintf("{%s}", s), nil
|
||||
}
|
||||
107
pkg/sqlType/string_slice.go
Normal file
107
pkg/sqlType/string_slice.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package sqlType
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type StrSlice []string
|
||||
|
||||
func (s *StrSlice) Scan(val interface{}) error {
|
||||
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return ErrConvertScanVal
|
||||
}
|
||||
|
||||
if len(str) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bs := make([]byte, 0, 128)
|
||||
bss := make([]byte, 0, 2*len(str))
|
||||
|
||||
quoteCount := 0
|
||||
|
||||
for idx := 1; idx < len(str)-1; idx++ {
|
||||
quote := str[idx]
|
||||
switch quote {
|
||||
case 44:
|
||||
if quote == 44 && str[idx-1] != 92 && quoteCount == 0 {
|
||||
if len(bs) > 0 {
|
||||
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
|
||||
bs = append([]byte{34}, bs...)
|
||||
bs = append(bs, 34)
|
||||
}
|
||||
|
||||
bss = append(bss, bs...)
|
||||
bss = append(bss, 44)
|
||||
}
|
||||
bs = bs[:0]
|
||||
} else {
|
||||
bs = append(bs, quote)
|
||||
}
|
||||
case 34:
|
||||
if str[idx-1] != 92 {
|
||||
quoteCount = (quoteCount + 1) % 2
|
||||
}
|
||||
bs = append(bs, quote)
|
||||
default:
|
||||
bs = append(bs, quote)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if len(bs) > 0 {
|
||||
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
|
||||
bs = append([]byte{34}, bs...)
|
||||
bs = append(bs, 34)
|
||||
}
|
||||
|
||||
bss = append(bss, bs...)
|
||||
} else {
|
||||
if len(bss) > 2 {
|
||||
bss = bss[:len(bss)-2]
|
||||
}
|
||||
}
|
||||
|
||||
bss = append([]byte{'['}, append(bss, ']')...)
|
||||
|
||||
if err := json.Unmarshal(bss, s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s StrSlice) Value() (driver.Value, error) {
|
||||
if s == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
if len(s) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
encoder := json.NewEncoder(buf)
|
||||
encoder.SetEscapeHTML(false)
|
||||
|
||||
if err := encoder.Encode(s); err != nil {
|
||||
return "{}", err
|
||||
}
|
||||
|
||||
bs := buf.Bytes()
|
||||
|
||||
bs[0] = '{'
|
||||
|
||||
if bs[len(bs)-1] == 10 {
|
||||
bs = bs[:len(bs)-1]
|
||||
}
|
||||
|
||||
bs[len(bs)-1] = '}'
|
||||
|
||||
return string(bs), nil
|
||||
}
|
||||
46
pkg/store/client.go
Normal file
46
pkg/store/client.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"uauth/tool"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
Session(ctx context.Context) *gorm.DB
|
||||
}
|
||||
|
||||
var (
|
||||
Default *Client
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
ctx context.Context
|
||||
cli *gorm.DB
|
||||
ttype string
|
||||
debug bool
|
||||
}
|
||||
|
||||
func (c *Client) Type() string {
|
||||
return c.ttype
|
||||
}
|
||||
|
||||
func (c *Client) Session(ctx context.Context) *gorm.DB {
|
||||
if ctx == nil {
|
||||
ctx = tool.Timeout(30)
|
||||
}
|
||||
|
||||
session := c.cli.Session(&gorm.Session{Context: ctx})
|
||||
|
||||
if c.debug {
|
||||
session = session.Debug()
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
d, _ := c.cli.DB()
|
||||
d.Close()
|
||||
}
|
||||
59
pkg/store/init.go
Normal file
59
pkg/store/init.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Debug bool
|
||||
}
|
||||
|
||||
func New(uri string, configs ...Config) (*Client, error) {
|
||||
strs := strings.Split(uri, "::")
|
||||
|
||||
if len(strs) != 2 {
|
||||
return nil, fmt.Errorf("db.Init: opt db uri invalid: %s", uri)
|
||||
}
|
||||
|
||||
c := &Client{ttype: strs[0]}
|
||||
|
||||
if len(configs) > 0 && configs[0].Debug {
|
||||
c.debug = true
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
dsn = strs[1]
|
||||
)
|
||||
|
||||
switch strs[0] {
|
||||
case "sqlite":
|
||||
c.cli, err = gorm.Open(sqlite.Open(dsn))
|
||||
case "mysql":
|
||||
c.cli, err = gorm.Open(mysql.Open(dsn))
|
||||
case "postgres":
|
||||
c.cli, err = gorm.Open(postgres.Open(dsn))
|
||||
default:
|
||||
return nil, fmt.Errorf("db type only support: [sqlite, mysql, postgres], unsupported db type: %s", strs[0])
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.Init: open %s with dsn:%s, err: %w", strs[0], dsn, err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func Init(uri string, configs ...Config) (err error) {
|
||||
if Default, err = New(uri, configs...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user