feat: 0.1.2
1. login page && auth required 2. file dir cleanup
This commit is contained in:
@ -18,9 +18,10 @@ func Start(ctx context.Context) <-chan struct{} {
|
||||
return c.SendStatus(http.StatusOK)
|
||||
})
|
||||
|
||||
app.Get("/api/share/:code", handler.Fetch())
|
||||
app.Put("/api/share/:filename", handler.ShareNew()) // 获取上传 code, 分片大小
|
||||
app.Post("/api/share/:code", handler.ShareUpload()) // 分片上传接口
|
||||
app.Get("/ushare/:code", handler.Fetch())
|
||||
app.Put("/api/ushare/:filename", handler.AuthVerify(), handler.ShareNew()) // 获取上传 code, 分片大小
|
||||
app.Post("/api/ushare/:code", handler.ShareUpload()) // 分片上传接口
|
||||
app.Post("/api/uauth/login", handler.AuthLogin())
|
||||
|
||||
ready := make(chan struct{})
|
||||
ln, err := net.Listen("tcp", opt.Cfg.Address)
|
||||
|
@ -4,10 +4,14 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/loveuer/nf/nft/log"
|
||||
"github.com/loveuer/ushare/internal/model"
|
||||
"github.com/loveuer/ushare/internal/opt"
|
||||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||||
"github.com/spf13/viper"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -19,12 +23,12 @@ type metaInfo struct {
|
||||
last time.Time
|
||||
size int64
|
||||
cursor int64
|
||||
ip string
|
||||
user string
|
||||
}
|
||||
|
||||
func (m *metaInfo) generateMeta(code string) error {
|
||||
content := fmt.Sprintf("filename=%s\ncreated_at=%d\nsize=%d\nuploader_ip=%s",
|
||||
m.name, m.create.UnixMilli(), m.size, m.ip,
|
||||
content := fmt.Sprintf("filename=%s\ncreated_at=%d\nsize=%d\nuploader=%s",
|
||||
m.name, m.create.UnixMilli(), m.size, m.user,
|
||||
)
|
||||
|
||||
return os.WriteFile(opt.MetaPath(code), []byte(content), 0644)
|
||||
@ -62,7 +66,7 @@ func (m *meta) New(size int64, filename, ip string) (string, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
m.m[code] = &metaInfo{f: f, name: filename, last: now, size: size, cursor: 0, create: now, ip: ip}
|
||||
m.m[code] = &metaInfo{f: f, name: filename, last: now, size: size, cursor: 0, create: now, user: ip}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
@ -100,6 +104,7 @@ func (m *meta) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
m.ctx = ctx
|
||||
|
||||
// 清理 2 分钟内没有继续上传的 part
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@ -107,7 +112,7 @@ func (m *meta) Start(ctx context.Context) {
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
for code, info := range m.m {
|
||||
if now.Sub(info.last) > 1*time.Minute {
|
||||
if now.Sub(info.last) > 2*time.Minute {
|
||||
m.Lock()
|
||||
if err := info.f.Close(); err != nil {
|
||||
log.Warn("handler.Meta: [timer] close file failed, file = %s, err = %s", opt.FilePath(code), err.Error())
|
||||
@ -123,4 +128,57 @@ func (m *meta) Start(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 清理一天前的文件
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
_ = filepath.Walk(opt.Cfg.DataPath, func(path string, info os.FileInfo, err error) error {
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := filepath.Base(info.Name())
|
||||
if !strings.HasPrefix(name, ".meta.") {
|
||||
return nil
|
||||
}
|
||||
|
||||
viper.SetConfigFile(path)
|
||||
viper.SetConfigType("env")
|
||||
if err = viper.ReadInConfig(); err != nil {
|
||||
// todo log
|
||||
return nil
|
||||
}
|
||||
|
||||
mi := new(model.Meta)
|
||||
|
||||
if err = viper.Unmarshal(mi); err != nil {
|
||||
// todo log
|
||||
return nil
|
||||
}
|
||||
|
||||
code := strings.TrimPrefix(name, ".meta.")
|
||||
|
||||
if now.Sub(time.UnixMilli(mi.CreatedAt)) > 24*time.Hour {
|
||||
|
||||
log.Debug("controller.meta: file out of date, code = %s, user_key = %s", code, mi.Uploader)
|
||||
|
||||
os.RemoveAll(opt.FilePath(code))
|
||||
os.RemoveAll(path)
|
||||
|
||||
m.Lock()
|
||||
delete(m.m, code)
|
||||
m.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
86
internal/controller/user.go
Normal file
86
internal/controller/user.go
Normal file
@ -0,0 +1,86 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/loveuer/ushare/internal/model"
|
||||
"github.com/loveuer/ushare/internal/opt"
|
||||
"github.com/loveuer/ushare/internal/pkg/tool"
|
||||
"github.com/pkg/errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type userManager struct {
|
||||
sync.Mutex
|
||||
ctx context.Context
|
||||
um map[string]*model.User
|
||||
}
|
||||
|
||||
func (um *userManager) Login(username string, password string) (*model.User, error) {
|
||||
var (
|
||||
now = time.Now()
|
||||
)
|
||||
|
||||
if username != "admin" {
|
||||
return nil, errors.New("账号或密码错误")
|
||||
}
|
||||
|
||||
if !tool.ComparePassword(password, opt.Cfg.Auth) {
|
||||
return nil, errors.New("账号或密码错误")
|
||||
}
|
||||
|
||||
op := &model.User{
|
||||
Id: 1,
|
||||
Username: username,
|
||||
LoginAt: now.Unix(),
|
||||
Token: tool.RandomString(32),
|
||||
}
|
||||
|
||||
um.Lock()
|
||||
defer um.Unlock()
|
||||
um.um[op.Token] = op
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (um *userManager) Verify(token string) (*model.User, error) {
|
||||
um.Lock()
|
||||
defer um.Unlock()
|
||||
|
||||
op, ok := um.um[token]
|
||||
if !ok {
|
||||
return nil, errors.New("未登录或凭证已失效, 请重新登录")
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (um *userManager) Start(ctx context.Context) {
|
||||
um.ctx = ctx
|
||||
|
||||
go func() {
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-um.ctx.Done():
|
||||
return
|
||||
case now := <-ticker.C:
|
||||
um.Lock()
|
||||
for _, op := range um.um {
|
||||
if now.Sub(time.UnixMilli(op.LoginAt)) > 8*time.Hour {
|
||||
delete(um.um, op.Token)
|
||||
}
|
||||
}
|
||||
um.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var (
|
||||
UserManager = &userManager{
|
||||
um: make(map[string]*model.User),
|
||||
}
|
||||
)
|
70
internal/handler/auth.go
Normal file
70
internal/handler/auth.go
Normal file
@ -0,0 +1,70 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/loveuer/nf"
|
||||
"github.com/loveuer/ushare/internal/controller"
|
||||
"github.com/loveuer/ushare/internal/model"
|
||||
"github.com/loveuer/ushare/internal/opt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func AuthVerify() nf.HandlerFunc {
|
||||
tokenFn := func(c *nf.Ctx) (token string) {
|
||||
if token = c.Get("Authorization"); token != "" {
|
||||
return
|
||||
}
|
||||
|
||||
token = c.Cookies("ushare")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
return func(c *nf.Ctx) error {
|
||||
if opt.Cfg.Auth == "" {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
token := tokenFn(c)
|
||||
if token == "" {
|
||||
return c.Status(http.StatusUnauthorized).JSON(map[string]string{"error": "unauthorized"})
|
||||
}
|
||||
|
||||
op, err := controller.UserManager.Verify(token)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusUnauthorized).JSON(map[string]string{"error": "unauthorized", "msg": err.Error()})
|
||||
}
|
||||
|
||||
c.Locals("user", op)
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func AuthLogin() nf.HandlerFunc {
|
||||
return func(c *nf.Ctx) error {
|
||||
type Req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
req Req
|
||||
op *model.User
|
||||
)
|
||||
|
||||
if err = c.BodyParser(&req); err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(map[string]string{"msg": "错误的用户名或密码<1>"})
|
||||
}
|
||||
|
||||
if op, err = controller.UserManager.Login(req.Username, req.Password); err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(map[string]string{"msg": err.Error()})
|
||||
}
|
||||
|
||||
header := fmt.Sprintf("ushare=%s; Path=/; Max-Age=%d", op.Token, 8*3600)
|
||||
c.SetHeader("Set-Cookie", header)
|
||||
|
||||
return c.Status(http.StatusOK).JSON(map[string]any{"data": op})
|
||||
}
|
||||
}
|
@ -49,11 +49,6 @@ func Fetch() nf.HandlerFunc {
|
||||
|
||||
func ShareNew() nf.HandlerFunc {
|
||||
return func(c *nf.Ctx) error {
|
||||
|
||||
if opt.Cfg.Auth {
|
||||
return c.SendStatus(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
filename := strings.TrimSpace(c.Param("filename"))
|
||||
if filename == "" {
|
||||
return c.Status(http.StatusBadRequest).JSON(map[string]string{"msg": "filename required"})
|
||||
|
@ -1,8 +1,8 @@
|
||||
package model
|
||||
|
||||
type Meta struct {
|
||||
Filename string `json:"filename" mapstructure:"filename"`
|
||||
CreatedAt int64 `json:"created_at" mapstructure:"created_at"`
|
||||
Size int64 `json:"size" mapstructure:"size"`
|
||||
UploaderIp string `json:"uploader_ip" mapstructure:"uploader_ip"`
|
||||
Filename string `json:"filename" mapstructure:"filename"`
|
||||
CreatedAt int64 `json:"created_at" mapstructure:"created_at"`
|
||||
Size int64 `json:"size" mapstructure:"size"`
|
||||
Uploader string `json:"uploader" mapstructure:"uploader"`
|
||||
}
|
||||
|
10
internal/model/user.go
Normal file
10
internal/model/user.go
Normal file
@ -0,0 +1,10 @@
|
||||
package model
|
||||
|
||||
type User struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Key string `json:"key"`
|
||||
Password string `json:"-"`
|
||||
LoginAt int64 `json:"login_at"`
|
||||
Token string `json:"token"`
|
||||
}
|
@ -1,12 +1,25 @@
|
||||
package opt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/loveuer/nf/nft/log"
|
||||
"github.com/loveuer/ushare/internal/pkg/tool"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
Debug bool
|
||||
Address string
|
||||
DataPath string
|
||||
Auth bool
|
||||
Auth string
|
||||
}
|
||||
|
||||
var (
|
||||
Cfg = &config{}
|
||||
)
|
||||
|
||||
func Init(_ context.Context) {
|
||||
if Cfg.Auth != "" {
|
||||
Cfg.Auth = tool.NewPassword(Cfg.Auth)
|
||||
log.Debug("opt.Init: encrypted password = %s", Cfg.Auth)
|
||||
}
|
||||
}
|
||||
|
51
internal/pkg/db/client.go
Normal file
51
internal/pkg/db/client.go
Normal file
@ -0,0 +1,51 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"au99999/internal/opt"
|
||||
"au99999/pkg/tool"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var Default *Client
|
||||
|
||||
type DBType string
|
||||
|
||||
const (
|
||||
DBTypeSqlite = "sqlite"
|
||||
DBTypeMysql = "mysql"
|
||||
DBTypePostgres = "postgres"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
ctx context.Context
|
||||
cli *gorm.DB
|
||||
dbType DBType
|
||||
}
|
||||
|
||||
func (c *Client) Type() DBType {
|
||||
return c.dbType
|
||||
}
|
||||
|
||||
func (c *Client) Session(ctxs ...context.Context) *gorm.DB {
|
||||
var ctx context.Context
|
||||
if len(ctxs) > 0 && ctxs[0] != nil {
|
||||
ctx = ctxs[0]
|
||||
} else {
|
||||
ctx = tool.Timeout(30)
|
||||
}
|
||||
|
||||
session := c.cli.Session(&gorm.Session{Context: ctx})
|
||||
|
||||
if opt.Cfg.Debug {
|
||||
session = session.Debug()
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
d, _ := c.cli.DB()
|
||||
d.Close()
|
||||
}
|
65
internal/pkg/db/new.go
Normal file
65
internal/pkg/db/new.go
Normal file
@ -0,0 +1,65 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const tip = `example:
|
||||
for sqlite -> sqlite::<filepath>
|
||||
sqlite::data.sqlite
|
||||
sqlite::/data/data.db
|
||||
for mysql -> mysql::<gorm_dsn>
|
||||
mysql::user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local
|
||||
for postgres -> postgres::<gorm_dsn>
|
||||
postgres::host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai
|
||||
`
|
||||
|
||||
func New(ctx context.Context, uri string) (*Client, error) {
|
||||
parts := strings.SplitN(uri, "::", 2)
|
||||
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("db.Init: db uri invalid\n%s", tip)
|
||||
}
|
||||
|
||||
c := &Client{}
|
||||
|
||||
var (
|
||||
err error
|
||||
dsn = parts[1]
|
||||
)
|
||||
|
||||
switch parts[0] {
|
||||
case "sqlite":
|
||||
c.dbType = DBTypeSqlite
|
||||
c.cli, err = gorm.Open(sqlite.Open(dsn))
|
||||
case "mysql":
|
||||
c.dbType = DBTypeMysql
|
||||
c.cli, err = gorm.Open(mysql.Open(dsn))
|
||||
case "postgres":
|
||||
c.dbType = DBTypePostgres
|
||||
c.cli, err = gorm.Open(postgres.Open(dsn))
|
||||
default:
|
||||
return nil, fmt.Errorf("db type only support: [sqlite, mysql, postgres], unsupported db type: %s", parts[0])
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.Init: open %s with dsn:%s, err: %w", parts[0], dsn, err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func Init(ctx context.Context, uri string) (err error) {
|
||||
if Default, err = New(ctx, uri); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
38
internal/pkg/tool/ctx.go
Normal file
38
internal/pkg/tool/ctx.go
Normal file
@ -0,0 +1,38 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Timeout(seconds ...int) (ctx context.Context) {
|
||||
var (
|
||||
duration time.Duration
|
||||
)
|
||||
|
||||
if len(seconds) > 0 && seconds[0] > 0 {
|
||||
duration = time.Duration(seconds[0]) * time.Second
|
||||
} else {
|
||||
duration = time.Duration(30) * time.Second
|
||||
}
|
||||
|
||||
ctx, _ = context.WithTimeout(context.Background(), duration)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TimeoutCtx(ctx context.Context, seconds ...int) context.Context {
|
||||
var (
|
||||
duration time.Duration
|
||||
)
|
||||
|
||||
if len(seconds) > 0 && seconds[0] > 0 {
|
||||
duration = time.Duration(seconds[0]) * time.Second
|
||||
} else {
|
||||
duration = time.Duration(30) * time.Second
|
||||
}
|
||||
|
||||
nctx, _ := context.WithTimeout(ctx, duration)
|
||||
|
||||
return nctx
|
||||
}
|
50
internal/pkg/tool/human.go
Normal file
50
internal/pkg/tool/human.go
Normal file
@ -0,0 +1,50 @@
|
||||
package tool
|
||||
|
||||
import "fmt"
|
||||
|
||||
func HumanDuration(nano int64) string {
|
||||
duration := float64(nano)
|
||||
unit := "ns"
|
||||
if duration >= 1000 {
|
||||
duration /= 1000
|
||||
unit = "us"
|
||||
}
|
||||
|
||||
if duration >= 1000 {
|
||||
duration /= 1000
|
||||
unit = "ms"
|
||||
}
|
||||
|
||||
if duration >= 1000 {
|
||||
duration /= 1000
|
||||
unit = " s"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%6.2f%s", duration, unit)
|
||||
}
|
||||
|
||||
func HumanSize(size int64) string {
|
||||
const (
|
||||
_ = iota
|
||||
KB = 1 << (10 * iota) // 1 KB = 1024 bytes
|
||||
MB // 1 MB = 1024 KB
|
||||
GB // 1 GB = 1024 MB
|
||||
TB // 1 TB = 1024 GB
|
||||
PB // 1 PB = 1024 TB
|
||||
)
|
||||
|
||||
switch {
|
||||
case size >= PB:
|
||||
return fmt.Sprintf("%.2f PB", float64(size)/PB)
|
||||
case size >= TB:
|
||||
return fmt.Sprintf("%.2f TB", float64(size)/TB)
|
||||
case size >= GB:
|
||||
return fmt.Sprintf("%.2f GB", float64(size)/GB)
|
||||
case size >= MB:
|
||||
return fmt.Sprintf("%.2f MB", float64(size)/MB)
|
||||
case size >= KB:
|
||||
return fmt.Sprintf("%.2f KB", float64(size)/KB)
|
||||
default:
|
||||
return fmt.Sprintf("%d bytes", size)
|
||||
}
|
||||
}
|
68
internal/pkg/tool/loadash.go
Normal file
68
internal/pkg/tool/loadash.go
Normal file
@ -0,0 +1,68 @@
|
||||
package tool
|
||||
|
||||
import "math"
|
||||
|
||||
func Map[T, R any](vals []T, fn func(item T, index int) R) []R {
|
||||
var result = make([]R, len(vals))
|
||||
for idx, v := range vals {
|
||||
result[idx] = fn(v, idx)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func Chunk[T any](vals []T, size int) [][]T {
|
||||
if size <= 0 {
|
||||
panic("Second parameter must be greater than 0")
|
||||
}
|
||||
|
||||
chunksNum := len(vals) / size
|
||||
if len(vals)%size != 0 {
|
||||
chunksNum += 1
|
||||
}
|
||||
|
||||
result := make([][]T, 0, chunksNum)
|
||||
|
||||
for i := 0; i < chunksNum; i++ {
|
||||
last := (i + 1) * size
|
||||
if last > len(vals) {
|
||||
last = len(vals)
|
||||
}
|
||||
result = append(result, vals[i*size:last:last])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// 对 vals 取样 x 个
|
||||
func Sample[T any](vals []T, x int) []T {
|
||||
if x < 0 {
|
||||
panic("Second parameter can't be negative")
|
||||
}
|
||||
|
||||
n := len(vals)
|
||||
if n == 0 {
|
||||
return []T{}
|
||||
}
|
||||
|
||||
if x >= n {
|
||||
return vals
|
||||
}
|
||||
|
||||
// 处理x=1的特殊情况
|
||||
if x == 1 {
|
||||
return []T{vals[(n-1)/2]}
|
||||
}
|
||||
|
||||
// 计算采样步长并生成结果数组
|
||||
step := float64(n-1) / float64(x-1)
|
||||
result := make([]T, x)
|
||||
|
||||
for i := 0; i < x; i++ {
|
||||
// 计算采样位置并四舍五入
|
||||
pos := float64(i) * step
|
||||
index := int(math.Round(pos))
|
||||
result[i] = vals[index]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
11
internal/pkg/tool/must.go
Normal file
11
internal/pkg/tool/must.go
Normal file
@ -0,0 +1,11 @@
|
||||
package tool
|
||||
|
||||
import "github.com/loveuer/nf/nft/log"
|
||||
|
||||
func Must(errs ...error) {
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
log.Panic(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
84
internal/pkg/tool/password.go
Normal file
84
internal/pkg/tool/password.go
Normal file
@ -0,0 +1,84 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/loveuer/nf/nft/log"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
EncryptHeader string = "pbkdf2:sha256" // 用户密码加密
|
||||
)
|
||||
|
||||
func NewPassword(password string) string {
|
||||
return EncryptPassword(password, RandomString(8), int(RandomInt(50000)+100000))
|
||||
}
|
||||
|
||||
func ComparePassword(in, db string) bool {
|
||||
strs := strings.Split(db, "$")
|
||||
if len(strs) != 3 {
|
||||
log.Error("password in db invalid: %s", db)
|
||||
return false
|
||||
}
|
||||
|
||||
encs := strings.Split(strs[0], ":")
|
||||
if len(encs) != 3 {
|
||||
log.Error("password in db invalid: %s", db)
|
||||
return false
|
||||
}
|
||||
|
||||
encIteration, err := strconv.Atoi(encs[2])
|
||||
if err != nil {
|
||||
log.Error("password in db invalid: %s, convert iter err: %s", db, err)
|
||||
return false
|
||||
}
|
||||
|
||||
return EncryptPassword(in, strs[1], encIteration) == db
|
||||
}
|
||||
|
||||
func EncryptPassword(password, salt string, iter int) string {
|
||||
hash := pbkdf2.Key([]byte(password), []byte(salt), iter, 32, sha256.New)
|
||||
encrypted := hex.EncodeToString(hash)
|
||||
return fmt.Sprintf("%s:%d$%s$%s", EncryptHeader, iter, salt, encrypted)
|
||||
}
|
||||
|
||||
func CheckPassword(password string) error {
|
||||
if len(password) < 8 || len(password) > 32 {
|
||||
return errors.New("密码长度不符合")
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
match bool
|
||||
patternList = []string{`[0-9]+`, `[a-z]+`, `[A-Z]+`, `[!@#%]+`} //, `[~!@#$%^&*?_-]+`}
|
||||
matchAccount = 0
|
||||
tips = []string{"缺少数字", "缺少小写字母", "缺少大写字母", "缺少'!@#%'"}
|
||||
locktips = make([]string, 0)
|
||||
)
|
||||
|
||||
for idx, pattern := range patternList {
|
||||
match, err = regexp.MatchString(pattern, password)
|
||||
if err != nil {
|
||||
log.Warn("regex match string err, reg_str: %s, err: %v", pattern, err)
|
||||
return errors.New("密码强度不够")
|
||||
}
|
||||
|
||||
if match {
|
||||
matchAccount++
|
||||
} else {
|
||||
locktips = append(locktips, tips[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if matchAccount < 3 {
|
||||
return fmt.Errorf("密码强度不够, 可能 %s", strings.Join(locktips, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
20
internal/pkg/tool/password_test.go
Normal file
20
internal/pkg/tool/password_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package tool
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEncPassword(t *testing.T) {
|
||||
password := "123456"
|
||||
|
||||
result := EncryptPassword(password, RandomString(8), 50000)
|
||||
|
||||
t.Logf("sum => %s", result)
|
||||
}
|
||||
|
||||
func TestPassword(t *testing.T) {
|
||||
p := "wahaha@123"
|
||||
p = NewPassword(p)
|
||||
t.Logf("password => %s", p)
|
||||
|
||||
result := ComparePassword("wahaha@123", p)
|
||||
t.Logf("compare result => %v", result)
|
||||
}
|
54
internal/pkg/tool/random.go
Normal file
54
internal/pkg/tool/random.go
Normal file
@ -0,0 +1,54 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
var (
|
||||
letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
letterNum = []byte("0123456789")
|
||||
letterLow = []byte("abcdefghijklmnopqrstuvwxyz")
|
||||
letterCap = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
letterSyb = []byte("!@#$%^&*()_+-=")
|
||||
)
|
||||
|
||||
func RandomInt(max int64) int64 {
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(max))
|
||||
return num.Int64()
|
||||
}
|
||||
|
||||
func RandomString(length int) string {
|
||||
result := make([]byte, length)
|
||||
for i := 0; i < length; i++ {
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
result[i] = letters[num.Int64()]
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func RandomPassword(length int, withSymbol bool) string {
|
||||
result := make([]byte, length)
|
||||
kind := 3
|
||||
if withSymbol {
|
||||
kind++
|
||||
}
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
switch i % kind {
|
||||
case 0:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterNum))))
|
||||
result[i] = letterNum[num.Int64()]
|
||||
case 1:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterLow))))
|
||||
result[i] = letterLow[num.Int64()]
|
||||
case 2:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterCap))))
|
||||
result[i] = letterCap[num.Int64()]
|
||||
case 3:
|
||||
num, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letterSyb))))
|
||||
result[i] = letterSyb[num.Int64()]
|
||||
}
|
||||
}
|
||||
return string(result)
|
||||
}
|
124
internal/pkg/tool/table.go
Normal file
124
internal/pkg/tool/table.go
Normal file
@ -0,0 +1,124 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/jedib0t/go-pretty/v6/table"
|
||||
"github.com/loveuer/nf/nft/log"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func TablePrinter(data any, writers ...io.Writer) {
|
||||
var w io.Writer = os.Stdout
|
||||
if len(writers) > 0 && writers[0] != nil {
|
||||
w = writers[0]
|
||||
}
|
||||
|
||||
t := table.NewWriter()
|
||||
structPrinter(t, "", data)
|
||||
_, _ = fmt.Fprintln(w, t.Render())
|
||||
}
|
||||
|
||||
func structPrinter(w table.Writer, prefix string, item any) {
|
||||
Start:
|
||||
rv := reflect.ValueOf(item)
|
||||
if rv.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
for rv.Type().Kind() == reflect.Pointer {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
|
||||
switch rv.Type().Kind() {
|
||||
case reflect.Invalid,
|
||||
reflect.Uintptr,
|
||||
reflect.Chan,
|
||||
reflect.Func,
|
||||
reflect.UnsafePointer:
|
||||
case reflect.Bool,
|
||||
reflect.Int,
|
||||
reflect.Int8,
|
||||
reflect.Int16,
|
||||
reflect.Int32,
|
||||
reflect.Int64,
|
||||
reflect.Uint,
|
||||
reflect.Uint8,
|
||||
reflect.Uint16,
|
||||
reflect.Uint32,
|
||||
reflect.Uint64,
|
||||
reflect.Float32,
|
||||
reflect.Float64,
|
||||
reflect.Complex64,
|
||||
reflect.Complex128,
|
||||
reflect.Interface:
|
||||
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), rv.Interface()})
|
||||
case reflect.String:
|
||||
val := rv.String()
|
||||
if len(val) <= 160 {
|
||||
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val})
|
||||
return
|
||||
}
|
||||
|
||||
w.AppendRow(table.Row{strings.TrimPrefix(prefix, "."), val[0:64] + "..." + val[len(val)-64:]})
|
||||
case reflect.Array, reflect.Slice:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
p := strings.Join([]string{prefix, fmt.Sprintf("[%d]", i)}, ".")
|
||||
structPrinter(w, p, rv.Index(i).Interface())
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, k := range rv.MapKeys() {
|
||||
structPrinter(w, fmt.Sprintf("%s.{%v}", prefix, k), rv.MapIndex(k).Interface())
|
||||
}
|
||||
case reflect.Pointer:
|
||||
goto Start
|
||||
case reflect.Struct:
|
||||
for i := 0; i < rv.NumField(); i++ {
|
||||
p := fmt.Sprintf("%s.%s", prefix, rv.Type().Field(i).Name)
|
||||
field := rv.Field(i)
|
||||
|
||||
//log.Debug("TablePrinter: prefix: %s, field: %v", p, rv.Field(i))
|
||||
|
||||
if !field.CanInterface() {
|
||||
return
|
||||
}
|
||||
|
||||
structPrinter(w, p, field.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TableMapPrinter(data []byte) {
|
||||
m := make(map[string]any)
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
log.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
t := table.NewWriter()
|
||||
addRow(t, "", m)
|
||||
fmt.Println(t.Render())
|
||||
}
|
||||
|
||||
func addRow(w table.Writer, prefix string, m any) {
|
||||
rv := reflect.ValueOf(m)
|
||||
switch rv.Type().Kind() {
|
||||
case reflect.Map:
|
||||
for _, k := range rv.MapKeys() {
|
||||
key := k.String()
|
||||
if prefix != "" {
|
||||
key = strings.Join([]string{prefix, k.String()}, ".")
|
||||
}
|
||||
addRow(w, key, rv.MapIndex(k).Interface())
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
addRow(w, fmt.Sprintf("%s[%d]", prefix, i), rv.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
w.AppendRow(table.Row{prefix, m})
|
||||
}
|
||||
}
|
73
internal/pkg/tool/tools.go
Normal file
73
internal/pkg/tool/tools.go
Normal file
@ -0,0 +1,73 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
func Min[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func Mins[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
|
||||
var val T
|
||||
|
||||
if len(vals) == 0 {
|
||||
return val
|
||||
}
|
||||
|
||||
val = vals[0]
|
||||
|
||||
for _, item := range vals[1:] {
|
||||
if item < val {
|
||||
val = item
|
||||
}
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
func Max[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
|
||||
if a >= b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func Maxs[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
|
||||
var val T
|
||||
|
||||
if len(vals) == 0 {
|
||||
return val
|
||||
}
|
||||
|
||||
for _, item := range vals {
|
||||
if item > val {
|
||||
val = item
|
||||
}
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
func Sum[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](vals ...T) T {
|
||||
var sum T = 0
|
||||
for i := range vals {
|
||||
sum += vals[i]
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func Percent(val, minVal, maxVal, minPercent, maxPercent float64) string {
|
||||
return fmt.Sprintf(
|
||||
"%d%%",
|
||||
int(math.Round(
|
||||
((val-minVal)/(maxVal-minVal)*(maxPercent-minPercent)+minPercent)*100,
|
||||
)),
|
||||
)
|
||||
}
|
70
internal/pkg/tool/tools_test.go
Normal file
70
internal/pkg/tool/tools_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
package tool
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPercent(t *testing.T) {
|
||||
type args struct {
|
||||
val float64
|
||||
minVal float64
|
||||
maxVal float64
|
||||
minPercent float64
|
||||
maxPercent float64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "case 1",
|
||||
args: args{
|
||||
val: 0.5,
|
||||
minVal: 0,
|
||||
maxVal: 1,
|
||||
minPercent: 0,
|
||||
maxPercent: 1,
|
||||
},
|
||||
want: "50%",
|
||||
},
|
||||
{
|
||||
name: "case 2",
|
||||
args: args{
|
||||
val: 0.3,
|
||||
minVal: 0.1,
|
||||
maxVal: 0.6,
|
||||
minPercent: 0,
|
||||
maxPercent: 1,
|
||||
},
|
||||
want: "40%",
|
||||
},
|
||||
{
|
||||
name: "case 3",
|
||||
args: args{
|
||||
val: 700,
|
||||
minVal: 700,
|
||||
maxVal: 766,
|
||||
minPercent: 0.1,
|
||||
maxPercent: 0.7,
|
||||
},
|
||||
want: "10%",
|
||||
},
|
||||
{
|
||||
name: "case 4",
|
||||
args: args{
|
||||
val: 766,
|
||||
minVal: 700,
|
||||
maxVal: 766,
|
||||
minPercent: 0.1,
|
||||
maxPercent: 0.7,
|
||||
},
|
||||
want: "70%",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := Percent(tt.args.val, tt.args.minVal, tt.args.maxVal, tt.args.minPercent, tt.args.maxPercent); got != tt.want {
|
||||
t.Errorf("Percent() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user