wip: 登录和认证

This commit is contained in:
loveuer
2025-07-13 22:57:57 +08:00
parent 48af538f98
commit b48fa05d9f
33 changed files with 1961 additions and 33 deletions

98
pkg/database/cache/cache.go vendored Normal file
View File

@ -0,0 +1,98 @@
package cache
import (
"context"
"encoding/json"
"errors"
"sync"
"time"
)
type encoded_value interface {
MarshalBinary() ([]byte, error)
}
type decoded_value interface {
UnmarshalBinary(bs []byte) error
}
type Scanner interface {
Scan(model any) error
}
type scan struct {
err error
bs []byte
}
func newScan(bs []byte, err error) *scan {
return &scan{bs: bs, err: err}
}
func (s *scan) Scan(model any) error {
if s.err != nil {
return s.err
}
return unmarshaler(s.bs, model)
}
type Cache interface {
Get(ctx context.Context, key string) ([]byte, error)
Gets(ctx context.Context, keys ...string) ([][]byte, error)
GetScan(ctx context.Context, key string) Scanner
GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
GetExScan(ctx context.Context, key string, duration time.Duration) Scanner
// Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
Set(ctx context.Context, key string, value any) error
Sets(ctx context.Context, vm map[string]any) error
// SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
SetEx(ctx context.Context, key string, value any, duration time.Duration) error
Del(ctx context.Context, keys ...string) error
GetDel(ctx context.Context, key string) ([]byte, error)
GetDelScan(ctx context.Context, key string) Scanner
Close()
Client() any
}
var (
lock = &sync.Mutex{}
marshaler func(data any) ([]byte, error) = json.Marshal
unmarshaler func(data []byte, model any) error = json.Unmarshal
ErrorKeyNotFound = errors.New("key not found")
ErrorStoreFailed = errors.New("store failed")
Default Cache
)
func handleValue(value any) ([]byte, error) {
var (
bs []byte
err error
)
switch val := value.(type) {
case []byte:
return val, nil
}
if imp, ok := value.(encoded_value); ok {
bs, err = imp.MarshalBinary()
} else {
bs, err = marshaler(value)
}
return bs, err
}
func SetMarshaler(fn func(data any) ([]byte, error)) {
lock.Lock()
defer lock.Unlock()
marshaler = fn
}
func SetUnmarshaler(fn func(data []byte, model any) error) {
lock.Lock()
defer lock.Unlock()
unmarshaler = fn
}

155
pkg/database/cache/memory.go vendored Normal file
View File

@ -0,0 +1,155 @@
package cache
import (
"context"
"errors"
"time"
"github.com/dgraph-io/ristretto/v2"
)
var _ Cache = (*_mem)(nil)
type _mem struct {
ctx context.Context
cache *ristretto.Cache[string, []byte]
}
func newMemory(ctx context.Context, ins *ristretto.Cache[string, []byte]) Cache {
return &_mem{
ctx: ctx,
cache: ins,
}
}
func (m *_mem) Client() any {
return m.cache
}
func (c *_mem) Close() {
c.cache.Close()
}
func (c *_mem) Del(ctx context.Context, keys ...string) error {
for _, key := range keys {
c.cache.Del(key)
}
return nil
}
func (c *_mem) Get(ctx context.Context, key string) ([]byte, error) {
val, ok := c.cache.Get(key)
if !ok {
return val, ErrorKeyNotFound
}
return val, nil
}
func (c *_mem) GetDel(ctx context.Context, key string) ([]byte, error) {
val, err := c.Get(ctx, key)
if err != nil {
return val, err
}
c.cache.Del(key)
return val, err
}
func (c *_mem) GetDelScan(ctx context.Context, key string) Scanner {
val, err := c.GetDel(ctx, key)
return newScan(val, err)
}
func (c *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
val, err := c.Get(ctx, key)
if err != nil {
return val, err
}
c.cache.SetWithTTL(key, val, 1, duration)
return val, err
}
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
val, err := m.GetEx(ctx, key, duration)
return newScan(val, err)
}
func (m *_mem) GetScan(ctx context.Context, key string) Scanner {
val, err := m.Get(ctx, key)
return newScan(val, err)
}
func (m *_mem) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
vals := make([][]byte, 0, len(keys))
for _, key := range keys {
val, err := m.Get(ctx, key)
if err != nil {
if errors.Is(err, ErrorKeyNotFound) {
continue
}
return vals, err
}
vals = append(vals, val)
}
if len(vals) != len(keys) {
return vals, ErrorKeyNotFound
}
return vals, nil
}
func (m *_mem) Set(ctx context.Context, key string, value any) error {
val, err := handleValue(value)
if err != nil {
return err
}
if ok := m.cache.Set(key, val, 1); !ok {
return ErrorStoreFailed
}
m.cache.Wait()
return nil
}
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
val, err := handleValue(value)
if err != nil {
return err
}
if ok := m.cache.SetWithTTL(key, val, 1, duration); !ok {
return ErrorStoreFailed
}
m.cache.Wait()
return nil
}
func (m *_mem) Sets(ctx context.Context, vm map[string]any) error {
for key, value := range vm {
val, err := handleValue(value)
if err != nil {
return err
}
if ok := m.cache.Set(key, val, 1); !ok {
return ErrorStoreFailed
}
}
m.cache.Wait()
return nil
}

