wip: 登录和认证
This commit is contained in:
98
pkg/database/cache/cache.go
vendored
Normal file
98
pkg/database/cache/cache.go
vendored
Normal file
@ -0,0 +1,98 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type encoded_value interface {
|
||||
MarshalBinary() ([]byte, error)
|
||||
}
|
||||
|
||||
type decoded_value interface {
|
||||
UnmarshalBinary(bs []byte) error
|
||||
}
|
||||
|
||||
type Scanner interface {
|
||||
Scan(model any) error
|
||||
}
|
||||
|
||||
type scan struct {
|
||||
err error
|
||||
bs []byte
|
||||
}
|
||||
|
||||
func newScan(bs []byte, err error) *scan {
|
||||
return &scan{bs: bs, err: err}
|
||||
}
|
||||
|
||||
func (s *scan) Scan(model any) error {
|
||||
if s.err != nil {
|
||||
return s.err
|
||||
}
|
||||
|
||||
return unmarshaler(s.bs, model)
|
||||
}
|
||||
|
||||
type Cache interface {
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
Gets(ctx context.Context, keys ...string) ([][]byte, error)
|
||||
GetScan(ctx context.Context, key string) Scanner
|
||||
GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
|
||||
GetExScan(ctx context.Context, key string, duration time.Duration) Scanner
|
||||
// Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
|
||||
Set(ctx context.Context, key string, value any) error
|
||||
Sets(ctx context.Context, vm map[string]any) error
|
||||
// SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
|
||||
SetEx(ctx context.Context, key string, value any, duration time.Duration) error
|
||||
Del(ctx context.Context, keys ...string) error
|
||||
GetDel(ctx context.Context, key string) ([]byte, error)
|
||||
GetDelScan(ctx context.Context, key string) Scanner
|
||||
Close()
|
||||
Client() any
|
||||
}
|
||||
|
||||
var (
|
||||
lock = &sync.Mutex{}
|
||||
marshaler func(data any) ([]byte, error) = json.Marshal
|
||||
unmarshaler func(data []byte, model any) error = json.Unmarshal
|
||||
ErrorKeyNotFound = errors.New("key not found")
|
||||
ErrorStoreFailed = errors.New("store failed")
|
||||
Default Cache
|
||||
)
|
||||
|
||||
func handleValue(value any) ([]byte, error) {
|
||||
var (
|
||||
bs []byte
|
||||
err error
|
||||
)
|
||||
|
||||
switch val := value.(type) {
|
||||
case []byte:
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if imp, ok := value.(encoded_value); ok {
|
||||
bs, err = imp.MarshalBinary()
|
||||
} else {
|
||||
bs, err = marshaler(value)
|
||||
}
|
||||
|
||||
return bs, err
|
||||
}
|
||||
|
||||
func SetMarshaler(fn func(data any) ([]byte, error)) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
marshaler = fn
|
||||
}
|
||||
|
||||
func SetUnmarshaler(fn func(data []byte, model any) error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
unmarshaler = fn
|
||||
}
|
155
pkg/database/cache/memory.go
vendored
Normal file
155
pkg/database/cache/memory.go
vendored
Normal file
@ -0,0 +1,155 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/dgraph-io/ristretto/v2"
|
||||
)
|
||||
|
||||
var _ Cache = (*_mem)(nil)
|
||||
|
||||
type _mem struct {
|
||||
ctx context.Context
|
||||
cache *ristretto.Cache[string, []byte]
|
||||
}
|
||||
|
||||
func newMemory(ctx context.Context, ins *ristretto.Cache[string, []byte]) Cache {
|
||||
return &_mem{
|
||||
ctx: ctx,
|
||||
cache: ins,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *_mem) Client() any {
|
||||
return m.cache
|
||||
}
|
||||
|
||||
func (c *_mem) Close() {
|
||||
c.cache.Close()
|
||||
}
|
||||
|
||||
func (c *_mem) Del(ctx context.Context, keys ...string) error {
|
||||
for _, key := range keys {
|
||||
c.cache.Del(key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *_mem) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, ok := c.cache.Get(key)
|
||||
if !ok {
|
||||
return val, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *_mem) GetDel(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := c.Get(ctx, key)
|
||||
if err != nil {
|
||||
return val, err
|
||||
}
|
||||
|
||||
c.cache.Del(key)
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
func (c *_mem) GetDelScan(ctx context.Context, key string) Scanner {
|
||||
val, err := c.GetDel(ctx, key)
|
||||
return newScan(val, err)
|
||||
}
|
||||
|
||||
func (c *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
|
||||
val, err := c.Get(ctx, key)
|
||||
if err != nil {
|
||||
return val, err
|
||||
}
|
||||
|
||||
c.cache.SetWithTTL(key, val, 1, duration)
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
|
||||
val, err := m.GetEx(ctx, key, duration)
|
||||
return newScan(val, err)
|
||||
}
|
||||
|
||||
func (m *_mem) GetScan(ctx context.Context, key string) Scanner {
|
||||
val, err := m.Get(ctx, key)
|
||||
return newScan(val, err)
|
||||
}
|
||||
|
||||
func (m *_mem) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
|
||||
vals := make([][]byte, 0, len(keys))
|
||||
|
||||
for _, key := range keys {
|
||||
val, err := m.Get(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrorKeyNotFound) {
|
||||
continue
|
||||
}
|
||||
|
||||
return vals, err
|
||||
}
|
||||
|
||||
vals = append(vals, val)
|
||||
}
|
||||
|
||||
if len(vals) != len(keys) {
|
||||
return vals, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return vals, nil
|
||||
}
|
||||
|
||||
func (m *_mem) Set(ctx context.Context, key string, value any) error {
|
||||
val, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := m.cache.Set(key, val, 1); !ok {
|
||||
return ErrorStoreFailed
|
||||
}
|
||||
|
||||
m.cache.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
|
||||
val, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := m.cache.SetWithTTL(key, val, 1, duration); !ok {
|
||||
return ErrorStoreFailed
|
||||
}
|
||||
|
||||
m.cache.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *_mem) Sets(ctx context.Context, vm map[string]any) error {
|
||||
for key, value := range vm {
|
||||
val, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := m.cache.Set(key, val, 1); !ok {
|
||||
return ErrorStoreFailed
|
||||
}
|
||||
}
|
||||
|
||||
m.cache.Wait()
|
||||
|
||||
return nil
|
||||
}
|
84
pkg/database/cache/new.go
vendored
Normal file
84
pkg/database/cache/new.go
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"gitea.loveuer.com/yizhisec/packages/tool"
|
||||
"github.com/dgraph-io/ristretto/v2"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func New(opts ...Option) (Cache, error) {
|
||||
var (
|
||||
err error
|
||||
cfg = &option{
|
||||
ctx: context.Background(),
|
||||
}
|
||||
)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
if cfg.redis != nil {
|
||||
var (
|
||||
ins *url.URL
|
||||
client *redis.Client
|
||||
)
|
||||
|
||||
if ins, err = url.Parse(*cfg.redis); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
username := ins.User.Username()
|
||||
password, _ := ins.User.Password()
|
||||
|
||||
client = redis.NewClient(&redis.Options{
|
||||
Addr: ins.Host,
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
|
||||
if err = client.Ping(tool.CtxTimeout(cfg.ctx, 5)).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newRedis(cfg.ctx, client), nil
|
||||
}
|
||||
|
||||
if cfg.memory {
|
||||
var (
|
||||
ins *ristretto.Cache[string, []byte]
|
||||
)
|
||||
|
||||
if ins, err = ristretto.NewCache(&ristretto.Config[string, []byte]{
|
||||
NumCounters: 1e7, // number of keys to track frequency of (10M).
|
||||
MaxCost: 1 << 30, // maximum cost of cache (1GB).
|
||||
BufferItems: 64, // number of keys per Get buffer.
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newMemory(cfg.ctx, ins), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid cache option")
|
||||
}
|
||||
|
||||
func Init(opts ...Option) (err error) {
|
||||
opt := &option{}
|
||||
|
||||
for _, optFn := range opts {
|
||||
optFn(opt)
|
||||
}
|
||||
|
||||
if opt.memory {
|
||||
Default, err = New(opts...)
|
||||
return err
|
||||
}
|
||||
|
||||
Default, err = New(opts...)
|
||||
return err
|
||||
}
|
108
pkg/database/cache/new_test.go
vendored
Normal file
108
pkg/database/cache/new_test.go
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
/* if err := Init(WithRedis("127.0.0.1", 6379, "", "MyPassw0rd")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
if err := Default.Set(t.Context(), "zyp:haha", &User{
|
||||
Name: "cache",
|
||||
Age: 18,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := Default.GetDelScan(t.Context(), "zyp:haha")
|
||||
u := new(User)
|
||||
|
||||
if err := s.Scan(u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("%#v", *u)
|
||||
|
||||
if err := Default.SetEx(t.Context(), "zyp:haha", &User{
|
||||
Name: "redis",
|
||||
Age: 2,
|
||||
}, time.Hour); err != nil {
|
||||
t.Fatal(err)
|
||||
}*/
|
||||
}
|
||||
|
||||
func TestNoAuth(t *testing.T) {
|
||||
//if err := Init(WithRedis("10.125.1.28", 6379, "", "")); err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
//
|
||||
//type User struct {
|
||||
// Name string `json:"name"`
|
||||
// Age int `json:"age"`
|
||||
//}
|
||||
//
|
||||
//if err := Default.Set(t.Context(), "zyp:haha", &User{
|
||||
// Name: "cache",
|
||||
// Age: 18,
|
||||
//}); err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
//
|
||||
//s := Default.GetDelScan(t.Context(), "zyp:haha")
|
||||
//u := new(User)
|
||||
//
|
||||
//if err := s.Scan(u); err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
//
|
||||
//t.Logf("%#v", *u)
|
||||
//
|
||||
//if err := Default.SetEx(t.Context(), "zyp:haha", &User{
|
||||
// Name: "redis",
|
||||
// Age: 2,
|
||||
//}, time.Hour); err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
}
|
||||
|
||||
func TestMemoryDefault(t *testing.T) {
|
||||
if err := Init(WithMemory()); err != nil {
|
||||
t.Fatal("init err:", err)
|
||||
}
|
||||
|
||||
if err := Default.Set(t.Context(), "123", "123"); err != nil {
|
||||
t.Fatal("set err:", err)
|
||||
}
|
||||
|
||||
val, err := Default.Get(t.Context(), "123")
|
||||
if err != nil {
|
||||
t.Fatal("get err:", err)
|
||||
}
|
||||
|
||||
t.Logf("%s", val)
|
||||
}
|
||||
|
||||
func TestMemoryNew(t *testing.T) {
|
||||
client, err := New(WithMemory())
|
||||
if err != nil {
|
||||
t.Fatal("init err:", err)
|
||||
}
|
||||
|
||||
if err := client.Set(t.Context(), "123", "123"); err != nil {
|
||||
t.Fatal("set err:", err)
|
||||
}
|
||||
|
||||
val, err := client.Get(t.Context(), "123")
|
||||
if err != nil {
|
||||
t.Fatal("get err:", err)
|
||||
}
|
||||
|
||||
t.Logf("%s", val)
|
||||
}
|
55
pkg/database/cache/option.go
vendored
Normal file
55
pkg/database/cache/option.go
vendored
Normal file
@ -0,0 +1,55 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type option struct {
|
||||
ctx context.Context
|
||||
redis *string
|
||||
memory bool
|
||||
}
|
||||
|
||||
type Option func(*option)
|
||||
|
||||
func WithCtx(ctx context.Context) Option {
|
||||
return func(c *option) {
|
||||
if ctx != nil {
|
||||
c.ctx = ctx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithRedis(host string, port int, username, password string) Option {
|
||||
return func(c *option) {
|
||||
uri := fmt.Sprintf("redis://%s:%d", host, port)
|
||||
if username != "" || password != "" {
|
||||
uri = fmt.Sprintf("redis://%s:%s@%s:%d", username, password, host, port)
|
||||
}
|
||||
|
||||
c.redis = &uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithRedisURI(uri string) Option {
|
||||
return func(c *option) {
|
||||
ins, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ins.Scheme != "redis" {
|
||||
return
|
||||
}
|
||||
|
||||
c.redis = &uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithMemory() Option {
|
||||
return func(c *option) {
|
||||
c.memory = true
|
||||
}
|
||||
}
|
153
pkg/database/cache/redis.go
vendored
Normal file
153
pkg/database/cache/redis.go
vendored
Normal file
@ -0,0 +1,153 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.loveuer.com/yizhisec/packages/tool"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
var _ Cache = (*_redis)(nil)
|
||||
|
||||
type _redis struct {
|
||||
sync.Mutex
|
||||
ctx context.Context
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
func (r *_redis) Client() any {
|
||||
return r.client
|
||||
}
|
||||
|
||||
func newRedis(ctx context.Context, client *redis.Client) Cache {
|
||||
r := &_redis{ctx: ctx, client: client}
|
||||
|
||||
go func() {
|
||||
<-r.ctx.Done()
|
||||
if client != nil {
|
||||
r.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *_redis) GetDel(ctx context.Context, key string) ([]byte, error) {
|
||||
s, err := r.client.GetDel(ctx, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tool.StringToBytes(s), nil
|
||||
}
|
||||
|
||||
func (r *_redis) GetDelScan(ctx context.Context, key string) Scanner {
|
||||
bs, err := r.GetDel(ctx, key)
|
||||
return newScan(bs, err)
|
||||
}
|
||||
|
||||
func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
result, err := r.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tool.StringToBytes(result), nil
|
||||
}
|
||||
|
||||
func (r *_redis) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
|
||||
result, err := r.client.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tool.Map(
|
||||
result,
|
||||
func(item any, index int) []byte {
|
||||
return tool.StringToBytes(cast.ToString(item))
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (r *_redis) GetScan(ctx context.Context, key string) Scanner {
|
||||
return newScan(r.Get(ctx, key))
|
||||
}
|
||||
|
||||
func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
|
||||
result, err := r.client.GetEx(ctx, key, duration).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tool.StringToBytes(result), nil
|
||||
}
|
||||
|
||||
func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
|
||||
return newScan(r.GetEx(ctx, key, duration))
|
||||
}
|
||||
|
||||
func (r *_redis) Set(ctx context.Context, key string, value any) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *_redis) Sets(ctx context.Context, values map[string]any) error {
|
||||
vm := make(map[string]any)
|
||||
for k, v := range values {
|
||||
bs, err := handleValue(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vm[k] = bs
|
||||
}
|
||||
|
||||
return r.client.MSet(ctx, vm).Err()
|
||||
}
|
||||
|
||||
func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
|
||||
bs, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = r.client.SetEX(ctx, key, bs, duration).Result()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *_redis) Del(ctx context.Context, keys ...string) error {
|
||||
return r.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
func (r *_redis) Close() {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
_ = r.client.Close()
|
||||
r.client = nil
|
||||
}
|
49
pkg/database/db/db.go
Normal file
49
pkg/database/db/db.go
Normal file
@ -0,0 +1,49 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Debug bool
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
type DB interface {
|
||||
Session(ctx context.Context, configs ...Config) *gorm.DB
|
||||
}
|
||||
|
||||
type db struct {
|
||||
tx *gorm.DB
|
||||
}
|
||||
|
||||
var (
|
||||
Default DB
|
||||
)
|
||||
|
||||
func (db *db) Session(ctx context.Context, configs ...Config) *gorm.DB {
|
||||
var (
|
||||
sc = &gorm.Session{Context: ctx}
|
||||
session *gorm.DB
|
||||
)
|
||||
|
||||
if len(configs) == 0 {
|
||||
session = db.tx.Session(sc)
|
||||
return session
|
||||
}
|
||||
|
||||
cfg := configs[0]
|
||||
|
||||
if cfg.DryRun {
|
||||
sc.DryRun = true
|
||||
}
|
||||
|
||||
session = db.tx.Session(sc)
|
||||
|
||||
if cfg.Debug {
|
||||
session = session.Debug()
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
48
pkg/database/db/new.go
Normal file
48
pkg/database/db/new.go
Normal file
@ -0,0 +1,48 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var defaultSqlite = "data.db"
|
||||
|
||||
func New(opts ...OptionFn) (DB, error) {
|
||||
var (
|
||||
err error
|
||||
conf = &config{
|
||||
sqlite: &defaultSqlite,
|
||||
}
|
||||
tx *gorm.DB
|
||||
)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(conf)
|
||||
}
|
||||
|
||||
if conf.mysql != nil {
|
||||
tx, err = gorm.Open(mysql.Open(*conf.mysql))
|
||||
goto CHECK
|
||||
}
|
||||
|
||||
if conf.pg != nil {
|
||||
tx, err = gorm.Open(postgres.Open(*conf.pg))
|
||||
goto CHECK
|
||||
}
|
||||
|
||||
tx, err = gorm.Open(sqlite.Open(*conf.sqlite))
|
||||
|
||||
CHECK:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &db{tx: tx}, nil
|
||||
}
|
||||
|
||||
func Init(opts ...OptionFn) (err error) {
|
||||
Default, err = New(opts...)
|
||||
return err
|
||||
}
|
25
pkg/database/db/new_test.go
Normal file
25
pkg/database/db/new_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
//mdb, err := New(WithMysql("127.0.0.1", 3306, "root", "MyPassw0rd", "mydb"))
|
||||
//if err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
//
|
||||
//type User struct {
|
||||
// Id uint64 `gorm:"primaryKey"`
|
||||
// Username string `gorm:"unique"`
|
||||
//}
|
||||
//
|
||||
//if err = mdb.Session(t.Context()).AutoMigrate(&User{}); err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
//
|
||||
//if err = mdb.Session(t.Context()).Create(&User{Username: "zyp"}).Error; err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
}
|
45
pkg/database/db/option.go
Normal file
45
pkg/database/db/option.go
Normal file
@ -0,0 +1,45 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
ctx context.Context
|
||||
mysql *string
|
||||
pg *string
|
||||
sqlite *string
|
||||
}
|
||||
|
||||
type OptionFn func(*config)
|
||||
|
||||
func WithCtx(ctx context.Context) OptionFn {
|
||||
return func(c *config) {
|
||||
if ctx != nil {
|
||||
c.ctx = ctx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithMysql(host string, port int, user string, password string, database string) OptionFn {
|
||||
return func(c *config) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, database)
|
||||
c.mysql = &dsn
|
||||
}
|
||||
}
|
||||
|
||||
func WithPg(host string, port int, user string, password string, database string) OptionFn {
|
||||
return func(c *config) {
|
||||
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai", host, user, password, database, port)
|
||||
c.pg = &dsn
|
||||
}
|
||||
}
|
||||
|
||||
func WithSqlite(path string) OptionFn {
|
||||
return func(c *config) {
|
||||
if path != "" {
|
||||
c.sqlite = &path
|
||||
}
|
||||
}
|
||||
}
|
@ -1,9 +1,53 @@
|
||||
package logger
|
||||
|
||||
import "github.com/gofiber/fiber/v3"
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/spf13/cast"
|
||||
"loveuer/utodo/pkg/logger"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
func New() fiber.Handler {
|
||||
pool := sync.Pool{
|
||||
New: func() any {
|
||||
return &strings.Builder{}
|
||||
},
|
||||
}
|
||||
|
||||
return func(c fiber.Ctx) error {
|
||||
return c.Next()
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
|
||||
duration := time.Since(start)
|
||||
method := c.Method()
|
||||
path := c.Path()
|
||||
status := c.Response().StatusCode()
|
||||
traceId := c.Context().Value(logger.CtxKey)
|
||||
|
||||
buf := pool.Get().(*strings.Builder)
|
||||
defer pool.Put(buf)
|
||||
|
||||
buf.Reset()
|
||||
|
||||
buf.WriteString("API | ")
|
||||
buf.WriteString(start.Format("2006-01-02T15:04:05"))
|
||||
buf.WriteString(" | ")
|
||||
buf.WriteString(method)
|
||||
buf.WriteString(" | ")
|
||||
buf.WriteString(path)
|
||||
buf.WriteString(" | ")
|
||||
buf.WriteString(duration.String())
|
||||
buf.WriteString(" | ")
|
||||
buf.WriteString(strconv.Itoa(status))
|
||||
buf.WriteString(" | ")
|
||||
buf.WriteString(cast.ToString(traceId))
|
||||
|
||||
fmt.Println(buf.String())
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -1 +1,56 @@
|
||||
package resp
|
||||
|
||||
import "net/http"
|
||||
|
||||
type Error struct {
|
||||
Status int `json:"status"`
|
||||
Msg string `json:"msg"`
|
||||
Err error `json:"err"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *Error) _r() *res {
|
||||
data := &res{
|
||||
Status: e.Status,
|
||||
Msg: e.Msg,
|
||||
Data: e.Data,
|
||||
Err: e.Err,
|
||||
}
|
||||
|
||||
if data.Status < 0 || data.Status > 999 {
|
||||
data.Status = 500
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func NewError(err error, args ...any) *Error {
|
||||
e := &Error{
|
||||
Status: http.StatusInternalServerError,
|
||||
Err: err,
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
if status, ok := args[0].(int); ok {
|
||||
e.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
e.Msg = Msg(e.Status)
|
||||
|
||||
if len(args) > 1 {
|
||||
if msg, ok := args[1].(string); ok {
|
||||
e.Msg = msg
|
||||
}
|
||||
}
|
||||
|
||||
if len(args) > 2 {
|
||||
e.Data = args[2]
|
||||
}
|
||||
|
||||
return e
|
||||
}
|
||||
|
@ -1 +1,34 @@
|
||||
package resp
|
||||
|
||||
const (
|
||||
Msg200 = "操作成功"
|
||||
Msg400 = "参数错误"
|
||||
Msg401 = "该账号登录已失效, 请重新登录"
|
||||
Msg401NoMulti = "用户已在其他地方登录"
|
||||
Msg403 = "权限不足"
|
||||
Msg404 = "资源不存在"
|
||||
Msg500 = "服务器开小差了"
|
||||
Msg501 = "服务不可用"
|
||||
Msg503 = "服务不可用或正在升级, 请联系管理员"
|
||||
)
|
||||
|
||||
func Msg(status int) string {
|
||||
switch status {
|
||||
case 400:
|
||||
return Msg400
|
||||
case 401:
|
||||
return Msg401
|
||||
case 403:
|
||||
return Msg403
|
||||
case 404:
|
||||
return Msg404
|
||||
case 500:
|
||||
return Msg500
|
||||
case 501:
|
||||
return Msg501
|
||||
case 503:
|
||||
return Msg503
|
||||
}
|
||||
|
||||
return "未知错误"
|
||||
}
|
||||
|
104
pkg/resp/resp.go
104
pkg/resp/resp.go
@ -1 +1,105 @@
|
||||
package resp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
type res struct {
|
||||
Status int `json:"status"`
|
||||
Msg string `json:"msg"`
|
||||
Data any `json:"data"`
|
||||
Err any `json:"err"`
|
||||
}
|
||||
|
||||
func R200(c fiber.Ctx, data any, msgs ...string) error {
|
||||
r := &res{
|
||||
Status: 200,
|
||||
Msg: Msg200,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
if len(msgs) > 0 && msgs[0] != "" {
|
||||
r.Msg = msgs[0]
|
||||
}
|
||||
|
||||
return c.JSON(r)
|
||||
}
|
||||
|
||||
func RC(c fiber.Ctx, status int, args ...any) error {
|
||||
return _r(c, &res{Status: status}, args...)
|
||||
}
|
||||
|
||||
func RE(c fiber.Ctx, err error) error {
|
||||
var re *Error
|
||||
|
||||
if errors.As(err, &re) {
|
||||
return _r(c, re._r())
|
||||
}
|
||||
|
||||
return R500(c, "", nil, err)
|
||||
}
|
||||
|
||||
func _r(c fiber.Ctx, r *res, args ...any) error {
|
||||
length := len(args)
|
||||
switch length {
|
||||
case 0:
|
||||
break
|
||||
case 1:
|
||||
if msg, ok := args[0].(string); ok {
|
||||
r.Msg = msg
|
||||
} else {
|
||||
r.Data = args[0]
|
||||
}
|
||||
case 2:
|
||||
r.Data = args[1]
|
||||
case 3:
|
||||
r.Err = args[2]
|
||||
}
|
||||
|
||||
if r.Msg == "" {
|
||||
r.Msg = Msg(r.Status)
|
||||
}
|
||||
|
||||
return c.Status(r.Status).JSON(r)
|
||||
}
|
||||
|
||||
func R400(c fiber.Ctx, args ...any) error {
|
||||
r := &res{
|
||||
Status: 400,
|
||||
}
|
||||
|
||||
return _r(c, r, args...)
|
||||
}
|
||||
|
||||
func R401(c fiber.Ctx, args ...any) error {
|
||||
r := &res{
|
||||
Status: 401,
|
||||
}
|
||||
|
||||
return _r(c, r, args...)
|
||||
}
|
||||
|
||||
func R403(c fiber.Ctx, args ...any) error {
|
||||
r := &res{
|
||||
Status: 403,
|
||||
}
|
||||
|
||||
return _r(c, r, args...)
|
||||
}
|
||||
|
||||
func R500(c fiber.Ctx, args ...any) error {
|
||||
r := &res{
|
||||
Status: 500,
|
||||
}
|
||||
|
||||
return _r(c, r, args...)
|
||||
}
|
||||
|
||||
func R501(c fiber.Ctx, args ...any) error {
|
||||
r := &res{
|
||||
Status: 501,
|
||||
}
|
||||
|
||||
return _r(c, r, args...)
|
||||
}
|
||||
|
44
pkg/tool/ctx.go
Normal file
44
pkg/tool/ctx.go
Normal file
@ -0,0 +1,44 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gitea.loveuer.com/yizhisec/packages/opt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Timeout(seconds ...int) (ctx context.Context) {
|
||||
var (
|
||||
duration time.Duration
|
||||
)
|
||||
|
||||
if len(seconds) > 0 && seconds[0] > 0 {
|
||||
duration = time.Duration(seconds[0]) * time.Second
|
||||
} else {
|
||||
duration = time.Duration(30) * time.Second
|
||||
}
|
||||
|
||||
ctx, _ = context.WithTimeout(context.Background(), duration)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func CtxTimeout(ctx context.Context, seconds ...int) context.Context {
|
||||
var (
|
||||
duration time.Duration
|
||||
)
|
||||
|
||||
if len(seconds) > 0 && seconds[0] > 0 {
|
||||
duration = time.Duration(seconds[0]) * time.Second
|
||||
} else {
|
||||
duration = time.Duration(30) * time.Second
|
||||
}
|
||||
|
||||
nctx, _ := context.WithTimeout(ctx, duration)
|
||||
|
||||
return nctx
|
||||
}
|
||||
|
||||
func CtxTrace(ctx context.Context, key string) context.Context {
|
||||
return context.WithValue(ctx, opt.TraceKey, fmt.Sprintf("%36s", key))
|
||||
}
|
12
pkg/tool/gin.go
Normal file
12
pkg/tool/gin.go
Normal file
@ -0,0 +1,12 @@
|
||||
package tool
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Local(c *gin.Context, key string) any {
|
||||
data, ok := c.Get(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
50
pkg/tool/human.go
Normal file
50
pkg/tool/human.go
Normal file
@ -0,0 +1,50 @@
|
||||
package tool
|
||||
|
||||
import "fmt"
|
||||
|
||||
func HumanDuration(nano int64) string {
|
||||
duration := float64(nano)
|
||||
unit := "ns"
|
||||
if duration >= 1000 {
|
||||
duration /= 1000
|
||||
unit = "us"
|
||||
}
|
||||
|
||||
if duration >= 1000 {
|
||||
duration /= 1000
|
||||
unit = "ms"
|
||||
}
|
||||
|
||||
if duration >= 1000 {
|
||||
duration /= 1000
|
||||
unit = " s"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%6.2f%s", duration, unit)
|
||||
}
|
||||
|
||||
func HumanSize(size int64) string {
|
||||
const (
|
||||
_ = iota
|
||||
KB = 1 << (10 * iota) // 1 KB = 1024 bytes
|
||||
MB // 1 MB = 1024 KB
|
||||
GB // 1 GB = 1024 MB
|
||||
TB // 1 TB = 1024 GB
|
||||
PB // 1 PB = 1024 TB
|
||||
)
|
||||
|
||||
switch {
|
||||
case size >= PB:
|
||||
return fmt.Sprintf("%.2f PB", float64(size)/PB)
|
||||
case size >= TB:
|
||||
return fmt.Sprintf("%.2f TB", float64(size)/TB)
|
||||
case size >= GB:
|
||||
return fmt.Sprintf("%.2f GB", float64(size)/GB)
|
||||
case size >= MB:
|
||||
return fmt.Sprintf("%.2f MB", float64(size)/MB)
|
||||
case size >= KB:
|
||||
return fmt.Sprintf("%.2f KB", float64(size)/KB)
|
||||
default:
|
||||
return fmt.Sprintf("%d bytes", size)
|
||||
}
|
||||
}
|
59
pkg/tool/ip.go
Normal file
59
pkg/tool/ip.go
Normal file
@ -0,0 +1,59 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
var (
|
||||
privateIPv4Blocks []*net.IPNet
|
||||
privateIPv6Blocks []*net.IPNet
|
||||
)
|
||||
|
||||
func init() {
|
||||
// IPv4私有地址段
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8", // A类私有地址
|
||||
"172.16.0.0/12", // B类私有地址
|
||||
"192.168.0.0/16", // C类私有地址
|
||||
"169.254.0.0/16", // 链路本地地址
|
||||
"127.0.0.0/8", // 环回地址
|
||||
} {
|
||||
_, block, _ := net.ParseCIDR(cidr)
|
||||
privateIPv4Blocks = append(privateIPv4Blocks, block)
|
||||
}
|
||||
|
||||
// IPv6私有地址段
|
||||
for _, cidr := range []string{
|
||||
"fc00::/7", // 唯一本地地址
|
||||
"fe80::/10", // 链路本地地址
|
||||
"::1/128", // 环回地址
|
||||
} {
|
||||
_, block, _ := net.ParseCIDR(cidr)
|
||||
privateIPv6Blocks = append(privateIPv6Blocks, block)
|
||||
}
|
||||
}
|
||||
|
||||
func IsPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 处理IPv4和IPv4映射的IPv6地址
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
for _, block := range privateIPv4Blocks {
|
||||
if block.Contains(ip4) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 处理IPv6地址
|
||||
for _, block := range privateIPv6Blocks {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
76
pkg/tool/loadash.go
Normal file
76
pkg/tool/loadash.go
Normal file
@ -0,0 +1,76 @@
|
||||
package tool
|
||||
|
||||
import "math"
|
||||
|
||||
func Map[T, R any](vals []T, fn func(item T, index int) R) []R {
|
||||
var result = make([]R, len(vals))
|
||||
for idx, v := range vals {
|
||||
result[idx] = fn(v, idx)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func Chunk[T any](vals []T, size int) [][]T {
|
||||
if size <= 0 {
|
||||
panic("Second parameter must be greater than 0")
|
||||
}
|
||||
|
||||
chunksNum := len(vals) / size
|
||||
if len(vals)%size != 0 {
|
||||
chunksNum += 1
|
||||
}
|
||||
|
||||
result := make([][]T, 0, chunksNum)
|
||||
|
||||
for i := 0; i < chunksNum; i++ {
|
||||
last := (i + 1) * size
|
||||
if last > len(vals) {
|
||||
last = len(vals)
|
||||
}
|
||||
result = append(result, vals[i*size:last:last])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// 对 vals 取样 x 个
|
||||
func Sample[T any](vals []T, x int) []T {
|
||||
if x < 0 {
|
||||
panic("Second parameter can't be negative")
|
||||
}
|
||||
|
||||
n := len(vals)
|
||||
if n == 0 {
|
||||
return []T{}
|
||||
}
|
||||
|
||||
if x >= n {
|
||||
return vals
|
||||
}
|
||||
|
||||
// 处理x=1的特殊情况
|
||||
if x == 1 {
|
||||
return []T{vals[(n-1)/2]}
|
||||
}
|
||||
|
||||
// 计算采样步长并生成结果数组
|
||||
step := float64(n-1) / float64(x-1)
|
||||
result := make([]T, x)
|
||||
|
||||
for i := 0; i < x; i++ {
|
||||
// 计算采样位置并四舍五入
|
||||
pos := float64(i) * step
|
||||
index := int(math.Round(pos))
|
||||
result[i] = vals[index]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func If[T any](cond bool, trueVal, falseVal T) T {
|
||||
if cond {
|
||||
return trueVal
|
||||
}
|
||||
|
||||
return falseVal
|
||||
}
|
53
pkg/tool/must.go
Normal file
53
pkg/tool/must.go
Normal file
@ -0,0 +1,53 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gitea.loveuer.com/yizhisec/packages/logger"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func Must(errs ...error) {
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
logger.Panic(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MustWithData[T any](data T, err error) T {
|
||||
Must(err)
|
||||
return data
|
||||
}
|
||||
|
||||
func MustStop(ctx context.Context, stopFns ...func(ctx context.Context) error) {
|
||||
if len(stopFns) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ok := make(chan struct{})
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(len(stopFns))
|
||||
|
||||
for _, fn := range stopFns {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
if err := fn(ctx); err != nil {
|
||||
logger.ErrorCtx(ctx, "stop function failed, err = %s", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.FatalCtx(ctx, "stop function timeout, force down")
|
||||
case _, _ = <-ok:
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(ok)
|
||||
}
|
84
pkg/tool/password.go
Normal file
84
pkg/tool/password.go
Normal file
@ -0,0 +1,84 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitea.loveuer.com/yizhisec/packages/logger"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
EncryptHeader string = "pbkdf2:sha256" // 用户密码加密
|
||||
)
|
||||
|
||||
func NewPassword(password string) string {
|
||||
return EncryptPassword(password, RandomString(8), int(RandomInt(50000)+100000))
|
||||
}
|
||||
|
||||
func ComparePassword(in, db string) bool {
|
||||
strs := strings.Split(db, "$")
|
||||
if len(strs) != 3 {
|
||||
logger.Error("password in db invalid: %s", db)
|
||||
return false
|
||||
}
|
||||
|
||||
encs := strings.Split(strs[0], ":")
|
||||
if len(encs) != 3 {
|
||||
logger.Error("password in db invalid: %s", db)
|
||||
return false
|
||||
}
|
||||
|
||||
encIteration, err := strconv.Atoi(encs[2])
|
||||
if err != nil {
|
||||
logger.Error("password in db invalid: %s, convert iter err: %s", db, err)
|
||||
return false
|
||||
}
|
||||
|
||||
return EncryptPassword(in, strs[1], encIteration) == db
|
||||
}
|
||||
|
||||
func EncryptPassword(password, salt string, iter int) string {
|
||||
hash := pbkdf2.Key([]byte(password), []byte(salt), iter, 32, sha256.New)
|
||||
encrypted := hex.EncodeToString(hash)
|
||||
return fmt.Sprintf("%s:%d$%s$%s", EncryptHeader, iter, salt, encrypted)
|
||||
}
|
||||
|
||||
func CheckPassword(password string) error {
|
||||
if len(password) < 8 || len(password) > 32 {
|
||||
return errors.New("密码长度不符合")
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
match bool
|
||||
patternList = []string{`[0-9]+`, `[a-z]+`, `[A-Z]+`, `[!@#%]+`} //, `[~!@#$%^&*?_-]+`}
|
||||
matchAccount = 0
|
||||
tips = []string{"缺少数字", "缺少小写字母", "缺少大写字母", "缺少'!@#%'"}
|
||||
locktips = make([]string, 0)
|
||||
)
|
||||
|
||||
for idx, pattern := range patternList {
|
||||
match, err = regexp.MatchString(pattern, password)
|
||||
if err != nil {
|
||||
logger.Warn("regex match string err, reg_str: %s, err: %v", pattern, err)
|
||||
return errors.New("密码强度不够")
|
||||
}
|
||||
|
||||
if match {
|
||||
matchAccount++
|
||||
} else {
|
||||
locktips = append(locktips, tips[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if matchAccount < 3 {
|
||||
return fmt.Errorf("密码强度不够, 可能 %s", strings.Join(locktips, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
20
pkg/tool/password_test.go
Normal file
20
pkg/tool/password_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package tool
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEncPassword(t *testing.T) {
|
||||
password := "123456"
|
||||
|
||||
result := EncryptPassword(password, RandomString(8), 50000)
|
||||
|
||||
t.Logf("sum => %s", result)
|
||||
}
|
||||
|
||||
func TestPassword(t *testing.T) {
|
||||
p := "wahaha@123"
|
||||
p = NewPassword(p)
|
||||
t.Logf("password => %s", p)
|
||||
|
||||
result := ComparePassword("wahaha@123", p)
|
||||
t.Logf("compare result => %v", result)
|
||||
}
|
75
pkg/tool/random.go
Normal file
75
pkg/tool/random.go
Normal file
@ -0,0 +1,75 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
mrand "math/rand"
|
||||
)
|
||||
|
||||
var (
|
||||
letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
letterNum = []byte("0123456789")
|
||||
letterLow = []byte("abcdefghijklmnopqrstuvwxyz")
|
||||
letterCap = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
letterSyb = []byte("!@#$%^&*()_+-=")
|
||||
adjectives = []string{
|
||||
"开心的", "灿烂的", "温暖的", "阳光的", "活泼的",
|
||||
"聪明的", "优雅的", "幸运的", "甜蜜的", "勇敢的",
|
||||
"宁静的", "热情的", "温柔的", "幽默的", "坚强的",
|
||||
"迷人的", "神奇的", "快乐的", "健康的", "自由的",
|
||||
"梦幻的", "勤劳的", "真诚的", "浪漫的", "自信的",
|
||||
}
|
||||
|
||||
plants = []string{
|
||||
"苹果", "香蕉", "橘子", "葡萄", "草莓",
|
||||
"西瓜", "樱桃", "菠萝", "柠檬", "蜜桃",
|
||||
"蓝莓", "芒果", "石榴", "甜瓜", "雪梨",
|
||||
"番茄", "南瓜", "土豆", "青椒", "洋葱",
|
||||
"黄瓜", "萝卜", "豌豆", "玉米", "蘑菇",
|
||||
"菠菜", "茄子", "芹菜", "莲藕", "西兰花",
|
||||
}
|
||||
)
|
||||
|
||||
func RandomInt(max int64) int64 {
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(max))
|
||||
return num.Int64()
|
||||
}
|
||||
|
||||
func RandomString(length int) string {
|
||||
result := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
result[i] = letters[num.Int64()]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomPassword(length int, withSymbol bool) string {
|
||||
result := make([]byte, length)
|
||||
kind := 3
|
||||
if withSymbol {
|
||||
kind++
|
||||
}
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
switch i % kind {
|
||||
case 0:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterNum))))
|
||||
result[i] = letterNum[num.Int64()]
|
||||
case 1:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterLow))))
|
||||
result[i] = letterLow[num.Int64()]
|
||||
case 2:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterCap))))
|
||||
result[i] = letterCap[num.Int64()]
|
||||
case 3:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterSyb))))
|
||||
result[i] = letterSyb[num.Int64()]
|
||||
}
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomName() string {
|
||||
return adjectives[mrand.Intn(len(adjectives))] + plants[mrand.Intn(len(plants))]
|
||||
}
|
11
pkg/tool/string.go
Normal file
11
pkg/tool/string.go
Normal file
@ -0,0 +1,11 @@
|
||||
package tool
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func BytesToString(b []byte) string {
|
||||
return unsafe.String(unsafe.SliceData(b), len(b))
|
||||
}
|
||||
|
||||
func StringToBytes(s string) []byte {
|
||||
return unsafe.Slice(unsafe.StringData(s), len(s))
|
||||
}
|
124
pkg/tool/table.go
Normal file
124
pkg/tool/table.go
Normal file
@ -0,0 +1,124 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gitea.loveuer.com/yizhisec/packages/logger"
|
||||
"github.com/jedib0t/go-pretty/v6/table"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func TablePrinter(data any, writers ...io.Writer) {
|
||||
var w io.Writer = os.Stdout
|
||||
if len(writers) > 0 && writers[0] != nil {
|
||||
w = writers[0]
|
||||
}
|
||||
|
||||
t := table.NewWriter()
|
||||
structPrinter(t, "", data)
|
||||
_, _ = fmt.Fprintln(w, t.Render())
|
||||
}
|
||||
|
||||
func structPrinter(w table.Writer, prefix string, item any) {
|
||||
Start:
|
||||
rv := reflect.ValueOf(item)
|
||||
if rv.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
for rv.Type().Kind() == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
|
||||
switch rv.Type().Kind() {
|
||||
case reflect.Invalid,
|
||||
reflect.Uintptr,
|
||||
reflect.Chan,
|
||||
reflect.Func,
|
||||
reflect.UnsafePointer:
|
||||
case reflect.Bool,
|
||||
reflect.Int,
|
||||
reflect.Int8,
|
||||
reflect.Int16,
|
||||
reflect.Int32,
|
||||
reflect.Int64,
|
||||
reflect.Uint,
|
||||
reflect.Uint8,
|
||||
reflect.Uint16,
|
||||
reflect.Uint32,
|
||||
reflect.Uint64,
|
||||
reflect.Float32,
|
||||
reflect.Float64,
|
||||
reflect.Complex64,
|
||||
reflect.Complex128,
|
||||
reflect.Interface:
|
||||
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), rv.Interface()})
|
||||
case reflect.String:
|
||||
val := rv.String()
|
||||
if len(val) <= 160 {
|
||||
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val})
|
||||
return
|
||||
}
|
||||
|
||||
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val[0:64] + "..." + val[len(val)-64:]})
|
||||
case reflect.Array, reflect.Slice:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
p := strings.Join([]string{prefix, fmt.Sprintf("[%d]", i)}, ".")
|
||||
structPrinter(w, p, rv.Index(i).Interface())
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, k := range rv.MapKeys() {
|
||||
structPrinter(w, fmt.Sprintf("%s.{%v}", prefix, k), rv.MapIndex(k).Interface())
|
||||
}
|
||||
case reflect.Pointer:
|
||||
goto Start
|
||||
case reflect.Struct:
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
p := fmt.Sprintf("%s.%s", prefix, rv.Type().Field(i).Name)
|
||||
field := rv.Field(i)
|
||||
|
||||
//log.Debug("TablePrinter: prefix: %s, field: %v", p, rv.Field(i))
|
||||
|
||||
if !field.CanInterface() {
|
||||
return
|
||||
}
|
||||
|
||||
structPrinter(w, p, field.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TableMapPrinter(data []byte) {
|
||||
m := make(map[string]any)
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
t := table.NewWriter()
|
||||
addRow(t, "", m)
|
||||
fmt.Println(t.Render())
|
||||
}
|
||||
|
||||
func addRow(w table.Writer, prefix string, m any) {
|
||||
rv := reflect.ValueOf(m)
|
||||
switch rv.Type().Kind() {
|
||||
case reflect.Map:
|
||||
for _, k := range rv.MapKeys() {
|
||||
key := k.String()
|
||||
if prefix != "" {
|
||||
key = strings.Join([]string{prefix, k.String()}, ".")
|
||||
}
|
||||
addRow(w, key, rv.MapIndex(k).Interface())
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
addRow(w, fmt.Sprintf("%s[%d]", prefix, i), rv.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
w.AppendRow(table.Row{prefix, m})
|
||||
}
|
||||
}
|
73
pkg/tool/tools.go
Normal file
73
pkg/tool/tools.go
Normal file
@ -0,0 +1,73 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
func Min[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func Mins[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
|
||||
var val T
|
||||
|
||||
if len(vals) == 0 {
|
||||
return val
|
||||
}
|
||||
|
||||
val = vals[0]
|
||||
|
||||
for _, item := range vals[1:] {
|
||||
if item < val {
|
||||
val = item
|
||||
}
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
func Max[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
|
||||
if a >= b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func Maxs[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
|
||||
var val T
|
||||
|
||||
if len(vals) == 0 {
|
||||
return val
|
||||
}
|
||||
|
||||
for _, item := range vals {
|
||||
if item > val {
|
||||
val = item
|
||||
}
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
func Sum[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
|
||||
var sum T = 0
|
||||
for i := range vals {
|
||||
sum += vals[i]
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func Percent(val, minVal, maxVal, minPercent, maxPercent float64) string {
|
||||
return fmt.Sprintf(
|
||||
"%d%%",
|
||||
int(math.Round(
|
||||
((val-minVal)/(maxVal-minVal)*(maxPercent-minPercent)+minPercent)*100,
|
||||
)),
|
||||
)
|
||||
}
|
70
pkg/tool/tools_test.go
Normal file
70
pkg/tool/tools_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
package tool
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPercent(t *testing.T) {
|
||||
type args struct {
|
||||
val float64
|
||||
minVal float64
|
||||
maxVal float64
|
||||
minPercent float64
|
||||
maxPercent float64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "case 1",
|
||||
args: args{
|
||||
val: 0.5,
|
||||
minVal: 0,
|
||||
maxVal: 1,
|
||||
minPercent: 0,
|
||||
maxPercent: 1,
|
||||
},
|
||||
want: "50%",
|
||||
},
|
||||
{
|
||||
name: "case 2",
|
||||
args: args{
|
||||
val: 0.3,
|
||||
minVal: 0.1,
|
||||
maxVal: 0.6,
|
||||
minPercent: 0,
|
||||
maxPercent: 1,
|
||||
},
|
||||
want: "40%",
|
||||
},
|
||||
{
|
||||
name: "case 3",
|
||||
args: args{
|
||||
val: 700,
|
||||
minVal: 700,
|
||||
maxVal: 766,
|
||||
minPercent: 0.1,
|
||||
maxPercent: 0.7,
|
||||
},
|
||||
want: "10%",
|
||||
},
|
||||
{
|
||||
name: "case 4",
|
||||
args: args{
|
||||
val: 766,
|
||||
minVal: 700,
|
||||
maxVal: 766,
|
||||
minPercent: 0.1,
|
||||
maxPercent: 0.7,
|
||||
},
|
||||
want: "70%",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := Percent(tt.args.val, tt.args.minVal, tt.args.maxVal, tt.args.minPercent, tt.args.maxPercent); got != tt.want {
|
||||
t.Errorf("Percent() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user