Files
upkg/database/cache/interface.go
2026-01-28 10:28:13 +08:00

332 lines
8.2 KiB
Go

package cache
import (
"context"
"fmt"
"strconv"
"time"
)
type Cache interface {
Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error
Get(ctx context.Context, key string) (string, error)
GetBytes(ctx context.Context, key string) ([]byte, error)
GetScan(ctx context.Context, key string, dest interface{}) error
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
Expire(ctx context.Context, key string, expiration time.Duration) error
TTL(ctx context.Context, key string) (time.Duration, error)
Inc(ctx context.Context, key string) (int64, error)
IncBy(ctx context.Context, key string, value int64) (int64, error)
Dec(ctx context.Context, key string) (int64, error)
SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error)
Keys(ctx context.Context, pattern string) ([]string, error)
Close() error
HSet(ctx context.Context, key string, field string, value interface{}) error
HGet(ctx context.Context, key string, field string) (string, error)
HGetAll(ctx context.Context, key string) (map[string]string, error)
HDel(ctx context.Context, key string, fields ...string) error
HExists(ctx context.Context, key string, field string) (bool, error)
HKeys(ctx context.Context, key string) ([]string, error)
HLen(ctx context.Context, key string) (int64, error)
HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error)
}
type Config struct {
Driver string `json:"driver"`
Addr string `json:"addr"`
MasterAddr string `json:"master_addr"`
ReplicaAddrs []string `json:"replica_addrs"`
Password string `json:"password"`
DB int `json:"db"`
DialTimeout time.Duration `json:"dial_timeout"`
ReadTimeout time.Duration `json:"read_timeout"`
WriteTimeout time.Duration `json:"write_timeout"`
PoolSize int `json:"pool_size"`
ReadOnly bool `json:"read_only"`
// 重连配置
Reconnect bool `json:"reconnect"` // 是否启用自动重连,默认 true
ReconnectInterval time.Duration `json:"reconnect_interval"` // 重连检测间隔,默认 10 秒
}
type Option func(*Config)
func WithAddr(addr string) Option {
return func(c *Config) {
c.Addr = addr
}
}
func WithPassword(password string) Option {
return func(c *Config) {
c.Password = password
}
}
func WithDB(db int) Option {
return func(c *Config) {
c.DB = db
}
}
func WithDialTimeout(timeout time.Duration) Option {
return func(c *Config) {
c.DialTimeout = timeout
}
}
func WithMasterAddr(addr string) Option {
return func(c *Config) {
c.MasterAddr = addr
}
}
func WithReplicaAddrs(addrs []string) Option {
return func(c *Config) {
c.ReplicaAddrs = addrs
}
}
func WithReadOnly(readOnly bool) Option {
return func(c *Config) {
c.ReadOnly = readOnly
}
}
func WithReconnect(reconnect bool) Option {
return func(c *Config) {
c.Reconnect = reconnect
}
}
func WithReconnectInterval(interval time.Duration) Option {
return func(c *Config) {
c.ReconnectInterval = interval
}
}
type Driver interface {
Cache(config *Config) (Cache, error)
}
func NewMemoryCache() Cache {
return &memoryCache{
data: make(map[string]string),
hash: make(map[string]map[string]string),
}
}
type memoryCache struct {
data map[string]string
hash map[string]map[string]string
}
func (m *memoryCache) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
m.data[key] = fmt.Sprintf("%v", value)
return nil
}
func (m *memoryCache) Get(ctx context.Context, key string) (string, error) {
val, exists := m.data[key]
if !exists {
return "", ErrKeyNotFound
}
return val, nil
}
func (m *memoryCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := m.Get(ctx, key)
if err != nil {
return nil, err
}
return []byte(val), nil
}
func (m *memoryCache) GetScan(ctx context.Context, key string, dest interface{}) error {
// 简单实现,实际应该用 json.Unmarshal
return fmt.Errorf("GetScan not implemented in memory cache")
}
func (m *memoryCache) Del(ctx context.Context, keys ...string) error {
for _, key := range keys {
delete(m.data, key)
delete(m.hash, key)
}
return nil
}
func (m *memoryCache) Exists(ctx context.Context, key string) (bool, error) {
_, exists := m.data[key]
return exists, nil
}
func (m *memoryCache) Expire(ctx context.Context, key string, expiration time.Duration) error {
// Memory cache 简单实现,不支持过期
return nil
}
func (m *memoryCache) TTL(ctx context.Context, key string) (time.Duration, error) {
return -1, nil
}
func (m *memoryCache) Inc(ctx context.Context, key string) (int64, error) {
return m.IncBy(ctx, key, 1)
}
func (m *memoryCache) IncBy(ctx context.Context, key string, value int64) (int64, error) {
current, exists := m.data[key]
var currentInt int64
if exists {
var err error
currentInt, err = strconv.ParseInt(current, 10, 64)
if err != nil {
currentInt = 0
}
}
newVal := currentInt + value
m.data[key] = fmt.Sprintf("%d", newVal)
return newVal, nil
}
func (m *memoryCache) Dec(ctx context.Context, key string) (int64, error) {
return m.IncBy(ctx, key, -1)
}
func (m *memoryCache) SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error) {
if _, exists := m.data[key]; exists {
return false, nil
}
m.data[key] = fmt.Sprintf("%v", value)
return true, nil
}
func (m *memoryCache) Keys(ctx context.Context, pattern string) ([]string, error) {
var keys []string
for key := range m.data {
if pattern == "*" || key == pattern {
keys = append(keys, key)
}
}
return keys, nil
}
func (m *memoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
if _, exists := m.hash[key]; !exists {
m.hash[key] = make(map[string]string)
}
m.hash[key][field] = fmt.Sprintf("%v", value)
return nil
}
func (m *memoryCache) HGet(ctx context.Context, key string, field string) (string, error) {
hash, exists := m.hash[key]
if !exists {
return "", ErrKeyNotFound
}
val, exists := hash[field]
if !exists {
return "", ErrKeyNotFound
}
return val, nil
}
func (m *memoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
hash, exists := m.hash[key]
if !exists {
return nil, ErrKeyNotFound
}
// 返回副本
result := make(map[string]string)
for k, v := range hash {
result[k] = v
}
return result, nil
}
func (m *memoryCache) HDel(ctx context.Context, key string, fields ...string) error {
hash, exists := m.hash[key]
if !exists {
return nil
}
for _, field := range fields {
delete(hash, field)
}
return nil
}
func (m *memoryCache) HExists(ctx context.Context, key string, field string) (bool, error) {
hash, exists := m.hash[key]
if !exists {
return false, nil
}
_, exists = hash[field]
return exists, nil
}
func (m *memoryCache) HKeys(ctx context.Context, key string) ([]string, error) {
hash, exists := m.hash[key]
if !exists {
return []string{}, nil
}
var keys []string
for k := range hash {
keys = append(keys, k)
}
return keys, nil
}
func (m *memoryCache) HLen(ctx context.Context, key string) (int64, error) {
hash, exists := m.hash[key]
if !exists {
return 0, nil
}
return int64(len(hash)), nil
}
func (m *memoryCache) HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error) {
hash, exists := m.hash[key]
if !exists {
hash = make(map[string]string)
m.hash[key] = hash
}
current, exists := hash[field]
var currentInt int64
if exists {
var err error
currentInt, err = strconv.ParseInt(current, 10, 64)
if err != nil {
currentInt = 0
}
}
newVal := currentInt + increment
hash[field] = fmt.Sprintf("%d", newVal)
return newVal, nil
}
func (m *memoryCache) Close() error {
m.data = nil
m.hash = nil
return nil
}
func NewConfig(driver, addr string) *Config {
return &Config{
Driver: driver,
Addr: addr,
DB: 0,
DialTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
PoolSize: 10,
Reconnect: true,
ReconnectInterval: 10 * time.Second,
}
}