84
pkg/database/cache/new.go vendored Normal file
View File

@ -0,0 +1,84 @@
package cache
import (
"context"
"fmt"
"net/url"
"gitea.loveuer.com/yizhisec/packages/tool"
"github.com/dgraph-io/ristretto/v2"
"github.com/go-redis/redis/v8"
)
func New(opts ...Option) (Cache, error) {
var (
err error
cfg = &option{
ctx: context.Background(),
}
)
for _, opt := range opts {
opt(cfg)
}
if cfg.redis != nil {
var (
ins *url.URL
client *redis.Client
)
if ins, err = url.Parse(*cfg.redis); err != nil {
return nil, err
}
username := ins.User.Username()
password, _ := ins.User.Password()
client = redis.NewClient(&redis.Options{
Addr: ins.Host,
Username: username,
Password: password,
})
if err = client.Ping(tool.CtxTimeout(cfg.ctx, 5)).Err(); err != nil {
return nil, err
}
return newRedis(cfg.ctx, client), nil
}
if cfg.memory {
var (
ins *ristretto.Cache[string, []byte]
)
if ins, err = ristretto.NewCache(&ristretto.Config[string, []byte]{
NumCounters: 1e7, // number of keys to track frequency of (10M).
MaxCost: 1 << 30, // maximum cost of cache (1GB).
BufferItems: 64, // number of keys per Get buffer.
}); err != nil {
return nil, err
}
return newMemory(cfg.ctx, ins), nil
}
return nil, fmt.Errorf("invalid cache option")
}
func Init(opts ...Option) (err error) {
opt := &option{}
for _, optFn := range opts {
optFn(opt)
}
if opt.memory {
Default, err = New(opts...)
return err
}
Default, err = New(opts...)
return err
}

108
pkg/database/cache/new_test.go vendored Normal file
View File

@ -0,0 +1,108 @@
package cache
import (
"testing"
)
func TestNew(t *testing.T) {
/* if err := Init(WithRedis("127.0.0.1", 6379, "", "MyPassw0rd")); err != nil {
t.Fatal(err)
}
type User struct {
Name string `json:"name"`
Age int `json:"age"`
}
if err := Default.Set(t.Context(), "zyp:haha", &User{
Name: "cache",
Age: 18,
}); err != nil {
t.Fatal(err)
}
s := Default.GetDelScan(t.Context(), "zyp:haha")
u := new(User)
if err := s.Scan(u); err != nil {
t.Fatal(err)
}
t.Logf("%#v", *u)
if err := Default.SetEx(t.Context(), "zyp:haha", &User{
Name: "redis",
Age: 2,
}, time.Hour); err != nil {
t.Fatal(err)
}*/
}
func TestNoAuth(t *testing.T) {
//if err := Init(WithRedis("10.125.1.28", 6379, "", "")); err != nil {
// t.Fatal(err)
//}
//
//type User struct {
// Name string `json:"name"`
// Age int `json:"age"`
//}
//
//if err := Default.Set(t.Context(), "zyp:haha", &User{
// Name: "cache",
// Age: 18,
//}); err != nil {
// t.Fatal(err)
//}
//
//s := Default.GetDelScan(t.Context(), "zyp:haha")
//u := new(User)
//
//if err := s.Scan(u); err != nil {
// t.Fatal(err)
//}
//
//t.Logf("%#v", *u)
//
//if err := Default.SetEx(t.Context(), "zyp:haha", &User{
// Name: "redis",
// Age: 2,
//}, time.Hour); err != nil {
// t.Fatal(err)
//}
}
func TestMemoryDefault(t *testing.T) {
if err := Init(WithMemory()); err != nil {
t.Fatal("init err:", err)
}
if err := Default.Set(t.Context(), "123", "123"); err != nil {
t.Fatal("set err:", err)
}
val, err := Default.Get(t.Context(), "123")
if err != nil {
t.Fatal("get err:", err)
}
t.Logf("%s", val)
}
func TestMemoryNew(t *testing.T) {
client, err := New(WithMemory())
if err != nil {
t.Fatal("init err:", err)
}
if err := client.Set(t.Context(), "123", "123"); err != nil {
t.Fatal("set err:", err)
}
val, err := client.Get(t.Context(), "123")
if err != nil {
t.Fatal("get err:", err)
}
t.Logf("%s", val)
}

