structure: 确定基本结构(保持基本形式, 采用组合)

This commit is contained in:
loveuer
2024-11-01 17:47:33 +08:00
parent 9e8a47a7c6
commit 56cfd42bb9
52 changed files with 1003 additions and 176 deletions

116
pkg/cache/cache_lru.go vendored Normal file
View File

@@ -0,0 +1,116 @@
package cache
import (
"context"
"github.com/hashicorp/golang-lru/v2/expirable"
_ "github.com/hashicorp/golang-lru/v2/expirable"
"time"
)
var _ Cache = (*_lru)(nil)
type _lru struct {
client *expirable.LRU[string, *_lru_value]
}
type _lru_value struct {
duration time.Duration
last time.Time
bs []byte
}
func (l *_lru) Get(ctx context.Context, key string) ([]byte, error) {
v, ok := l.client.Get(key)
if !ok {
return nil, ErrorKeyNotFound
}
if v.duration == 0 {
return v.bs, nil
}
if time.Now().Sub(v.last) > v.duration {
l.client.Remove(key)
return nil, ErrorKeyNotFound
}
return v.bs, nil
}
func (l *_lru) GetScan(ctx context.Context, key string) Scanner {
return newScanner(l.Get(ctx, key))
}
func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
v, ok := l.client.Get(key)
if !ok {
return nil, ErrorKeyNotFound
}
if v.duration == 0 {
return v.bs, nil
}
now := time.Now()
if now.Sub(v.last) > v.duration {
l.client.Remove(key)
return nil, ErrorKeyNotFound
}
l.client.Add(key, &_lru_value{
duration: duration,
last: now,
bs: v.bs,
})
return v.bs, nil
}
func (l *_lru) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
return newScanner(l.GetEx(ctx, key, duration))
}
func (l *_lru) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
return err
}
l.client.Add(key, &_lru_value{
duration: 0,
last: time.Now(),
bs: bs,
})
return nil
}
func (l *_lru) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
return err
}
l.client.Add(key, &_lru_value{
duration: duration,
last: time.Now(),
bs: bs,
})
return nil
}
func (l *_lru) Del(ctx context.Context, keys ...string) error {
for _, key := range keys {
l.client.Remove(key)
}
return nil
}
func newLRUCache() (Cache, error) {
client := expirable.NewLRU[string, *_lru_value](1024*1024, nil, 0)
return &_lru{client: client}, nil
}

81
pkg/cache/cache_memory.go vendored Normal file
View File

@@ -0,0 +1,81 @@
package cache
import (
"context"
"errors"
"fmt"
"time"
"gitea.com/loveuer/gredis"
)
var _ Cache = (*_mem)(nil)
type _mem struct {
client *gredis.Gredis
}
func (m *_mem) GetScan(ctx context.Context, key string) Scanner {
return newScanner(m.Get(ctx, key))
}
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
return newScanner(m.GetEx(ctx, key, duration))
}
func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) {
v, err := m.client.Get(key)
if err != nil {
if errors.Is(err, gredis.ErrKeyNotFound) {
return nil, ErrorKeyNotFound
}
return nil, err
}
bs, ok := v.([]byte)
if !ok {
return nil, fmt.Errorf("invalid value type=%T", v)
}
return bs, nil
}
func (m *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
v, err := m.client.GetEx(key, duration)
if err != nil {
if errors.Is(err, gredis.ErrKeyNotFound) {
return nil, ErrorKeyNotFound
}
return nil, err
}
bs, ok := v.([]byte)
if !ok {
return nil, fmt.Errorf("invalid value type=%T", v)
}
return bs, nil
}
func (m *_mem) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
return err
}
return m.client.Set(key, bs)
}
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
return err
}
return m.client.SetEx(key, bs, duration)
}
func (m *_mem) Del(ctx context.Context, keys ...string) error {
m.client.Delete(keys...)
return nil
}

71
pkg/cache/cache_redis.go vendored Normal file
View File

