Compare commits

...

2 Commits

Author SHA1 Message Date
loveuer
cef1775811 wip: 继续 2024-11-01 17:53:47 +08:00
loveuer
58fae2e090 dev: rbac 2024-11-01 17:47:33 +08:00
9 changed files with 459 additions and 0 deletions

View File

@ -0,0 +1,9 @@
package sqlType
import "errors"
var (
ErrConvertScanVal = errors.New("convert scan val to str err")
ErrInvalidScanVal = errors.New("scan val invalid")
ErrConvertVal = errors.New("convert err")
)

76
internal/sqlType/jsonb.go Normal file
View File

@ -0,0 +1,76 @@
package sqlType
import (
"database/sql/driver"
"encoding/json"
"github.com/jackc/pgtype"
)
type JSONB struct {
Val pgtype.JSONB
Valid bool
}
func NewJSONB(v interface{}) JSONB {
j := new(JSONB)
j.Val = pgtype.JSONB{}
if err := j.Val.Set(v); err == nil {
j.Valid = true
return *j
}
return *j
}
func (j *JSONB) Set(value interface{}) error {
if err := j.Val.Set(value); err != nil {
j.Valid = false
return err
}
j.Valid = true
return nil
}
func (j *JSONB) Bind(model interface{}) error {
return j.Val.AssignTo(model)
}
func (j *JSONB) Scan(value interface{}) error {
j.Val = pgtype.JSONB{}
if value == nil {
j.Valid = false
return nil
}
j.Valid = true
return j.Val.Scan(value)
}
func (j JSONB) Value() (driver.Value, error) {
if j.Valid {
return j.Val.Value()
}
return nil, nil
}
func (j JSONB) MarshalJSON() ([]byte, error) {
if j.Valid {
return j.Val.MarshalJSON()
}
return json.Marshal(nil)
}
func (j *JSONB) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
j.Valid = false
return j.Val.UnmarshalJSON(b)
}
return j.Val.UnmarshalJSON(b)
}

View File

@ -0,0 +1,71 @@
package sqlType
import (
"database/sql/driver"
"fmt"
"strconv"
"strings"
"github.com/spf13/cast"
)
type NumSlice[T ~int | ~int64 | ~uint | ~uint64] []T
func (n *NumSlice[T]) Scan(val interface{}) error {
str, ok := val.(string)
if !ok {
return ErrConvertScanVal
}
length := len(str)
if length <= 0 {
*n = make(NumSlice[T], 0)
return nil
}
if str[0] != '{' || str[length-1] != '}' {
return ErrInvalidScanVal
}
str = str[1 : length-1]
if len(str) == 0 {
*n = make(NumSlice[T], 0)
return nil
}
numStrs := strings.Split(str, ",")
nums := make([]T, len(numStrs))
for idx := range numStrs {
num, err := cast.ToInt64E(strings.TrimSpace(numStrs[idx]))
if err != nil {
return fmt.Errorf("%w: can't convert to %T", ErrConvertVal, T(0))
}
nums[idx] = T(num)
}
*n = nums
return nil
}
func (n NumSlice[T]) Value() (driver.Value, error) {
if n == nil {
return "{}", nil
}
if len(n) == 0 {
return "{}", nil
}
ss := make([]string, 0, len(n))
for idx := range n {
ss = append(ss, strconv.Itoa(int(n[idx])))
}
s := strings.Join(ss, ", ")
return fmt.Sprintf("{%s}", s), nil
}

View File