55
pkg/database/cache/option.go vendored Normal file
View File

@ -0,0 +1,55 @@
package cache
import (
"context"
"fmt"
"net/url"
)
type option struct {
ctx context.Context
redis *string
memory bool
}
type Option func(*option)
func WithCtx(ctx context.Context) Option {
return func(c *option) {
if ctx != nil {
c.ctx = ctx
}
}
}
func WithRedis(host string, port int, username, password string) Option {
return func(c *option) {
uri := fmt.Sprintf("redis://%s:%d", host, port)
if username != "" || password != "" {
uri = fmt.Sprintf("redis://%s:%s@%s:%d", username, password, host, port)
}
c.redis = &uri
}
}
func WithRedisURI(uri string) Option {
return func(c *option) {
ins, err := url.Parse(uri)
if err != nil {
return
}
if ins.Scheme != "redis" {
return
}
c.redis = &uri
}
}
func WithMemory() Option {
return func(c *option) {
c.memory = true
}
}

153
pkg/database/cache/redis.go vendored Normal file
View File

@ -0,0 +1,153 @@
package cache
import (
"context"
"errors"
"sync"
"time"
"gitea.loveuer.com/yizhisec/packages/tool"
"github.com/go-redis/redis/v8"
"github.com/spf13/cast"
)
var _ Cache = (*_redis)(nil)
type _redis struct {
sync.Mutex
ctx context.Context
client *redis.Client
}
func (r *_redis) Client() any {
return r.client
}
func newRedis(ctx context.Context, client *redis.Client) Cache {
r := &_redis{ctx: ctx, client: client}
go func() {
<-r.ctx.Done()
if client != nil {
r.Close()
}
}()
return r
}
func (r *_redis) GetDel(ctx context.Context, key string) ([]byte, error) {
s, err := r.client.GetDel(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return tool.StringToBytes(s), nil
}
func (r *_redis) GetDelScan(ctx context.Context, key string) Scanner {
bs, err := r.GetDel(ctx, key)
return newScan(bs, err)
}
func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) {
result, err := r.client.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return tool.StringToBytes(result), nil
}
func (r *_redis) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
result, err := r.client.MGet(ctx, keys...).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return tool.Map(
result,
func(item any, index int) []byte {
return tool.StringToBytes(cast.ToString(item))
},
), nil
}
func (r *_redis) GetScan(ctx context.Context, key string) Scanner {
return newScan(r.Get(ctx, key))
}
func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
result, err := r.client.GetEx(ctx, key, duration).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return tool.StringToBytes(result), nil
}
func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
return newScan(r.GetEx(ctx, key, duration))
}
func (r *_redis) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
return err
}
_, err = r.client.Set(ctx, key, bs, redis.KeepTTL).Result()
return err
}
func (r *_redis) Sets(ctx context.Context, values map[string]any) error {
vm := make(map[string]any)
for k, v := range values {
bs, err := handleValue(v)
if err != nil {
return err
}
vm[k] = bs
}
return r.client.MSet(ctx, vm).Err()
}
func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
return err
}
_, err = r.client.SetEX(ctx, key, bs, duration).Result()
return err
}
func (r *_redis) Del(ctx context.Context, keys ...string) error {
return r.client.Del(ctx, keys...).Err()
}
func (r *_redis) Close() {
r.Lock()
defer r.Unlock()
_ = r.client.Close()
r.client = nil
}

49
pkg/database/db/db.go Normal file
View File

@ -0,0 +1,49 @@
package db
import (
"context"
"gorm.io/gorm"
)
type Config struct {
Debug bool
DryRun bool
}
type DB interface {
Session(ctx context.Context, configs ...Config) *gorm.DB
}
type db struct {
tx *gorm.DB
}
var (
Default DB
)
func (db *db) Session(ctx context.Context, configs ...Config) *gorm.DB {
var (
sc = &gorm.Session{Context: ctx}
session *gorm.DB
)
if len(configs) == 0 {
session = db.tx.Session(sc)
return session
}
cfg := configs[0]
if cfg.DryRun {
sc.DryRun = true
}
session = db.tx.Session(sc)
if cfg.Debug {
session = session.Debug()
}
return session
}

48
pkg/database/db/new.go Normal file
View File