@@ -0,0 +1,71 @@
package cache
import (
"context"
"errors"
"github.com/go-redis/redis/v8"
"time"
)
type _redis struct {
client *redis.Client
}
func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) {
result, err := r.client.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return []byte(result), nil
}
func (r *_redis) GetScan(ctx context.Context, key string) Scanner {
return newScanner(r.Get(ctx, key))
}
func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
result, err := r.client.GetEx(ctx, key, duration).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return []byte(result), nil
}
func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
return newScanner(r.GetEx(ctx, key, duration))
}
func (r *_redis) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
return err
}
_, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result()
return err
}
func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
return err
}
_, err = r.client.SetEX(ctx, key, bs, duration).Result()
return err
}
func (r *_redis) Del(ctx context.Context, keys ...string) error {
return r.client.Del(ctx, keys...).Err()
}

60
pkg/cache/client.go vendored Normal file
View File

@@ -0,0 +1,60 @@
package cache
import (
"context"
"encoding/json"
"time"
)
const (
// todo: config this
Prefix = "sys:uauth:"
)
type Cache interface {
Get(ctx context.Context, key string) ([]byte, error)
GetScan(ctx context.Context, key string) Scanner
GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
GetExScan(ctx context.Context, key string, duration time.Duration) Scanner
// Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
Set(ctx context.Context, key string, value any) error
// SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
SetEx(ctx context.Context, key string, value any, duration time.Duration) error
Del(ctx context.Context, keys ...string) error
}
type Scanner interface {
Scan(model any) error
}
var (
Client Cache
)
type encoded_value interface {
MarshalBinary() ([]byte, error)
}
type decoded_value interface {
UnmarshalBinary(bs []byte) error
}
func handleValue(value any) ([]byte, error) {
var (
bs []byte
err error
)
switch value.(type) {
case []byte:
return value.([]byte), nil
}
if imp, ok := value.(encoded_value); ok {
bs, err = imp.MarshalBinary()
} else {
bs, err = json.Marshal(value)
}
return bs, err
}

7
pkg/cache/error.go vendored Normal file
View File

@@ -0,0 +1,7 @@
package cache
import "errors"
var (
ErrorKeyNotFound = errors.New("key not found")
)

74
pkg/cache/init.go vendored Normal file
View File

@@ -0,0 +1,74 @@
package cache
import (
"fmt"
"gitea.com/loveuer/gredis"
"github.com/go-redis/redis/v8"
"net/url"
"strings"
"uauth/tool"
)
func Init(uri string) (err error) {
Client, err = New(uri)
return err
}
func New(uri string) (Cache, error) {
var (
client Cache
err error
)
strs := strings.Split(uri, "::")
switch strs[0] {
case "memory":
gc := gredis.NewGredis(1024 * 1024)
client = &_mem{client: gc}
case "lru":
if client, err = newLRUCache(); err != nil {
return nil, err
}
case "redis":
var (
ins *url.URL
err error
)
if len(strs) != 2 {
return nil, fmt.Errorf("cache.Init: invalid cache uri: %s", uri)
}
uri := strs[1]
if !strings.Contains(uri, "://") {
uri = fmt.Sprintf("redis://%s", uri)
}
if ins, err = url.Parse(uri); err != nil {
return nil, fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", uri, err.Error())
}
addr := ins.Host
username := ins.User.Username()
password, _ := ins.User.Password()
var rc *redis.Client
rc = redis.NewClient(&redis.Options{
Addr: addr,
Username: username,
Password: password,
})
if err = rc.Ping(tool.Timeout(5)).Err(); err != nil {
return nil, fmt.Errorf("cache.Init: redis ping err: %s", err.Error())
}
client = &_redis{client: rc}
default:
return nil, fmt.Errorf("cache type %s not support", strs[0])
}
return client, nil
}

20
pkg/cache/scan.go vendored Normal file
View File

