feat: add memory cache
This commit is contained in:
1
database/cache/cache.go
vendored
1
database/cache/cache.go
vendored
@ -62,6 +62,7 @@ var (
|
||||
marshaler func(data any) ([]byte, error) = json.Marshal
|
||||
unmarshaler func(data []byte, model any) error = json.Unmarshal
|
||||
ErrorKeyNotFound = errors.New("key not found")
|
||||
ErrorStoreFailed = errors.New("store failed")
|
||||
Default Cache
|
||||
)
|
||||
|
||||
|
155
database/cache/memory.go
vendored
Normal file
155
database/cache/memory.go
vendored
Normal file
@ -0,0 +1,155 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/dgraph-io/ristretto/v2"
|
||||
)
|
||||
|
||||
var _ Cache = (*_mem)(nil)
|
||||
|
||||
type _mem struct {
|
||||
ctx context.Context
|
||||
cache *ristretto.Cache[string, []byte]
|
||||
}
|
||||
|
||||
func newMemory(ctx context.Context, ins *ristretto.Cache[string, []byte]) Cache {
|
||||
return &_mem{
|
||||
ctx: ctx,
|
||||
cache: ins,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *_mem) Client() any {
|
||||
return m.cache
|
||||
}
|
||||
|
||||
func (c *_mem) Close() {
|
||||
c.cache.Close()
|
||||
}
|
||||
|
||||
func (c *_mem) Del(ctx context.Context, keys ...string) error {
|
||||
for _, key := range keys {
|
||||
c.cache.Del(key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *_mem) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
val, ok := c.cache.Get(key)
|
||||
if !ok {
|
||||
return val, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *_mem) GetDel(ctx context.Context, key string) ([]byte, error) {
|
||||
val, err := c.Get(ctx, key)
|
||||
if err != nil {
|
||||
return val, err
|
||||
}
|
||||
|
||||
c.cache.Del(key)
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
func (c *_mem) GetDelScan(ctx context.Context, key string) Scanner {
|
||||
val, err := c.GetDel(ctx, key)
|
||||
return newScan(val, err)
|
||||
}
|
||||
|
||||
func (c *_mem) GetEx(ctx context.Context, key string, duration time.Duration) ([]byte, error) {
|
||||
val, err := c.Get(ctx, key)
|
||||
if err != nil {
|
||||
return val, err
|
||||
}
|
||||
|
||||
c.cache.SetWithTTL(key, val, 1, duration)
|
||||
|
||||
return val, err
|
||||
}
|
||||
|
||||
func (m *_mem) GetExScan(ctx context.Context, key string, duration time.Duration) Scanner {
|
||||
val, err := m.GetEx(ctx, key, duration)
|
||||
return newScan(val, err)
|
||||
}
|
||||
|
||||
func (m *_mem) GetScan(ctx context.Context, key string) Scanner {
|
||||
val, err := m.Get(ctx, key)
|
||||
return newScan(val, err)
|
||||
}
|
||||
|
||||
func (m *_mem) Gets(ctx context.Context, keys ...string) ([][]byte, error) {
|
||||
vals := make([][]byte, 0, len(keys))
|
||||
|
||||
for _, key := range keys {
|
||||
val, err := m.Get(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrorKeyNotFound) {
|
||||
continue
|
||||
}
|
||||
|
||||
return vals, err
|
||||
}
|
||||
|
||||
vals = append(vals, val)
|
||||
}
|
||||
|
||||
if len(vals) != len(keys) {
|
||||
return vals, ErrorKeyNotFound
|
||||
}
|
||||
|
||||
return vals, nil
|
||||
}
|
||||
|
||||
func (m *_mem) Set(ctx context.Context, key string, value any) error {
|
||||
val, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := m.cache.Set(key, val, 1); !ok {
|
||||
return ErrorStoreFailed
|
||||
}
|
||||
|
||||
m.cache.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *_mem) SetEx(ctx context.Context, key string, value any, duration time.Duration) error {
|
||||
val, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := m.cache.SetWithTTL(key, val, 1, duration); !ok {
|
||||
return ErrorStoreFailed
|
||||
}
|
||||
|
||||
m.cache.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *_mem) Sets(ctx context.Context, vm map[string]any) error {
|
||||
for key, value := range vm {
|
||||
val, err := handleValue(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ok := m.cache.Set(key, val, 1); !ok {
|
||||
return ErrorStoreFailed
|
||||
}
|
||||
}
|
||||
|
||||
m.cache.Wait()
|
||||
|
||||
return nil
|
||||
}
|
42
database/cache/new.go
vendored
42
database/cache/new.go
vendored
@ -4,35 +4,31 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gitea.loveuer.com/yizhisec/packages/tool"
|
||||
"github.com/dgraph-io/ristretto/v2"
|
||||
"github.com/go-redis/redis/v8"
|
||||
_ "github.com/go-redis/redis/v8"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultRedis = "redis://127.0.0.1:6379"
|
||||
)
|
||||
|
||||
func New(opts ...OptionFn) (Cache, error) {
|
||||
func New(opts ...Option) (Cache, error) {
|
||||
var (
|
||||
err error
|
||||
cfg = &config{
|
||||
ctx: context.Background(),
|
||||
redis: &defaultRedis,
|
||||
opt = &option{
|
||||
ctx: context.Background(),
|
||||
}
|
||||
)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
for _, fn := range opts {
|
||||
fn(opt)
|
||||
}
|
||||
|
||||
if cfg.redis != nil {
|
||||
if opt.redis != nil {
|
||||
var (
|
||||
ins *url.URL
|
||||
client *redis.Client
|
||||
)
|
||||
|
||||
if ins, err = url.Parse(*cfg.redis); err != nil {
|
||||
if ins, err = url.Parse(*opt.redis); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -45,17 +41,33 @@ func New(opts ...OptionFn) (Cache, error) {
|
||||
Password: password,
|
||||
})
|
||||
|
||||
if err = client.Ping(tool.TimeoutCtx(cfg.ctx, 5)).Err(); err != nil {
|
||||
if err = client.Ping(tool.TimeoutCtx(opt.ctx, 5)).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newRedis(cfg.ctx, client), nil
|
||||
return newRedis(opt.ctx, client), nil
|
||||
}
|
||||
|
||||
if opt.memory {
|
||||
var (
|
||||
ins *ristretto.Cache[string, []byte]
|
||||
)
|
||||
|
||||
if ins, err = ristretto.NewCache(&ristretto.Config[string, []byte]{
|
||||
NumCounters: 1e7,
|
||||
MaxCost: 1 << 30,
|
||||
BufferItems: 64,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newMemory(opt.ctx, ins), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid cache config")
|
||||
}
|
||||
|
||||
func Init(opts ...OptionFn) (err error) {
|
||||
func Init(opts ...Option) (err error) {
|
||||
Default, err = New(opts...)
|
||||
return err
|
||||
}
|
||||
|
29
database/cache/new_test.go
vendored
29
database/cache/new_test.go
vendored
@ -1,6 +1,7 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -71,3 +72,31 @@ func TestNoAuth(t *testing.T) {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
}
|
||||
|
||||
func TestMemory(t *testing.T) {
|
||||
c, err := New(WithMemory())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
bs, err := c.Get(t.Context(), "haha")
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrorKeyNotFound) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("key not found")
|
||||
}
|
||||
|
||||
t.Logf("haha = %s", string(bs))
|
||||
|
||||
if err = c.Set(t.Context(), "haha", "haha"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if bs, err = c.Get(t.Context(), "haha"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("haha = %s", string(bs))
|
||||
}
|
||||
|
23
database/cache/option.go
vendored
23
database/cache/option.go
vendored
@ -5,23 +5,24 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
ctx context.Context
|
||||
redis *string
|
||||
type option struct {
|
||||
ctx context.Context
|
||||
redis *string
|
||||
memory bool
|
||||
}
|
||||
|
||||
type OptionFn func(*config)
|
||||
type Option func(*option)
|
||||
|
||||
func WithCtx(ctx context.Context) OptionFn {
|
||||
return func(c *config) {
|
||||
func WithCtx(ctx context.Context) Option {
|
||||
return func(c *option) {
|
||||
if ctx != nil {
|
||||
c.ctx = ctx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WithRedis(host string, port int, username, password string) OptionFn {
|
||||
return func(c *config) {
|
||||
func WithRedis(host string, port int, username, password string) Option {
|
||||
return func(c *option) {
|
||||
uri := fmt.Sprintf("redis://%s:%d", host, port)
|
||||
if username != "" || password != "" {
|
||||
uri = fmt.Sprintf("redis://%s:%s@%s:%d", username, password, host, port)
|
||||
@ -30,3 +31,9 @@ func WithRedis(host string, port int, username, password string) OptionFn {
|
||||
c.redis = &uri
|
||||
}
|
||||
}
|
||||
|
||||
func WithMemory() Option {
|
||||
return func(c *option) {
|
||||
c.memory = true
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user