@ -0,0 +1,48 @@
package db
import (
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
var defaultSqlite = "data.db"
func New(opts ...OptionFn) (DB, error) {
var (
err error
conf = &config{
sqlite: &defaultSqlite,
}
tx *gorm.DB
)
for _, opt := range opts {
opt(conf)
}
if conf.mysql != nil {
tx, err = gorm.Open(mysql.Open(*conf.mysql))
goto CHECK
}
if conf.pg != nil {
tx, err = gorm.Open(postgres.Open(*conf.pg))
goto CHECK
}
tx, err = gorm.Open(sqlite.Open(*conf.sqlite))
CHECK:
if err != nil {
return nil, err
}
return &db{tx: tx}, nil
}
func Init(opts ...OptionFn) (err error) {
Default, err = New(opts...)
return err
}

View File

@ -0,0 +1,25 @@
package db
import (
"testing"
)
func TestNew(t *testing.T) {
//mdb, err := New(WithMysql("127.0.0.1", 3306, "root", "MyPassw0rd", "mydb"))
//if err != nil {
// t.Fatal(err)
//}
//
//type User struct {
// Id uint64 `gorm:"primaryKey"`
// Username string `gorm:"unique"`
//}
//
//if err = mdb.Session(t.Context()).AutoMigrate(&User{}); err != nil {
// t.Fatal(err)
//}
//
//if err = mdb.Session(t.Context()).Create(&User{Username: "zyp"}).Error; err != nil {
// t.Fatal(err)
//}
}

45
pkg/database/db/option.go Normal file
View File

@ -0,0 +1,45 @@
package db
import (
"context"
"fmt"
)
type config struct {
ctx context.Context
mysql *string
pg *string
sqlite *string
}
type OptionFn func(*config)
func WithCtx(ctx context.Context) OptionFn {
return func(c *config) {
if ctx != nil {
c.ctx = ctx
}
}
}
func WithMysql(host string, port int, user string, password string, database string) OptionFn {
return func(c *config) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, database)
c.mysql = &dsn
}
}
func WithPg(host string, port int, user string, password string, database string) OptionFn {
return func(c *config) {
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai", host, user, password, database, port)
c.pg = &dsn
}
}
func WithSqlite(path string) OptionFn {
return func(c *config) {
if path != "" {
c.sqlite = &path
}
}
}

View File

@ -1,9 +1,53 @@
package logger
import "github.com/gofiber/fiber/v3"
import (
"fmt"
"github.com/gofiber/fiber/v3"
"github.com/spf13/cast"
"loveuer/utodo/pkg/logger"
"strconv"
"strings"
"sync"
"time"
)
func New() fiber.Handler {
pool := sync.Pool{
New: func() any {
return &strings.Builder{}
},
}
return func(c fiber.Ctx) error {
return c.Next()
start := time.Now()
err := c.Next()
duration := time.Since(start)
method := c.Method()
path := c.Path()
status := c.Response().StatusCode()
traceId := c.Context().Value(logger.CtxKey)
buf := pool.Get().(*strings.Builder)
defer pool.Put(buf)
buf.Reset()
buf.WriteString("API | ")
buf.WriteString(start.Format("2006-01-02T15:04:05"))
buf.WriteString(" | ")
buf.WriteString(method)
buf.WriteString(" | ")
buf.WriteString(path)
buf.WriteString(" | ")
buf.WriteString(duration.String())
buf.WriteString(" | ")
buf.WriteString(strconv.Itoa(status))
buf.WriteString(" | ")
buf.WriteString(cast.ToString(traceId))
fmt.Println(buf.String())
return err
}
}

View File

@ -1 +1,56 @@
package resp
import "net/http"
type Error struct {
Status int `json:"status"`
Msg string `json:"msg"`
Err error `json:"err"`
Data any `json:"data"`
}
func (e *Error) Error() string {
return e.Err.Error()
}
func (e *Error) _r() *res {
data := &res{
Status: e.Status,
Msg: e.Msg,
Data: e.Data,
Err: e.Err,
}
if data.Status < 0 || data.Status > 999 {
data.Status = 500
}
return data
}
func NewError(err error, args ...any) *Error {
e := &Error{
Status: http.StatusInternalServerError,
Err: err,
}
if len(args) > 0 {
if status, ok := args[0].(int); ok {
e.Status = status
}
}
e.Msg = Msg(e.Status)
if len(args) > 1 {
if msg, ok := args[1].(string); ok {
e.Msg = msg
}
}
if len(args) > 2 {
e.Data = args[2]
}
return e
}

View File

@ -1 +1,34 @@
package resp
const (
Msg200 = "操作成功"
Msg400 = "参数错误"
Msg401 = "该账号登录已失效, 请重新登录"
Msg401NoMulti = "用户已在其他地方登录"
Msg403 = "权限不足"
Msg404 = "资源不存在"
Msg500 = "服务器开小差了"
Msg501 = "服务不可用"
Msg503 = "服务不可用或正在升级, 请联系管理员"
)
func Msg(status int) string {
switch status {
case 400:
return Msg400
case 401:
return Msg401
case 403:
return Msg403
case 404:
return Msg404
case 500:
return Msg500
case 501:
return Msg501
case 503:
return Msg503
}
return "未知错误"
}