@@ -0,0 +1,20 @@
package cache
import "encoding/json"
type scanner struct {
err error
bs []byte
}
func (s *scanner) Scan(model any) error {
if s.err != nil {
return s.err
}
return json.Unmarshal(s.bs, model)
}
func newScanner(bs []byte, err error) *scanner {
return &scanner{bs: bs, err: err}
}

118
pkg/middleware/auth/auth.go Normal file
View File

@@ -0,0 +1,118 @@
package auth
import (
"errors"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"github.com/loveuer/nf/nft/resp"
"net/http"
"time"
"uauth/model"
"uauth/pkg/cache"
"uauth/tool"
)
type Config struct {
IgnoreFn func(c *nf.Ctx) bool
TokenFn func(c *nf.Ctx) (string, bool)
GetUserFn func(c *nf.Ctx, token string) (*model.User, error)
NextOnError bool
}
var (
defaultIgnoreFn = func(c *nf.Ctx) bool {
return false
}
defaultTokenFn = func(c *nf.Ctx) (string, bool) {
var token string
if token = c.Request.Header.Get("Authorization"); token != "" {
return token, true
}
if token = c.Query("access_token"); token != "" {
return token, true
}
if token = c.Cookies("access_token"); token != "" {
return token, true
}
return "", false
}
defaultGetUserFn = func(c *nf.Ctx, token string) (*model.User, error) {
var (
err error
op = new(model.User)
key = cache.Prefix + "token:" + token
)
if err = cache.Client.GetExScan(tool.Timeout(3), key, 24*time.Hour).Scan(op); err != nil {
if errors.Is(err, cache.ErrorKeyNotFound) {
return nil, err
}
log.Error("[M] cache client get user by token key = %s, err = %s", key, err.Error())
return nil, errors.New("Internal Server Error")
}
return op, nil
}
)
func New(cfgs ...*Config) nf.HandlerFunc {
var cfg = &Config{}
if len(cfgs) > 0 && cfgs[0] != nil {
cfg = cfgs[0]
}
if cfg.IgnoreFn == nil {
cfg.IgnoreFn = defaultIgnoreFn
}
if cfg.TokenFn == nil {
cfg.TokenFn = defaultTokenFn
}
if cfg.GetUserFn == nil {
cfg.GetUserFn = defaultGetUserFn
}
return func(c *nf.Ctx) error {
if cfg.IgnoreFn(c) {
return c.Next()
}
token, ok := cfg.TokenFn(c)
if !ok {
if cfg.NextOnError {
return c.Next()
}
return resp.Resp401(c, nil, "请登录")
}
op, err := cfg.GetUserFn(c, token)
if err != nil {
if cfg.NextOnError {
return c.Next()
}
if errors.Is(err, cache.ErrorKeyNotFound) {
return c.Status(http.StatusUnauthorized).JSON(map[string]any{
"status": 500,
"msg": "用户认证信息不存在或已过期, 请重新登录",
})
}
return c.Status(http.StatusInternalServerError).SendString("Internal Server Error")
}
c.Locals("user", op)
return c.Next()
}
}

View File

@@ -0,0 +1,47 @@
package logger
import (
"fmt"
"github.com/loveuer/nf"
"github.com/loveuer/nf/nft/log"
"github.com/loveuer/nf/nft/resp"
"strconv"
"time"
"uauth/tool"
)
func New() nf.HandlerFunc {
return func(c *nf.Ctx) error {
var (
now = time.Now()
logFn func(msg string, data ...any)
ip = c.IP()
)
traceId := c.Context().Value(nf.TraceKey)
c.Locals(nf.TraceKey, traceId)
err := c.Next()
c.Writer.Header().Set(nf.TraceKey, fmt.Sprint(traceId))
status, _ := strconv.Atoi(c.Writer.Header().Get(resp.RealStatusHeader))
duration := time.Since(now)
msg := fmt.Sprintf("%s | %15s | %d[%3d] | %s | %6s | %s", traceId, ip, c.StatusCode, status, tool.HumanDuration(duration.Nanoseconds()), c.Method(), c.Path())
switch {
case status >= 500:
logFn = log.Error
case status >= 400:
logFn = log.Warn
default:
logFn = log.Info
}
logFn(msg)
return err
}
}

