feat: update cache; fix: user logout

This commit is contained in:
loveuer
2025-01-14 00:42:18 -08:00
parent 534b9586f2
commit 32b30ae183
16 changed files with 281 additions and 127 deletions

View File

@ -2,6 +2,7 @@ package cmd
import (
"context"
"ultone/internal/api"
"ultone/internal/controller"
"ultone/internal/database/cache"
@ -12,19 +13,17 @@ import (
"ultone/internal/tool"
)
var (
filename string
)
var filename string
func execute(ctx context.Context) error {
tool.Must(opt.Init(filename))
tool.Must(db.Init(ctx, opt.Cfg.DB.Uri))
tool.Must(cache.Init())
tool.Must(cache.Init(ctx, opt.Cfg.Cache.Uri))
// todo: if elastic search required
//tool.Must(es.Init(ctx, opt.Cfg.ES.Uri))
// tool.Must(es.Init(ctx, opt.Cfg.ES.Uri))
// 或者使用 https://github.com/olivere/elastic
//tool.Must(elastic.Init(ctx, opt.Cfg.ES.Uri))
// tool.Must(elastic.Init(ctx, opt.Cfg.ES.Uri))
// todo: if nebula required
// tool.Must(nebula.Init(ctx, opt.Cfg.Nebula))
@ -34,7 +33,7 @@ func execute(ctx context.Context) error {
tool.Must(api.Start(ctx))
// todo: if need some cli operation, should start local unix rpc svc
//tool.Must(unix.Start(ctx))
// tool.Must(unix.Start(ctx))
<-ctx.Done()

63
internal/database/cache/cache.go vendored Normal file
View File

@ -0,0 +1,63 @@
package cache
import (
"context"
"encoding/json"
"errors"
"time"
)
type Cache interface {
Get(ctx context.Context, key string) ([]byte, error)
Gets(ctx context.Context, keys ...string) ([][]byte, error)
GetScan(ctx context.Context, key string) Scanner
GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
GetExScan(ctx context.Context, key string, duration time.Duration) Scanner
// Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
Set(ctx context.Context, key string, value any) error
Sets(ctx context.Context, vm map[string]any) error
// SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
SetEx(ctx context.Context, key string, value any, duration time.Duration) error
Del(ctx context.Context, keys ...string) error
Close() error
}
var Client Cache
type Scanner interface {
Scan(model any) error
}
type encoded_value interface {
MarshalBinary() ([]byte, error)
}
type decoded_value interface {
UnmarshalBinary(bs []byte) error
}
const (
Prefix = "upp:"
)
var ErrorKeyNotFound = errors.New("key not found")
func handleValue(value any) ([]byte, error) {
var (
bs []byte
err error
)
switch value.(type) {
case []byte:
return value.([]byte), nil
}
if imp, ok := value.(encoded_value); ok {
bs, err = imp.MarshalBinary()
} else {
bs, err = json.Marshal(value)
}
return bs, err
}

View File

@ -2,13 +2,13 @@ package cache
import (
"context"
"time"
"github.com/hashicorp/golang-lru/v2/expirable"
_ "github.com/hashicorp/golang-lru/v2/expirable"
"time"
"ultone/internal/interfaces"
)
var _ interfaces.Cacher = (*_lru)(nil)
var _ Cache = (*_lru)(nil)
type _lru struct {
client *expirable.LRU[string, *_lru_value]
@ -38,6 +38,24 @@ func (l *_lru) Get(ctx context.Context, key string) ([]byte, error) {
return v.bs, nil
}
func (l *_lru) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
bss := make([][]byte, 0, len(keys))
for _, key := range keys {
bs, err := l.Get(ctx, key)
if err != nil {
return nil, err
}
bss = append(bss, bs)
}
return bss, nil
}
func (l *_lru) GetScan(ctx context.Context, key string) Scanner {
return newScanner(l.Get(ctx, key))
}
func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
v, ok := l.client.Get(key)
if !ok {
@ -64,6 +82,10 @@ func (l *_lru) GetEx(ctx context.Context, key string, duration time.Duration) ([
return v.bs, nil
}
func (l *_lru) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
return newScanner(l.GetEx(ctx, key, duration))
}
func (l *_lru) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
@ -94,6 +116,16 @@ func (l *_lru) SetEx(ctx context.Context, key string, value any, duration time.D
return nil
}
func (l *_lru) Sets(ctx context.Context, m map[string]any) error {
for k, v := range m {
if err := l.Set(ctx, k, v); err != nil {
return err
}
}
return nil
}
func (l *_lru) Del(ctx context.Context, keys ...string) error {
for _, key := range keys {
l.client.Remove(key)
@ -102,7 +134,12 @@ func (l *_lru) Del(ctx context.Context, keys ...string) error {
return nil
}
func newLRUCache() (interfaces.Cacher, error) {
func (l *_lru) Close() error {
l.client = nil
return nil
}
func newLRUCache() (Cache, error) {
client := expirable.NewLRU[string, *_lru_value](1024*1024, nil, 0)
return &_lru{client: client}, nil

View File

@ -5,17 +5,24 @@ import (
"errors"
"fmt"
"time"
"ultone/internal/interfaces"
"gitea.com/taozitaozi/gredis"
)
var _ interfaces.Cacher = (*_mem)(nil)
var _ Cache = (*_mem)(nil)
type _mem struct {
client *gredis.Gredis
}
func (m *_mem) GetScan(ctx context.Context, key string) Scanner {
return newScanner(m.Get(ctx, key))
}
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
return newScanner(m.GetEx(ctx, key, duration))
}
func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) {
v, err := m.client.Get(key)
if err != nil {
@ -34,6 +41,20 @@ func (m *_mem) Get(ctx context.Context, key string) ([]byte, error) {
return bs, nil
}
func (m *_mem) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
bss := make([][]byte, 0, len(keys))
for _, key := range keys {
bs, err := m.Get(ctx, key)
if err != nil {
return nil, err
}
bss = append(bss, bs)
}
return bss, nil
}
func (m *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
v, err := m.client.GetEx(key, duration)
if err != nil {
@ -60,6 +81,16 @@ func (m *_mem) Set(ctx context.Context, key string, value any) error {
return m.client.Set(key, bs)
}
func (m *_mem) Sets(ctx context.Context, vm map[string]any) error {
for k, v := range vm {
if err := m.Set(ctx, k, v); err != nil {
return err
}
}
return nil
}
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
@ -72,3 +103,9 @@ func (m *_mem) Del(ctx context.Context, keys ...string) error {
m.client.Delete(keys...)
return nil
}
func (m *_mem) Close() error {
m.client = nil
return nil
}

View File

@ -3,8 +3,11 @@ package cache
import (
"context"
"errors"
"github.com/go-redis/redis/v8"
"time"
"github.com/go-redis/redis/v8"
"github.com/samber/lo"
"github.com/spf13/cast"
)
type _redis struct {
@ -24,6 +27,28 @@ func (r *_redis) Get(ctx context.Context, key string) ([]byte, error) {
return []byte(result), nil
}
func (r *_redis) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
result, err := r.client.MGet(ctx, keys...).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrorKeyNotFound
}
return nil, err
}
return lo.Map(
result,
func(item any, index int) []byte {
return []byte(cast.ToString(item))
},
), nil
}
func (r *_redis) GetScan(ctx context.Context, key string) Scanner {
return newScanner(r.Get(ctx, key))
}
func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
result, err := r.client.GetEx(ctx, key, duration).Result()
if err != nil {
@ -37,6 +62,10 @@ func (r *_redis) GetEx(ctx context.Context, key string, duration time.Duration)
return []byte(result), nil
}
func (r *_redis) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
return newScanner(r.GetEx(ctx, key, duration))
}
func (r *_redis) Set(ctx context.Context, key string, value any) error {
bs, err := handleValue(value)
if err != nil {
@ -47,6 +76,20 @@ func (r *_redis) Set(ctx context.Context, key string, value any) error {
return err
}
func (r *_redis) Sets(ctx context.Context, values map[string]any) error {
vm := make(map[string]any)
for k, v := range values {
bs, err := handleValue(v)
if err != nil {
return err
}
vm[k] = bs
}
return r.client.MSet(ctx, vm).Err()
}
func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
bs, err := handleValue(value)
if err != nil {
@ -61,3 +104,7 @@ func (r *_redis) SetEx(ctx context.Context, key string, value any, duration time
func (r *_redis) Del(ctx context.Context, keys ...string) error {
return r.client.Del(ctx, keys...).Err()
}
func (r *_redis) Close() error {
return r.client.Close()
}

View File

@ -1,38 +0,0 @@
package cache
import (
"encoding/json"
"ultone/internal/interfaces"
)
var (
Client interfaces.Cacher
)
type encoded_value interface {
MarshalBinary() ([]byte, error)
}
type decoded_value interface {
UnmarshalBinary(bs []byte) error
}
func handleValue(value any) ([]byte, error) {
var (
bs []byte
err error
)
switch value.(type) {
case []byte:
return value.([]byte), nil
}
if imp, ok := value.(encoded_value); ok {
bs, err = imp.MarshalBinary()
} else {
bs, err = json.Marshal(value)
}
return bs, err
}

View File

@ -1,7 +0,0 @@
package cache
import "errors"
var (
ErrorKeyNotFound = errors.New("key not found")
)

View File

@ -1,30 +1,31 @@
package cache
import (
"context"
"fmt"
"gitea.com/taozitaozi/gredis"
"github.com/go-redis/redis/v8"
"net/url"
"strings"
"ultone/internal/opt"
"ultone/internal/tool"
"gitea.com/taozitaozi/gredis"
"github.com/go-redis/redis/v8"
)
func Init() error {
func New(ctx context.Context, uri string) (Cache, error) {
var (
err error
err error
newClient Cache
strs = strings.Split(uri, "::")
)
strs := strings.Split(opt.Cfg.Cache.Uri, "::")
switch strs[0] {
case "memory":
gc := gredis.NewGredis(1024 * 1024)
Client = &_mem{client: gc}
newClient = &_mem{client: gc}
case "lru":
if Client, err = newLRUCache(); err != nil {
return err
if newClient, err = newLRUCache(); err != nil {
return nil, err
}
case "redis":
var (
@ -33,7 +34,7 @@ func Init() error {
)
if len(strs) != 2 {
return fmt.Errorf("cache.Init: invalid cache uri: %s", opt.Cfg.Cache.Uri)
return nil, fmt.Errorf("cache.Init: invalid cache uri: %s", uri)
}
uri := strs[1]
@ -43,7 +44,7 @@ func Init() error {
}
if ins, err = url.Parse(uri); err != nil {
return fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", opt.Cfg.Cache.Uri, err.Error())
return nil, fmt.Errorf("cache.Init: url parse cache uri: %s, err: %s", uri, err.Error())
}
addr := ins.Host
@ -58,13 +59,18 @@ func Init() error {
})
if err = rc.Ping(tool.Timeout(5)).Err(); err != nil {
return fmt.Errorf("cache.Init: redis ping err: %s", err.Error())
return nil, fmt.Errorf("cache.Init: redis ping err: %s", err.Error())
}
Client = &_redis{client: rc}
newClient = &_redis{client: rc}
default:
return fmt.Errorf("cache type %s not support", strs[0])
return nil, fmt.Errorf("cache type %s not support", strs[0])
}
return nil
return newClient, nil
}
func Init(ctx context.Context, uri string) (err error) {
Client, err = New(ctx, uri)
return
}

20
internal/database/cache/scan.go vendored Normal file
View File

@ -0,0 +1,20 @@
package cache
import "encoding/json"
type scanner struct {
err error
bs []byte
}
func (s *scanner) Scan(model any) error {
if s.err != nil {
return s.err
}
return json.Unmarshal(s.bs, model)
}
func newScanner(bs []byte, err error) *scanner {
return &scanner{bs: bs, err: err}
}

View File

@ -9,6 +9,7 @@ import (
"ultone/internal/controller"
"ultone/internal/database/cache"
"ultone/internal/database/db"
"ultone/internal/log"
"ultone/internal/middleware/oplog"
"ultone/internal/model"
"ultone/internal/opt"
@ -75,26 +76,27 @@ func AuthLogin(c *nf.Ctx) error {
if !opt.MultiLogin {
var (
last = fmt.Sprintf("%s:user:last_token:%d", opt.CachePrefix, target.Id)
bs []byte
lastKey = fmt.Sprintf("%s:user:last_token:%d", opt.CachePrefix, target.Id)
lastToken string
)
// 获取之前的 token
if bs, err = cache.Client.Get(tool.Timeout(3), last); err == nil {
key := fmt.Sprintf("%s:user:token:%s", opt.CachePrefix, string(bs))
_ = cache.Client.Del(tool.Timeout(3), key)
if err = cache.Client.GetScan(tool.Timeout(3), lastKey).Scan(&lastToken); err != nil {
if !errors.Is(err, cache.ErrorKeyNotFound) {
log.Warn(c.Context(), "handler.AuthLogin: get last token err = %v", err)
goto HandleMultiEnd
}
}
// 删掉之前的 token
if len(bs) > 0 {
_ = controller.UserController.RmToken(c.Context(), string(bs))
}
controller.UserController.RmToken(c.Context(), lastToken)
// 将当前的 token 存入 last_token
if err = cache.Client.Set(tool.Timeout(3), last, token); err != nil {
if err = cache.Client.Set(tool.Timeout(3), lastKey, token); err != nil {
return resp.Resp500(c, err.Error())
}
}
HandleMultiEnd:
c.Set("Set-Cookie", fmt.Sprintf("%s=%s; Path=/", opt.CookieName, token))
c.Locals("user", target)
@ -121,12 +123,28 @@ func AuthVerify(c *nf.Ctx) error {
}
func AuthLogout(c *nf.Ctx) error {
defer func() {
c.Set("Set-Cookie", fmt.Sprintf("%s=; Path=/; Max-Age=0", opt.CookieName))
}()
op, ok := c.Locals("user").(*model.User)
if !ok {
return resp.Resp401(c, nil)
}
_ = controller.UserController.RmUserCache(c.Context(), op.Id)
token, ok := c.Locals("token").(string)
if !ok {
return resp.Resp401(c, nil)
}
if !opt.MultiLogin {
_ = controller.UserController.RmUserCache(c.Context(), op.Id)
lastKey := fmt.Sprintf("%s:user:last_token:%d", opt.CachePrefix, op.Id)
cache.Client.Del(c.Context(), lastKey)
}
// 删掉之前的 token
controller.UserController.RmToken(c.Context(), token)
c.Locals(opt.OpLogLocalKey, &oplog.OpLog{
Type: model.OpLogTypeLogout,

View File

@ -1,16 +0,0 @@
package interfaces
import (
"context"
"time"
)
type Cacher interface {
Get(ctx context.Context, key string) ([]byte, error)
GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error)
// Set value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
Set(ctx context.Context, key string, value any) error
// SetEx value 会被序列化, 优先使用 MarshalBinary 方法, 没有则执行 json.Marshal
SetEx(ctx context.Context, key string, value any, duration time.Duration) error
Del(ctx context.Context, keys ...string) error
}

View File

@ -1,8 +1,6 @@
package tool
import "cmp"
func Min[T cmp.Ordered](a, b T) T {
func Min[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
if a <= b {
return a
}
@ -10,7 +8,7 @@ func Min[T cmp.Ordered](a, b T) T {
return b
}
func Max[T cmp.Ordered](a, b T) T {
func Max[T ~int | ~uint | ~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~float32 | ~float64](a, b T) T {
if a >= b {
return a
}