View File

@ -1 +1,105 @@
package resp
import (
"errors"
"github.com/gofiber/fiber/v3"
)
type res struct {
Status int `json:"status"`
Msg string `json:"msg"`
Data any `json:"data"`
Err any `json:"err"`
}
func R200(c fiber.Ctx, data any, msgs ...string) error {
r := &res{
Status: 200,
Msg: Msg200,
Data: data,
}
if len(msgs) > 0 && msgs[0] != "" {
r.Msg = msgs[0]
}
return c.JSON(r)
}
func RC(c fiber.Ctx, status int, args ...any) error {
return _r(c, &res{Status: status}, args...)
}
func RE(c fiber.Ctx, err error) error {
var re *Error
if errors.As(err, &re) {
return _r(c, re._r())
}
return R500(c, "", nil, err)
}
func _r(c fiber.Ctx, r *res, args ...any) error {
length := len(args)
switch length {
case 0:
break
case 1:
if msg, ok := args[0].(string); ok {
r.Msg = msg
} else {
r.Data = args[0]
}
case 2:
r.Data = args[1]
case 3:
r.Err = args[2]
}
if r.Msg == "" {
r.Msg = Msg(r.Status)
}
return c.Status(r.Status).JSON(r)
}
func R400(c fiber.Ctx, args ...any) error {
r := &res{
Status: 400,
}
return _r(c, r, args...)
}
func R401(c fiber.Ctx, args ...any) error {
r := &res{
Status: 401,
}
return _r(c, r, args...)
}
func R403(c fiber.Ctx, args ...any) error {
r := &res{
Status: 403,
}
return _r(c, r, args...)
}
func R500(c fiber.Ctx, args ...any) error {
r := &res{
Status: 500,
}
return _r(c, r, args...)
}
func R501(c fiber.Ctx, args ...any) error {
r := &res{
Status: 501,
}
return _r(c, r, args...)
}

44
pkg/tool/ctx.go Normal file
View File

@ -0,0 +1,44 @@
package tool
import (
"context"
"fmt"
"gitea.loveuer.com/yizhisec/packages/opt"
"time"
)
func Timeout(seconds ...int) (ctx context.Context) {
var (
duration time.Duration
)
if len(seconds) > 0 && seconds[0] > 0 {
duration = time.Duration(seconds[0]) * time.Second
} else {
duration = time.Duration(30) * time.Second
}
ctx, _ = context.WithTimeout(context.Background(), duration)
return
}
func CtxTimeout(ctx context.Context, seconds ...int) context.Context {
var (
duration time.Duration
)
if len(seconds) > 0 && seconds[0] > 0 {
duration = time.Duration(seconds[0]) * time.Second
} else {
duration = time.Duration(30) * time.Second
}
nctx, _ := context.WithTimeout(ctx, duration)
return nctx
}
func CtxTrace(ctx context.Context, key string) context.Context {
return context.WithValue(ctx, opt.TraceKey, fmt.Sprintf("%36s", key))
}

12
pkg/tool/gin.go Normal file
View File

@ -0,0 +1,12 @@
package tool
import "github.com/gin-gonic/gin"
func Local(c *gin.Context, key string) any {
data, ok := c.Get(key)
if !ok {
return nil
}
return data
}

50
pkg/tool/human.go Normal file
View File

@ -0,0 +1,50 @@
package tool
import "fmt"
func HumanDuration(nano int64) string {
duration := float64(nano)
unit := "ns"
if duration >= 1000 {
duration /= 1000
unit = "us"
}
if duration >= 1000 {
duration /= 1000
unit = "ms"
}
if duration >= 1000 {
duration /= 1000
unit = " s"
}
return fmt.Sprintf("%6.2f%s", duration, unit)
}
func HumanSize(size int64) string {
const (
_ = iota
KB = 1 << (10 * iota) // 1 KB = 1024 bytes
MB // 1 MB = 1024 KB
GB // 1 GB = 1024 MB
TB // 1 TB = 1024 GB
PB // 1 PB = 1024 TB
)
switch {
case size >= PB:
return fmt.Sprintf("%.2f PB", float64(size)/PB)
case size >= TB:
return fmt.Sprintf("%.2f TB", float64(size)/TB)
case size >= GB:
return fmt.Sprintf("%.2f GB", float64(size)/GB)
case size >= MB:
return fmt.Sprintf("%.2f MB", float64(size)/MB)
case size >= KB:
return fmt.Sprintf("%.2f KB", float64(size)/KB)
default:
return fmt.Sprintf("%d bytes", size)
}
}

59
pkg/tool/ip.go Normal file
View File