110
pkg/rbac/create.go Normal file
View File

@@ -0,0 +1,110 @@
package rbac
import (
"context"
"fmt"
"github.com/samber/lo"
"strings"
"uauth/model"
)
func (u *RBAC) newScope(ctx context.Context, code, label, parent string) (*model.Scope, error) {
s := &model.Scope{Code: code, Label: label, Parent: parent}
if err := u.store.Session(ctx).Create(s).Error; err != nil {
return s, err
}
return s, nil
}
func (u *RBAC) GetScopeGroup(ctx context.Context, name string) (*model.Scope, error) {
scope := new(model.Scope)
err := u.store.Session(ctx).Where("name = ?", name).Take(scope).Error
return scope, err
}
func (u *RBAC) newRole(ctx context.Context, code, label, parent string, privileges ...*model.Privilege) (*model.Role, error) {
ps := lo.FilterMap(
privileges,
func(p *model.Privilege, _ int) (string, bool) {
if p == nil {
return "", false
}
return p.Code, p.Code != ""
},
)
r := &model.Role{
Code: code,
Label: label,
Parent: parent,
PrivilegeCodes: ps,
}
if err := u.store.Session(ctx).Create(r).Error; err != nil {
return r, err
}
return r, nil
}
func (u *RBAC) GetRole(ctx context.Context, name string) (*model.Role, error) {
var r model.Role
if err := u.store.Session(ctx).Take(&r, "name = ?", name).Error; err != nil {
return nil, err
}
return &r, nil
}
func (u *RBAC) newPrivilege(ctx context.Context, code, label string, parent string, scope string) (*model.Privilege, error) {
p := &model.Privilege{Code: code, Label: label, Parent: parent, Scope: scope}
codes := strings.SplitN(code, ":", 4)
if len(codes) != 4 {
return nil, fmt.Errorf("invalid code format")
}
wailcard := false
for _, item := range codes {
if item == "*" {
wailcard = true
}
if wailcard && item != "*" {
return nil, fmt.Errorf("invalid code format")
}
if len(item) > 8 {
return nil, fmt.Errorf("invalid code format: code snippet too long")
}
}
if codes[0] != "*" {
if _, err := u.GetScopeGroup(ctx, codes[0]); err != nil {
return nil, err
}
}
if err := u.store.Session(ctx).Create(p).Error; err != nil {
return p, err
}
return p, nil
}
func (u *RBAC) newUser(ctx context.Context, target *model.User) (*model.User, error) {
result := u.store.Session(ctx).
Create(target)
if result.Error != nil {
return nil, result.Error
}
if result.RowsAffected != 1 {
return nil, fmt.Errorf("invalid rows affected")
}
return target, nil
}

83
pkg/rbac/rbac.go Normal file
View File

@@ -0,0 +1,83 @@
package rbac
import (
"fmt"
"strings"
"uauth/model"
"uauth/pkg/cache"
"uauth/pkg/store"
"uauth/tool"
)
type RBAC struct {
cache cache.Cache
store store.Store
}
var (
Default *RBAC
)
func New(store store.Store, cache cache.Cache) (*RBAC, error) {
var (
err error
u = &RBAC{
store: store,
cache: cache,
}
rootPrivilege *model.Privilege
rootRole *model.Role
rootScope *model.Scope
rootUser *model.User
)
if err = u.store.Session(tool.Timeout()).AutoMigrate(
&model.Scope{},
&model.Privilege{},
&model.Role{},
&model.User{},
); err != nil {
return nil, fmt.Errorf("urbac migrate err: %w", err)
}
if rootPrivilege, err = u.newPrivilege(tool.Timeout(), "*:*:*:*", "admin", "", "*"); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
return nil, err
}
}
if rootRole, err = u.newRole(tool.Timeout(), "admin", "管理员", "", rootPrivilege); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
return nil, err
}
}
if rootScope, err = u.newScope(tool.Timeout(), "*", "全部", ""); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
return nil, err
}
}
rootUser = &model.User{
Username: "admin",
Password: tool.NewPassword("123456"),
Status: model.StatusActive,
Nickname: "管理员",
RoleNames: []string{rootRole.Code},
}
if _, err = u.newUser(tool.Timeout(3), rootUser); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "unique") {
return nil, err
}
}
_ = rootScope
return u, nil
}
func Init(store store.Store, cache cache.Cache) (err error) {
Default, err = New(store, cache)
return err
}