@ -0,0 +1,107 @@
package sqlType
import (
"bytes"
"database/sql/driver"
"encoding/json"
)
type StrSlice []string
func (s *StrSlice) Scan(val interface{}) error {
str, ok := val.(string)
if !ok {
return ErrConvertScanVal
}
if len(str) < 2 {
return nil
}
bs := make([]byte, 0, 128)
bss := make([]byte, 0, 2*len(str))
quoteCount := 0
for idx := 1; idx < len(str)-1; idx++ {
quote := str[idx]
switch quote {
case 44:
if quote == 44 && str[idx-1] != 92 && quoteCount == 0 {
if len(bs) > 0 {
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
bs = append([]byte{34}, bs...)
bs = append(bs, 34)
}
bss = append(bss, bs...)
bss = append(bss, 44)
}
bs = bs[:0]
} else {
bs = append(bs, quote)
}
case 34:
if str[idx-1] != 92 {
quoteCount = (quoteCount + 1) % 2
}
bs = append(bs, quote)
default:
bs = append(bs, quote)
}
}
if len(bs) > 0 {
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
bs = append([]byte{34}, bs...)
bs = append(bs, 34)
}
bss = append(bss, bs...)
} else {
if len(bss) > 2 {
bss = bss[:len(bss)-2]
}
}
bss = append([]byte{'['}, append(bss, ']')...)
if err := json.Unmarshal(bss, s); err != nil {
return err
}
return nil
}
func (s StrSlice) Value() (driver.Value, error) {
if s == nil {
return "{}", nil
}
if len(s) == 0 {
return "{}", nil
}
buf := &bytes.Buffer{}
encoder := json.NewEncoder(buf)
encoder.SetEscapeHTML(false)
if err := encoder.Encode(s); err != nil {
return "{}", err
}
bs := buf.Bytes()
bs[0] = '{'
if bs[len(bs)-1] == 10 {
bs = bs[:len(bs)-1]
}
bs[len(bs)-1] = '}'
return string(bs), nil
}

52
model/privilege.go Normal file
View File

@ -0,0 +1,52 @@
package model
import (
"fmt"
"strings"
)
type Privilege struct {
CreatedAt int64 `json:"created_at" gorm:"column:created_at;autoCreateTime:milli"`
UpdatedAt int64 `json:"updated_at" gorm:"column:updated_at;autoUpdateTime:milli"`
DeletedAt int64 `json:"deleted_at" gorm:"index;column:deleted_at;default:0"`
Code string `json:"code" gorm:"column:code;primaryKey"`
Label string `json:"label" gorm:"column:label"`
Parent string `json:"parent" gorm:"column:parent"`
Scope string `json:"scope" gorm:"column:scope"`
}
func (u *Urbac) newPrivilege(ctx context.Context, code, label string, parentId uint64, scope string) (*Privilege, error) {
p := &Privilege{Code: code, Label: label, ParentId: parentId, Scope: scope}
codes := strings.SplitN(code, ":", 4)
if len(codes) != 4 {
return nil, fmt.Errorf("invalid code format")
}
wailcard := false
for _, item := range codes {
if item == "*" {
wailcard = true
}
if wailcard && item != "*" {
return nil, fmt.Errorf("invalid code format")
}
if len(item) > 8 {
return nil, fmt.Errorf("invalid code format: code snippet too long")
}
}
if codes[0] != "*" {
if _, err := u.GetScopeGroup(ctx, codes[0]); err != nil {
return nil, err
}
}
if err := u.store.Session(ctx).Create(p).Error; err != nil {
return nil, err
}
return p, nil
}

48
model/role.go Normal file
View File

@ -0,0 +1,48 @@
package model
import "uauth/internal/sqlType"
type Role struct {
CreatedAt int64 `json:"created_at" gorm:"column:created_at;autoCreateTime:milli"`
UpdatedAt int64 `json:"updated_at" gorm:"column:updated_at;autoUpdateTime:milli"`
DeletedAt int64 `json:"deleted_at" gorm:"index;column:deleted_at;default:0"`
Code string `json:"code" gorm:"primaryKey;column:code"`
Label string `json:"label" gorm:"column:label"`
Parent string `json:"parent" gorm:"column:parent"`
PrivilegeCodes sqlType.StrSlice `json:"privilege_codes" gorm:"column:privilege_codes"`
}
func (u *Urbac) newRole(ctx context.Context, name, label, parent string, privileges ...*Privilege) (*Role, error) {
ps := lo.FilterMap(
privileges,
func(p *Privilege, _ int) (string, bool) {
if p == nil {
return "", false
}
return p.Code, p.Code != ""
},
)
r := &Role{
Name: name,
Label: label,
Parent: parent,
PrivilegeCodes: ps,
}
if err := u.store.Session(ctx).Create(r).Error; err != nil {
return nil, err
}
return r, nil
}
func (u *Urbac) GetRole(ctx context.Context, name string) (*Role, error) {
var r Role
if err := u.store.Session(ctx).Take(&r, "name = ?", name).Error; err != nil {
return nil, err
}
return &r, nil
}

