wip: jwt
This commit is contained in:
4
pkg/database/cache/new.go
vendored
4
pkg/database/cache/new.go
vendored
@ -3,9 +3,9 @@ package cache
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"loveuer/utodo/pkg/tool"
|
||||
"net/url"
|
||||
|
||||
"gitea.loveuer.com/yizhisec/packages/tool"
|
||||
"github.com/dgraph-io/ristretto/v2"
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
@ -41,7 +41,7 @@ func New(opts ...Option) (Cache, error) {
|
||||
Password: password,
|
||||
})
|
||||
|
||||
if err = client.Ping(tool.CtxTimeout(cfg.ctx, 5)).Err(); err != nil {
|
||||
if err = client.Ping(tool.TimeoutCtx(cfg.ctx, 5)).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
2
pkg/database/cache/redis.go
vendored
2
pkg/database/cache/redis.go
vendored
@ -3,10 +3,10 @@ package cache
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"loveuer/utodo/pkg/tool"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.loveuer.com/yizhisec/packages/tool"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
139
pkg/jwt/jwt.go
Normal file
139
pkg/jwt/jwt.go
Normal file
@ -0,0 +1,139 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"loveuer/utodo/pkg/logger"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type PayloadFn[T any] func(data T) map[string]any
|
||||
|
||||
type Option[T any] func(*option[T])
|
||||
|
||||
type option[T any] struct {
|
||||
secret string
|
||||
payloadFn PayloadFn[T]
|
||||
}
|
||||
|
||||
func WithSecret[T any](secret string) Option[T] {
|
||||
return func(o *option[T]) {
|
||||
if secret != "" {
|
||||
o.secret = secret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithPayloadFn[T any](payloadFn PayloadFn[T]) Option[T] {
|
||||
return func(o *option[T]) {
|
||||
if payloadFn != nil {
|
||||
o.payloadFn = payloadFn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type app[T any] struct {
|
||||
secret string
|
||||
payloadFn func(data T) map[string]any
|
||||
}
|
||||
|
||||
func (a *app[T]) Generate(data T) (string, error) {
|
||||
|
||||
type MyClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
User T
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, MyClaims{User: data})
|
||||
|
||||
return token.SignedString([]byte(a.secret))
|
||||
}
|
||||
|
||||
func (a *app[T]) Parse(token string) (T, error) {
|
||||
type MyClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
User T
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
val T
|
||||
claims *jwt.Token
|
||||
)
|
||||
|
||||
if claims, err = jwt.ParseWithClaims(token, &MyClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(a.secret), nil
|
||||
}); err != nil {
|
||||
return val, err
|
||||
}
|
||||
|
||||
if !claims.Valid {
|
||||
return val, ErrTokenInvalid
|
||||
}
|
||||
|
||||
cv, ok := claims.Claims.(*MyClaims)
|
||||
if !ok {
|
||||
return val, ErrTokenInvalid
|
||||
}
|
||||
|
||||
return cv.User, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTokenInvalid = errors.New("token is invalid")
|
||||
defaultPayloadFn PayloadFn[any] = func(data any) map[string]any {
|
||||
ref := reflect.ValueOf(data)
|
||||
|
||||
if ref.Kind() == reflect.Pointer {
|
||||
ref = ref.Elem()
|
||||
}
|
||||
|
||||
if ref.Kind() != reflect.Struct {
|
||||
logger.WarnCtx(context.TODO(), "data is not a struct")
|
||||
return nil
|
||||
}
|
||||
|
||||
typ := ref.Type()
|
||||
num := typ.NumField()
|
||||
|
||||
payload := make(map[string]any, num)
|
||||
|
||||
for i := 0; i < num; i++ {
|
||||
field := typ.Field(i)
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
if tag == "" {
|
||||
tag = field.Name
|
||||
}
|
||||
|
||||
payload[tag] = ref.Field(i).Interface()
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
)
|
||||
|
||||
const defaultSecret = "{<Z]h/J>[5-F?s#D;~HGpxuBWi=ezNmb"
|
||||
|
||||
func New[T any](opts ...Option[T]) *app[T] {
|
||||
opt := &option[T]{
|
||||
secret: defaultSecret,
|
||||
payloadFn: func(data T) map[string]any {
|
||||
return defaultPayloadFn(data)
|
||||
},
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
o(opt)
|
||||
}
|
||||
|
||||
return &app[T]{
|
||||
secret: opt.secret,
|
||||
payloadFn: opt.payloadFn,
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@ package resp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
@ -42,20 +43,25 @@ func RE(c fiber.Ctx, err error) error {
|
||||
|
||||
func _r(c fiber.Ctx, r *res, args ...any) error {
|
||||
length := len(args)
|
||||
switch length {
|
||||
case 0:
|
||||
break
|
||||
case 1:
|
||||
|
||||
if length == 0 {
|
||||
goto END
|
||||
}
|
||||
|
||||
if length >= 1 {
|
||||
if msg, ok := args[0].(string); ok {
|
||||
r.Msg = msg
|
||||
} else {
|
||||
r.Data = args[0]
|
||||
}
|
||||
case 2:
|
||||
}
|
||||
|
||||
if length >= 2 {
|
||||
r.Data = args[1]
|
||||
case 3:
|
||||
}
|
||||
|
||||
if length >= 3 {
|
||||
r.Err = args[2]
|
||||
}
|
||||
END:
|
||||
|
||||
if r.Msg == "" {
|
||||
r.Msg = Msg(r.Status)
|
||||
|
9
pkg/sqlType/err.go
Normal file
9
pkg/sqlType/err.go
Normal file
@ -0,0 +1,9 @@
|
||||
package sqlType
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrConvertScanVal = errors.New("convert scan val to str err")
|
||||
ErrInvalidScanVal = errors.New("scan val invalid")
|
||||
ErrConvertVal = errors.New("convert err")
|
||||
)
|
76
pkg/sqlType/jsonb.go
Normal file
76
pkg/sqlType/jsonb.go
Normal file
@ -0,0 +1,76 @@
|
||||
package sqlType
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
)
|
||||
|
||||
type JSONB struct {
|
||||
Val pgtype.JSONB
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func NewJSONB(v interface{}) JSONB {
|
||||
j := new(JSONB)
|
||||
j.Val = pgtype.JSONB{}
|
||||
if err := j.Val.Set(v); err == nil {
|
||||
j.Valid = true
|
||||
return *j
|
||||
}
|
||||
|
||||
return *j
|
||||
}
|
||||
|
||||
func (j *JSONB) Set(value interface{}) error {
|
||||
if err := j.Val.Set(value); err != nil {
|
||||
j.Valid = false
|
||||
return err
|
||||
}
|
||||
|
||||
j.Valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *JSONB) Bind(model interface{}) error {
|
||||
return j.Val.AssignTo(model)
|
||||
}
|
||||
|
||||
func (j *JSONB) Scan(value interface{}) error {
|
||||
j.Val = pgtype.JSONB{}
|
||||
if value == nil {
|
||||
j.Valid = false
|
||||
return nil
|
||||
}
|
||||
|
||||
j.Valid = true
|
||||
|
||||
return j.Val.Scan(value)
|
||||
}
|
||||
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
if j.Valid {
|
||||
return j.Val.Value()
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (j JSONB) MarshalJSON() ([]byte, error) {
|
||||
if j.Valid {
|
||||
return j.Val.MarshalJSON()
|
||||
}
|
||||
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
|
||||
func (j *JSONB) UnmarshalJSON(b []byte) error {
|
||||
if string(b) == "null" {
|
||||
j.Valid = false
|
||||
return j.Val.UnmarshalJSON(b)
|
||||
}
|
||||
|
||||
return j.Val.UnmarshalJSON(b)
|
||||
}
|
71
pkg/sqlType/slice.num.go
Normal file
71
pkg/sqlType/slice.num.go
Normal file
@ -0,0 +1,71 @@
|
||||
package sqlType
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
type SliceNum[T ~int | ~int64 | ~uint | ~uint64 | ~uint32 | ~int32] []T
|
||||
|
||||
func (n *SliceNum[T]) Scan(val interface{}) error {
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return ErrConvertScanVal
|
||||
}
|
||||
|
||||
length := len(str)
|
||||
|
||||
if length <= 0 {
|
||||
*n = make(SliceNum[T], 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
if str[0] != '{' || str[length-1] != '}' {
|
||||
return ErrInvalidScanVal
|
||||
}
|
||||
|
||||
str = str[1 : length-1]
|
||||
if len(str) == 0 {
|
||||
*n = make(SliceNum[T], 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
numStrs := strings.Split(str, ",")
|
||||
nums := make([]T, len(numStrs))
|
||||
|
||||
for idx := range numStrs {
|
||||
num, err := cast.ToInt64E(strings.TrimSpace(numStrs[idx]))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: can't convert to %T", ErrConvertVal, T(0))
|
||||
}
|
||||
|
||||
nums[idx] = T(num)
|
||||
}
|
||||
|
||||
*n = nums
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n SliceNum[T]) Value() (driver.Value, error) {
|
||||
if n == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
if len(n) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
ss := make([]string, 0, len(n))
|
||||
for idx := range n {
|
||||
ss = append(ss, strconv.Itoa(int(n[idx])))
|
||||
}
|
||||
|
||||
s := strings.Join(ss, ", ")
|
||||
|
||||
return fmt.Sprintf("{%s}", s), nil
|
||||
}
|
110
pkg/sqlType/slice.str.go
Normal file
110
pkg/sqlType/slice.str.go
Normal file
@ -0,0 +1,110 @@
|
||||
package sqlType
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"loveuer/utodo/pkg/tool"
|
||||
)
|
||||
|
||||
type SliceStr[T ~string] []T
|
||||
|
||||
func (s *SliceStr[T]) Scan(val interface{}) error {
|
||||
|
||||
str, ok := val.(string)
|
||||
if !ok {
|
||||
return ErrConvertScanVal
|
||||
}
|
||||
|
||||
if len(str) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bs := make([]byte, 0, 128)
|
||||
bss := make([]byte, 0, 2*len(str))
|
||||
|
||||
quoteCount := 0
|
||||
|
||||
for idx := 1; idx < len(str)-1; idx++ {
|
||||
// 44: , 92: \ 34: "
|
||||
quote := str[idx]
|
||||
switch quote {
|
||||
case 44:
|
||||
if quote == 44 && str[idx-1] != 92 && quoteCount == 0 {
|
||||
if len(bs) > 0 {
|
||||
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
|
||||
bs = append([]byte{34}, bs...)
|
||||
bs = append(bs, 34)
|
||||
}
|
||||
|
||||
bss = append(bss, bs...)
|
||||
bss = append(bss, 44)
|
||||
}
|
||||
bs = bs[:0]
|
||||
} else {
|
||||
bs = append(bs, quote)
|
||||
}
|
||||
case 34:
|
||||
if str[idx-1] != 92 {
|
||||
quoteCount = (quoteCount + 1) % 2
|
||||
}
|
||||
bs = append(bs, quote)
|
||||
default:
|
||||
bs = append(bs, quote)
|
||||
}
|
||||
|
||||
//bs = append(bs, str[idx])
|
||||
}
|
||||
|
||||
if len(bs) > 0 {
|
||||
if !(bs[0] == 34 && bs[len(bs)-1] == 34) {
|
||||
bs = append([]byte{34}, bs...)
|
||||
bs = append(bs, 34)
|
||||
}
|
||||
|
||||
bss = append(bss, bs...)
|
||||
} else {
|
||||
if len(bss) > 2 {
|
||||
bss = bss[:len(bss)-2]
|
||||
}
|
||||
}
|
||||
|
||||
bss = append([]byte{'['}, append(bss, ']')...)
|
||||
|
||||
if err := json.Unmarshal(bss, s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s SliceStr[T]) Value() (driver.Value, error) {
|
||||
if s == nil {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
if len(s) == 0 {
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
encoder := json.NewEncoder(buf)
|
||||
encoder.SetEscapeHTML(false)
|
||||
|
||||
if err := encoder.Encode(s); err != nil {
|
||||
return "{}", err
|
||||
}
|
||||
|
||||
bs := buf.Bytes()
|
||||
|
||||
bs[0] = '{'
|
||||
|
||||
if bs[len(bs)-1] == 10 {
|
||||
bs = bs[:len(bs)-1]
|
||||
}
|
||||
|
||||
bs[len(bs)-1] = '}'
|
||||
|
||||
return tool.BytesToString(bs), nil
|
||||
}
|
@ -2,8 +2,6 @@ package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gitea.loveuer.com/yizhisec/packages/opt"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -23,7 +21,7 @@ func Timeout(seconds ...int) (ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func CtxTimeout(ctx context.Context, seconds ...int) context.Context {
|
||||
func TimeoutCtx(ctx context.Context, seconds ...int) context.Context {
|
||||
var (
|
||||
duration time.Duration
|
||||
)
|
||||
@ -38,7 +36,3 @@ func CtxTimeout(ctx context.Context, seconds ...int) context.Context {
|
||||
|
||||
return nctx
|
||||
}
|
||||
|
||||
func CtxTrace(ctx context.Context, key string) context.Context {
|
||||
return context.WithValue(ctx, opt.TraceKey, fmt.Sprintf("%36s", key))
|
||||
}
|
||||
|
@ -1,12 +0,0 @@
|
||||
package tool
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Local(c *gin.Context, key string) any {
|
||||
data, ok := c.Get(key)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
@ -2,7 +2,7 @@ package tool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"gitea.loveuer.com/yizhisec/packages/logger"
|
||||
"loveuer/utodo/pkg/logger"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
@ -5,11 +5,12 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitea.loveuer.com/yizhisec/packages/logger"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"loveuer/utodo/pkg/logger"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -3,12 +3,13 @@ package tool
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gitea.loveuer.com/yizhisec/packages/logger"
|
||||
"github.com/jedib0t/go-pretty/v6/table"
|
||||
"io"
|
||||
"loveuer/utodo/pkg/logger"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jedib0t/go-pretty/v6/table"
|
||||
)
|
||||
|
||||
func TablePrinter(data any, writers ...io.Writer) {
|
||||
|
Reference in New Issue
Block a user