9
pkg/sqlType/error.go Normal file
View File

@@ -0,0 +1,9 @@
package sqlType
import "errors"
var (
ErrConvertScanVal = errors.New("convert scan val to str err")
ErrInvalidScanVal = errors.New("scan val invalid")
ErrConvertVal = errors.New("convert err")
)

76
pkg/sqlType/jsonb.go Normal file
View File

@@ -0,0 +1,76 @@
package sqlType
import (
"database/sql/driver"
"encoding/json"
"github.com/jackc/pgtype"
)
type JSONB struct {
Val pgtype.JSONB
Valid bool
}
func NewJSONB(v interface{}) JSONB {
j := new(JSONB)
j.Val = pgtype.JSONB{}
if err := j.Val.Set(v); err == nil {
j.Valid = true
return *j
}
return *j
}
func (j *JSONB) Set(value interface{}) error {
if err := j.Val.Set(value); err != nil {
j.Valid = false
return err
}
j.Valid = true
return nil
}
func (j *JSONB) Bind(model interface{}) error {
return j.Val.AssignTo(model)
}
func (j *JSONB) Scan(value interface{}) error {
j.Val = pgtype.JSONB{}
if value == nil {
j.Valid = false
return nil
}
j.Valid = true
return j.Val.Scan(value)
}
func (j JSONB) Value() (driver.Value, error) {
if j.Valid {
return j.Val.Value()
}
return nil, nil
}
func (j JSONB) MarshalJSON() ([]byte, error) {
if j.Valid {
return j.Val.MarshalJSON()
}
return json.Marshal(nil)
}
func (j *JSONB) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
j.Valid = false
return j.Val.UnmarshalJSON(b)
}
return j.Val.UnmarshalJSON(b)
}

71
pkg/sqlType/num_slice.go Normal file
View File

@@ -0,0 +1,71 @@
package sqlType
import (
"database/sql/driver"
"fmt"
"strconv"
"strings"
"github.com/spf13/cast"
)
type NumSlice[T ~int | ~int64 | ~uint | ~uint64] []T
func (n *NumSlice[T]) Scan(val interface{}) error {
str, ok := val.(string)
if !ok {
return ErrConvertScanVal
}
length := len(str)
if length <= 0 {
*n = make(NumSlice[T], 0)
return nil
}
if str[0] != '{' || str[length-1] != '}' {
return ErrInvalidScanVal
}
str = str[1 : length-1]
if len(str) == 0 {
*n = make(NumSlice[T], 0)
return nil
}
numStrs := strings.Split(str, ",")
nums := make([]T, len(numStrs))
for idx := range numStrs {
num, err := cast.ToInt64E(strings.TrimSpace(numStrs[idx]))
if err != nil {
return fmt.Errorf("%w: can't convert to %T", ErrConvertVal, T(0))
}
nums[idx] = T(num)
}
*n = nums
return nil
}
func (n NumSlice[T]) Value() (driver.Value, error) {
if n == nil {
return "{}", nil
}
if len(n) == 0 {
return "{}", nil
}
ss := make([]string, 0, len(n))
for idx := range n {
ss = append(ss, strconv.Itoa(int(n[idx])))
}
s := strings.Join(ss, ", ")
return fmt.Sprintf("{%s}", s), nil
}

107
pkg/sqlType/string_slice.go Normal file
View File