27
model/scope.go Normal file
View File

@ -0,0 +1,27 @@
package model
// 用户权限作用域
type Scope struct {
Code string `json:"code" gorm:"column:code;type:varchar(8);not null;primaryKey"`
CreatedAt int64 `json:"created_at" gorm:"column:created_at;autoCreateTime:milli"`
UpdatedAt int64 `json:"updated_at" gorm:"column:updated_at;autoUpdateTime:milli"`
DeletedAt int64 `json:"deleted_at" gorm:"index;column:deleted_at;default:0"`
Label string `json:"label" gorm:"column:label;type:varchar(64)"`
Parent string `json:"parent" gorm:"column:parent;type:varchar(8)"`
}
func (u *Urbac) newScope(ctx context.Context, name, label, parent string) (*Scope, error) {
s := &Scope{Name: name, Label: label, Parent: parent}
if err := u.store.Session(ctx).Create(s).Error; err != nil {
return nil, err
}
return s, nil
}
func (u *Urbac) GetScopeGroup(ctx context.Context, name string) (*Scope, error) {
scope := new(Scope)
err := u.store.Session(ctx).Where("name = ?", name).Take(scope).Error
return scope, err
}

View File

@ -6,6 +6,7 @@ import (
"github.com/loveuer/nf/nft/log" "github.com/loveuer/nf/nft/log"
"time" "time"
"uauth/internal/opt" "uauth/internal/opt"
"uauth/internal/sqlType"
) )
type Status int64 type Status int64
@ -29,6 +30,9 @@ type User struct {
CreatedByName string `json:"created_by_name" gorm:"column:created_by_name;type:varchar(64)"` CreatedByName string `json:"created_by_name" gorm:"column:created_by_name;type:varchar(64)"`
LoginAt int64 `json:"login_at" gorm:"-"` LoginAt int64 `json:"login_at" gorm:"-"`
Roles []*Role `json:"-" gorm:"-"`
RoleNames sqlType.StrSlice `json:"role_names" column:"role_names"`
} }
func (u *User) JwtEncode() (token string, err error) { func (u *User) JwtEncode() (token string, err error) {
@ -40,6 +44,7 @@ func (u *User) JwtEncode() (token string, err error) {
"username": u.Username, "username": u.Username,
"nickname": u.Nickname, "nickname": u.Nickname,
"status": u.Status, "status": u.Status,
"avatar": u.Avatar,
"login_at": now.UnixMilli(), "login_at": now.UnixMilli(),
}) })

64
rbac/rbac.go Normal file
View File

@ -0,0 +1,64 @@
package rbac
import (
"fmt"
"strings"
"uauth/internal/tool"
)
type Urbac struct {
cache cache.Cache
store store.Store
}
type Option func(u *Urbac)
func New(opts ...Option) (*Urbac, error) {
var (
err error
u = &Urbac{}
rootPrivilege *Privilege
rootRole *Role
rootScope *Scope
)
for _, opt := range opts {
opt(u)
}
if u.store == nil {
if u.store, err = store.NewSqliteStore("sqlite.db"); err != nil {
return nil, err
}
}
if u.cache == nil {
if u.cache, err = cache.NewRedisCache("redis://10.220.10.15:6379"); err != nil {
return nil, err
}
}
if err = u.store.Session(tool.Timeout()).AutoMigrate(&Scope{}, &Privilege{}, &Role{}); err != nil {
return nil, fmt.Errorf("urbac migrate err: %w", err)
}
if rootPrivilege, err = u.newPrivilege(tool.Timeout(), "*:*:*:*", "admin", 0, "*"); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
return nil, err
}
}
if rootRole, err = u.newRole(tool.Timeout(), "admin", "管理员", "", rootPrivilege); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
return nil, err
}
}
if rootScope, err = u.newScope(tool.Timeout(), "*", "全部", ""); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
return nil, err
}
}
return u, nil
}