@ -0,0 +1,59 @@
package tool
import (
"net"
)
var (
privateIPv4Blocks []*net.IPNet
privateIPv6Blocks []*net.IPNet
)
func init() {
// IPv4私有地址段
for _, cidr := range []string{
"10.0.0.0/8", // A类私有地址
"172.16.0.0/12", // B类私有地址
"192.168.0.0/16", // C类私有地址
"169.254.0.0/16", // 链路本地地址
"127.0.0.0/8", // 环回地址
} {
_, block, _ := net.ParseCIDR(cidr)
privateIPv4Blocks = append(privateIPv4Blocks, block)
}
// IPv6私有地址段
for _, cidr := range []string{
"fc00::/7", // 唯一本地地址
"fe80::/10", // 链路本地地址
"::1/128", // 环回地址
} {
_, block, _ := net.ParseCIDR(cidr)
privateIPv6Blocks = append(privateIPv6Blocks, block)
}
}
func IsPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// 处理IPv4和IPv4映射的IPv6地址
if ip4 := ip.To4(); ip4 != nil {
for _, block := range privateIPv4Blocks {
if block.Contains(ip4) {
return true
}
}
return false
}
// 处理IPv6地址
for _, block := range privateIPv6Blocks {
if block.Contains(ip) {
return true
}
}
return false
}

76
pkg/tool/loadash.go Normal file
View File

@ -0,0 +1,76 @@
package tool
import "math"
func Map[T, R any](vals []T, fn func(item T, index int) R) []R {
var result = make([]R, len(vals))
for idx, v := range vals {
result[idx] = fn(v, idx)
}
return result
}
func Chunk[T any](vals []T, size int) [][]T {
if size <= 0 {
panic("Second parameter must be greater than 0")
}
chunksNum := len(vals) / size
if len(vals)%size != 0 {
chunksNum += 1
}
result := make([][]T, 0, chunksNum)
for i := 0; i < chunksNum; i++ {
last := (i + 1) * size
if last > len(vals) {
last = len(vals)
}
result = append(result, vals[i*size:last:last])
}
return result
}
// 对 vals 取样 x 个
func Sample[T any](vals []T, x int) []T {
if x < 0 {
panic("Second parameter can't be negative")
}
n := len(vals)
if n == 0 {
return []T{}
}
if x >= n {
return vals
}
// 处理x=1的特殊情况
if x == 1 {
return []T{vals[(n-1)/2]}
}
// 计算采样步长并生成结果数组
step := float64(n-1) / float64(x-1)
result := make([]T, x)
for i := 0; i < x; i++ {
// 计算采样位置并四舍五入
pos := float64(i) * step
index := int(math.Round(pos))
result[i] = vals[index]
}
return result
}
func If[T any](cond bool, trueVal, falseVal T) T {
if cond {
return trueVal
}
return falseVal
}

53
pkg/tool/must.go Normal file
View File

@ -0,0 +1,53 @@
package tool
import (
"context"
"gitea.loveuer.com/yizhisec/packages/logger"
"sync"
)
func Must(errs ...error) {
for _, err := range errs {
if err != nil {
logger.Panic(err.Error())
}
}
}
func MustWithData[T any](data T, err error) T {
Must(err)
return data
}
func MustStop(ctx context.Context, stopFns ...func(ctx context.Context) error) {
if len(stopFns) == 0 {
return
}
ok := make(chan struct{})
wg := &sync.WaitGroup{}
wg.Add(len(stopFns))
for _, fn := range stopFns {
go func() {
defer wg.Done()
if err := fn(ctx); err != nil {
logger.ErrorCtx(ctx, "stop function failed, err = %s", err.Error())
}
}()
}
go func() {
select {
case <-ctx.Done():
logger.FatalCtx(ctx, "stop function timeout, force down")
case _, _ = <-ok:
return
}
}()
wg.Wait()
close(ok)
}

84
pkg/tool/password.go Normal file
View File

