diff --git a/model/privilege.go b/model/privilege.go index 092273d..fdadf37 100644 --- a/model/privilege.go +++ b/model/privilege.go @@ -1,5 +1,10 @@ package model +import ( + "fmt" + "strings" +) + type Privilege struct { CreatedAt int64 `json:"created_at" gorm:"column:created_at;autoCreateTime:milli"` UpdatedAt int64 `json:"updated_at" gorm:"column:updated_at;autoUpdateTime:milli"` @@ -9,3 +14,39 @@ type Privilege struct { Parent string `json:"parent" gorm:"column:parent"` Scope string `json:"scope" gorm:"column:scope"` } + +func (u *Urbac) newPrivilege(ctx context.Context, code, label string, parentId uint64, scope string) (*Privilege, error) { + p := &Privilege{Code: code, Label: label, ParentId: parentId, Scope: scope} + + codes := strings.SplitN(code, ":", 4) + if len(codes) != 4 { + return nil, fmt.Errorf("invalid code format") + } + + wailcard := false + for _, item := range codes { + if item == "*" { + wailcard = true + } + + if wailcard && item != "*" { + return nil, fmt.Errorf("invalid code format") + } + + if len(item) > 8 { + return nil, fmt.Errorf("invalid code format: code snippet too long") + } + } + + if codes[0] != "*" { + if _, err := u.GetScopeGroup(ctx, codes[0]); err != nil { + return nil, err + } + } + + if err := u.store.Session(ctx).Create(p).Error; err != nil { + return nil, err + } + + return p, nil +} diff --git a/model/role.go b/model/role.go index 29bab75..1f10fe7 100644 --- a/model/role.go +++ b/model/role.go @@ -11,3 +11,38 @@ type Role struct { Parent string `json:"parent" gorm:"column:parent"` PrivilegeCodes sqlType.StrSlice `json:"privilege_codes" gorm:"column:privilege_codes"` } + +func (u *Urbac) newRole(ctx context.Context, name, label, parent string, privileges ...*Privilege) (*Role, error) { + ps := lo.FilterMap( + privileges, + func(p *Privilege, _ int) (string, bool) { + if p == nil { + return "", false + } + + return p.Code, p.Code != "" + }, + ) + + r := &Role{ + Name: name, + Label: label, + Parent: parent, + PrivilegeCodes: ps, + } + + if err := u.store.Session(ctx).Create(r).Error; err != nil { + return nil, err + } + + return r, nil +} + +func (u *Urbac) GetRole(ctx context.Context, name string) (*Role, error) { + var r Role + if err := u.store.Session(ctx).Take(&r, "name = ?", name).Error; err != nil { + return nil, err + } + + return &r, nil +} diff --git a/model/scope.go b/model/scope.go index 533bd3a..7129c53 100644 --- a/model/scope.go +++ b/model/scope.go @@ -9,3 +9,19 @@ type Scope struct { Label string `json:"label" gorm:"column:label;type:varchar(64)"` Parent string `json:"parent" gorm:"column:parent;type:varchar(8)"` } + +func (u *Urbac) newScope(ctx context.Context, name, label, parent string) (*Scope, error) { + s := &Scope{Name: name, Label: label, Parent: parent} + if err := u.store.Session(ctx).Create(s).Error; err != nil { + return nil, err + } + + return s, nil +} + +func (u *Urbac) GetScopeGroup(ctx context.Context, name string) (*Scope, error) { + scope := new(Scope) + err := u.store.Session(ctx).Where("name = ?", name).Take(scope).Error + + return scope, err +} diff --git a/rbac/rbac.go b/rbac/rbac.go new file mode 100644 index 0000000..f3cc169 --- /dev/null +++ b/rbac/rbac.go @@ -0,0 +1,64 @@ +package rbac + +import ( + "fmt" + "strings" + "uauth/internal/tool" +) + +type Urbac struct { + cache cache.Cache + store store.Store +} + +type Option func(u *Urbac) + +func New(opts ...Option) (*Urbac, error) { + var ( + err error + u = &Urbac{} + rootPrivilege *Privilege + rootRole *Role + rootScope *Scope + ) + + for _, opt := range opts { + opt(u) + } + + if u.store == nil { + if u.store, err = store.NewSqliteStore("sqlite.db"); err != nil { + return nil, err + } + } + + if u.cache == nil { + if u.cache, err = cache.NewRedisCache("redis://10.220.10.15:6379"); err != nil { + return nil, err + } + } + + if err = u.store.Session(tool.Timeout()).AutoMigrate(&Scope{}, &Privilege{}, &Role{}); err != nil { + return nil, fmt.Errorf("urbac migrate err: %w", err) + } + + if rootPrivilege, err = u.newPrivilege(tool.Timeout(), "*:*:*:*", "admin", 0, "*"); err != nil { + if !strings.Contains(strings.ToLower(err.Error()), "unique") { + return nil, err + } + } + + if rootRole, err = u.newRole(tool.Timeout(), "admin", "管理员", "", rootPrivilege); err != nil { + if !strings.Contains(strings.ToLower(err.Error()), "unique") { + return nil, err + } + } + + if rootScope, err = u.newScope(tool.Timeout(), "*", "全部", ""); err != nil { + if !strings.Contains(strings.ToLower(err.Error()), "unique") { + return nil, err + } + } + + return u, nil +}