@@ -0,0 +1,107 @@
package sqlType
import (
"bytes"
"database/sql/driver"
"encoding/json"
)
type StrSlice []string
func (s *StrSlice) Scan(val interface{}) error {
str, ok := val.(string)
if !ok {
return ErrConvertScanVal
}
if len(str) < 2 {
return nil
}
bs := make([]byte, 0, 128)
bss := make([]byte, 0, 2*len(str))
quoteCount := 0
for idx := 1; idx < len(str)-1; idx++ {
quote := str[idx]
switch quote {
case 44:
if quote == 44 && str[idx-1] != 92 && quoteCount == 0 {
if len(bs) > 0 {
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
bs = append([]byte{34}, bs...)
bs = append(bs, 34)
}
bss = append(bss, bs...)
bss = append(bss, 44)
}
bs = bs[:0]
} else {
bs = append(bs, quote)
}
case 34:
if str[idx-1] != 92 {
quoteCount = (quoteCount + 1) % 2
}
bs = append(bs, quote)
default:
bs = append(bs, quote)
}
}
if len(bs) > 0 {
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
bs = append([]byte{34}, bs...)
bs = append(bs, 34)
}
bss = append(bss, bs...)
} else {
if len(bss) > 2 {
bss = bss[:len(bss)-2]
}
}
bss = append([]byte{'['}, append(bss, ']')...)
if err := json.Unmarshal(bss, s); err != nil {
return err
}
return nil
}
func (s StrSlice) Value() (driver.Value, error) {
if s == nil {
return "{}", nil
}
if len(s) == 0 {
return "{}", nil
}
buf := &bytes.Buffer{}
encoder := json.NewEncoder(buf)
encoder.SetEscapeHTML(false)
if err := encoder.Encode(s); err != nil {
return "{}", err
}
bs := buf.Bytes()
bs[0] = '{'
if bs[len(bs)-1] == 10 {
bs = bs[:len(bs)-1]
}
bs[len(bs)-1] = '}'
return string(bs), nil
}

46
pkg/store/client.go Normal file
View File

@@ -0,0 +1,46 @@
package store
import (
"context"
"uauth/tool"
"gorm.io/gorm"
)
type Store interface {
Session(ctx context.Context) *gorm.DB
}
var (
Default *Client
)
type Client struct {
ctx context.Context
cli *gorm.DB
ttype string
debug bool
}
func (c *Client) Type() string {
return c.ttype
}
func (c *Client) Session(ctx context.Context) *gorm.DB {
if ctx == nil {
ctx = tool.Timeout(30)
}
session := c.cli.Session(&gorm.Session{Context: ctx})
if c.debug {
session = session.Debug()
}
return session
}
func (c *Client) Close() {
d, _ := c.cli.DB()
d.Close()
}

59
pkg/store/init.go Normal file
View File

@@ -0,0 +1,59 @@
package store
import (
"fmt"
"strings"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type Config struct {
Debug bool
}
func New(uri string, configs ...Config) (*Client, error) {
strs := strings.Split(uri, "::")
if len(strs) != 2 {
return nil, fmt.Errorf("db.Init: opt db uri invalid: %s", uri)
}
c := &Client{ttype: strs[0]}
if len(configs) > 0 && configs[0].Debug {
c.debug = true
}
var (
err error
dsn = strs[1]
)
switch strs[0] {
case "sqlite":
c.cli, err = gorm.Open(sqlite.Open(dsn))
case "mysql":
c.cli, err = gorm.Open(mysql.Open(dsn))
case "postgres":
c.cli, err = gorm.Open(postgres.Open(dsn))
default:
return nil, fmt.Errorf("db type only support: [sqlite, mysql, postgres], unsupported db type: %s", strs[0])
}
if err != nil {
return nil, fmt.Errorf("db.Init: open %s with dsn:%s, err: %w", strs[0], dsn, err)
}
return c, nil
}
func Init(uri string, configs ...Config) (err error) {
if Default, err = New(uri, configs...); err != nil {
return err
}
return nil
}