@ -0,0 +1,84 @@
package tool
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"gitea.loveuer.com/yizhisec/packages/logger"
"golang.org/x/crypto/pbkdf2"
"regexp"
"strconv"
"strings"
)
const (
EncryptHeader string = "pbkdf2:sha256" // 用户密码加密
)
func NewPassword(password string) string {
return EncryptPassword(password, RandomString(8), int(RandomInt(50000)+100000))
}
func ComparePassword(in, db string) bool {
strs := strings.Split(db, "$")
if len(strs) != 3 {
logger.Error("password in db invalid: %s", db)
return false
}
encs := strings.Split(strs[0], ":")
if len(encs) != 3 {
logger.Error("password in db invalid: %s", db)
return false
}
encIteration, err := strconv.Atoi(encs[2])
if err != nil {
logger.Error("password in db invalid: %s, convert iter err: %s", db, err)
return false
}
return EncryptPassword(in, strs[1], encIteration) == db
}
func EncryptPassword(password, salt string, iter int) string {
hash := pbkdf2.Key([]byte(password), []byte(salt), iter, 32, sha256.New)
encrypted := hex.EncodeToString(hash)
return fmt.Sprintf("%s:%d$%s$%s", EncryptHeader, iter, salt, encrypted)
}
func CheckPassword(password string) error {
if len(password) < 8 || len(password) > 32 {
return errors.New("密码长度不符合")
}
var (
err error
match bool
patternList = []string{`[0-9]+`, `[a-z]+`, `[A-Z]+`, `[!@#%]+`} //, `[~!@#$%^&*?_-]+`}
matchAccount = 0
tips = []string{"缺少数字", "缺少小写字母", "缺少大写字母", "缺少'!@#%'"}
locktips = make([]string, 0)
)
for idx, pattern := range patternList {
match, err = regexp.MatchString(pattern, password)
if err != nil {
logger.Warn("regex match string err, reg_str: %s, err: %v", pattern, err)
return errors.New("密码强度不够")
}
if match {
matchAccount++
} else {
locktips = append(locktips, tips[idx])
}
}
if matchAccount < 3 {
return fmt.Errorf("密码强度不够, 可能 %s", strings.Join(locktips, ", "))
}
return nil
}

20
pkg/tool/password_test.go Normal file
View File

@ -0,0 +1,20 @@
package tool
import "testing"
func TestEncPassword(t *testing.T) {
password := "123456"
result := EncryptPassword(password, RandomString(8), 50000)
t.Logf("sum => %s", result)
}
func TestPassword(t *testing.T) {
p := "wahaha@123"
p = NewPassword(p)
t.Logf("password => %s", p)
result := ComparePassword("wahaha@123", p)
t.Logf("compare result => %v", result)
}

75
pkg/tool/random.go Normal file
View File

@ -0,0 +1,75 @@
package tool
import (
"crypto/rand"
"math/big"
mrand "math/rand"
)
var (
letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
letterNum = []byte("0123456789")
letterLow = []byte("abcdefghijklmnopqrstuvwxyz")
letterCap = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
letterSyb = []byte("!@#$%^&*()_+-=")
adjectives = []string{
"开心的", "灿烂的", "温暖的", "阳光的", "活泼的",
"聪明的", "优雅的", "幸运的", "甜蜜的", "勇敢的",
"宁静的", "热情的", "温柔的", "幽默的", "坚强的",
"迷人的", "神奇的", "快乐的", "健康的", "自由的",
"梦幻的", "勤劳的", "真诚的", "浪漫的", "自信的",
}
plants = []string{
"苹果", "香蕉", "橘子", "葡萄", "草莓",
"西瓜", "樱桃", "菠萝", "柠檬", "蜜桃",
"蓝莓", "芒果", "石榴", "甜瓜", "雪梨",
"番茄", "南瓜", "土豆", "青椒", "洋葱",
"黄瓜", "萝卜", "豌豆", "玉米", "蘑菇",
"菠菜", "茄子", "芹菜", "莲藕", "西兰花",
}
)
func RandomInt(max int64) int64 {
num, _ := rand.Int(rand.Reader, big.NewInt(max))
return num.Int64()
}
func RandomString(length int) string {
result := make([]byte, length)
for i := 0; i < length; i++ {
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
result[i] = letters[num.Int64()]
}
return string(result)
}
func RandomPassword(length int, withSymbol bool) string {
result := make([]byte, length)
kind := 3
if withSymbol {
kind++
}
for i := 0; i < length; i++ {
switch i % kind {
case 0:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterNum))))
result[i] = letterNum[num.Int64()]
case 1:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterLow))))
result[i] = letterLow[num.Int64()]
case 2:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterCap))))
result[i] = letterCap[num.Int64()]
case 3:
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterSyb))))
result[i] = letterSyb[num.Int64()]
}
}
return string(result)
}
func RandomName() string {
return adjectives[mrand.Intn(len(adjectives))] + plants[mrand.Intn(len(plants))]
}

11
pkg/tool/string.go Normal file
View File

@ -0,0 +1,11 @@
package tool
import "unsafe"
func BytesToString(b []byte) string {
return unsafe.String(unsafe.SliceData(b), len(b))
}
func StringToBytes(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}

124
pkg/tool/table.go Normal file
View File

@ -0,0 +1,124 @@
package tool
import (
"encoding/json"
"fmt"
"gitea.loveuer.com/yizhisec/packages/logger"
"github.com/jedib0t/go-pretty/v6/table"
"io"
"os"
"reflect"
"strings"
)
func TablePrinter(data any, writers ...io.Writer) {
var w io.Writer = os.Stdout
if len(writers) > 0 && writers[0] != nil {
w = writers[0]
}
t := table.NewWriter()
structPrinter(t, "", data)
_, _ = fmt.Fprintln(w, t.Render())
}
func structPrinter(w table.Writer, prefix string, item any) {
Start:
rv := reflect.ValueOf(item)
if rv.IsZero() {
return
}
for rv.Type().Kind() == reflect.Pointer {
rv = rv.Elem()
}
switch rv.Type().Kind() {
case reflect.Invalid,
reflect.Uintptr,
reflect.Chan,
reflect.Func,
reflect.UnsafePointer:
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Float32,
reflect.Float64,
reflect.Complex64,
reflect.Complex128,
reflect.Interface:
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), rv.Interface()})
case reflect.String:
val := rv.String()
if len(val) <= 160 {
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val})
return
}
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val[0:64] + "..." + val[len(val)-64:]})
case reflect.Array, reflect.Slice:
for i := 0; i < rv.Len(); i++ {
p := strings.Join([]string{prefix, fmt.Sprintf("[%d]", i)}, ".")
structPrinter(w, p, rv.Index(i).Interface())
}
case reflect.Map:
for _, k := range rv.MapKeys() {
structPrinter(w, fmt.Sprintf("%s.{%v}", prefix, k), rv.MapIndex(k).Interface())
}
case reflect.Pointer:
goto Start
case reflect.Struct:
for i := 0; i < rv.NumField(); i++ {
p := fmt.Sprintf("%s.%s", prefix, rv.Type().Field(i).Name)
field := rv.Field(i)
//log.Debug("TablePrinter: prefix: %s, field: %v", p, rv.Field(i))
if !field.CanInterface() {
return
}
structPrinter(w, p, field.Interface())
}
}
}
func TableMapPrinter(data []byte) {
m := make(map[string]any)
if err := json.Unmarshal(data, &m); err != nil {
logger.Warn(err.Error())
return
}
t := table.NewWriter()
addRow(t, "", m)
fmt.Println(t.Render())
}
func addRow(w table.Writer, prefix string, m any) {
rv := reflect.ValueOf(m)
switch rv.Type().Kind() {
case reflect.Map:
for _, k := range rv.MapKeys() {
key := k.String()
if prefix != "" {
key = strings.Join([]string{prefix, k.String()}, ".")
}
addRow(w, key, rv.MapIndex(k).Interface())
}
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
addRow(w, fmt.Sprintf("%s[%d]", prefix, i), rv.Index(i).Interface())
}
default:
w.AppendRow(table.Row{prefix, m})
}
}

73
pkg/tool/tools.go Normal file
View File

@ -0,0 +1,73 @@
package tool
import (
"fmt"
"math"
)
func Min[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
if a <= b {
return a
}
return b
}
func Mins[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
var val T
if len(vals) == 0 {
return val
}
val = vals[0]
for _, item := range vals[1:] {
if item < val {
val = item
}
}
return val
}
func Max[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
if a >= b {
return a
}
return b
}
func Maxs[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
var val T
if len(vals) == 0 {
return val
}
for _, item := range vals {
if item > val {
val = item
}
}
return val
}
func Sum[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
var sum T = 0
for i := range vals {
sum += vals[i]
}
return sum
}
func Percent(val, minVal, maxVal, minPercent, maxPercent float64) string {
return fmt.Sprintf(
"%d%%",
int(math.Round(
((val-minVal)/(maxVal-minVal)*(maxPercent-minPercent)+minPercent)*100,
)),
)
}

70
pkg/tool/tools_test.go Normal file
View File

@ -0,0 +1,70 @@
package tool
import "testing"
func TestPercent(t *testing.T) {
type args struct {
val float64
minVal float64
maxVal float64
minPercent float64
maxPercent float64
}
tests := []struct {
name string
args args
want string
}{
{
name: "case 1",
args: args{
val: 0.5,
minVal: 0,
maxVal: 1,
minPercent: 0,
maxPercent: 1,
},
want: "50%",
},
{
name: "case 2",
args: args{
val: 0.3,
minVal: 0.1,
maxVal: 0.6,
minPercent: 0,
maxPercent: 1,
},
want: "40%",
},
{
name: "case 3",
args: args{
val: 700,
minVal: 700,
maxVal: 766,
minPercent: 0.1,
maxPercent: 0.7,
},
want: "10%",
},
{
name: "case 4",
args: args{
val: 766,
minVal: 700,
maxVal: 766,
minPercent: 0.1,
maxPercent: 0.7,
},
want: "70%",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := Percent(tt.args.val, tt.args.minVal, tt.args.maxVal, tt.args.minPercent, tt.args.maxPercent); got != tt.want {
t.Errorf("Percent() = %v, want %v", got, tt.want)
}
})
}
}