This commit is contained in:
loveuer
2026-01-28 10:28:13 +08:00
parent 507a67e455
commit 3ee0c9c098
29 changed files with 2852 additions and 0 deletions

112
database/cache/README.md vendored Normal file
View File

@@ -0,0 +1,112 @@
# Cache Package
简洁的 Redis 兼容缓存接口,包含最常用的缓存操作。
## 接口说明
### Cache 核心方法
- `Set/Get/GetBytes/GetScan` - 存取值
- `Del` - 删除键
- `Exists` - 检查键是否存在
- `Expire/TTL` - 设置/获取过期时间
- `Inc/IncBy/Dec` - 原子递增递减
- `SetNX` - 不存在时设置
- `Keys` - 模式匹配查找键
- `Close` - 关闭连接
### Hash 操作方法
- `HSet/HGet` - 设置/获取字段值
- `HGetAll` - 获取所有字段值
- `HDel` - 删除字段
- `HExists` - 检查字段是否存在
- `HKeys` - 获取所有字段名
- `HLen` - 获取字段数量
- `HIncrBy` - 字段值原子递增
### 配置选项
- `Driver` - 驱动类型 (redis/memory)
- `Addr` - 连接地址
- `MasterAddr` - 主节点地址
- `ReplicaAddrs` - 副本节点地址列表
- `Password` - 密码
- `DB` - 数据库编号
- `ReadOnly` - 只读模式
- `Reconnect` - 是否启用自动重连(默认 true
- `ReconnectInterval` - 重连检测间隔(默认 10 秒)
- 连接池和超时配置
## 使用示例
### 基础 Redis 连接
```go
config := NewConfig("redis", "localhost:6379")
cache, err := Open(config)
if err != nil {
log.Fatal(err)
}
```
### Master-Replica 模式(读写分离)
```go
config := NewConfig("redis", "localhost:6379")
config.MasterAddr = "redis-master:6379"
config.ReplicaAddrs = []string{"redis-replica-1:6379", "redis-replica-2:6379"}
cache, err := Open(config)
// 读操作会自动使用 replica写操作使用 master
```
### Kubernetes Headless Service 模式
```go
// 自动解析 headless service 并实现读写分离
cache, err := NewRedisFromHeadlessService(
"my-redis-headless.default.svc.cluster.local:6379",
"password",
)
```
### 基础操作
```go
// 设置值
err = cache.Set(ctx, "key", "value", time.Hour)
// 获取值
val, err := cache.Get(ctx, "key")
// 原子递增
count, err := cache.IncBy(ctx, "counter", 1)
// Hash 操作
err = cache.HSet(ctx, "user:1", "name", "张三")
name, err := cache.HGet(ctx, "user:1", "name")
all, err := cache.HGetAll(ctx, "user:1")
```
### 读写分离说明
- **写操作** (Set, Del, Inc, HSet 等) → Master 节点
- **读操作** (Get, Exists, HGet, Keys 等) → Replica 节点
- **Headless Service** → 自动解析 Kubernetes Pod 地址
### 自动重连
- **默认启用**:每 10 秒检测一次连接状态
- **断线重连**:自动重新初始化连接
- **优雅关闭**Close() 时停止重连检测
```go
// 禁用自动重连
config := NewConfig("redis", "localhost:6379")
config.Reconnect = false
// 自定义重连间隔
config.Reconnect = true
config.ReconnectInterval = 5 * time.Second
```

11
database/cache/errors.go vendored Normal file
View File

@@ -0,0 +1,11 @@
package cache
import (
"errors"
)
var (
ErrKeyNotFound = errors.New("key not found")
ErrKeyExists = errors.New("key already exists")
ErrInvalidType = errors.New("invalid type")
)

331
database/cache/interface.go vendored Normal file
View File

@@ -0,0 +1,331 @@
package cache
import (
"context"
"fmt"
"strconv"
"time"
)
type Cache interface {
Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error
Get(ctx context.Context, key string) (string, error)
GetBytes(ctx context.Context, key string) ([]byte, error)
GetScan(ctx context.Context, key string, dest interface{}) error
Del(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
Expire(ctx context.Context, key string, expiration time.Duration) error
TTL(ctx context.Context, key string) (time.Duration, error)
Inc(ctx context.Context, key string) (int64, error)
IncBy(ctx context.Context, key string, value int64) (int64, error)
Dec(ctx context.Context, key string) (int64, error)
SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error)
Keys(ctx context.Context, pattern string) ([]string, error)
Close() error
HSet(ctx context.Context, key string, field string, value interface{}) error
HGet(ctx context.Context, key string, field string) (string, error)
HGetAll(ctx context.Context, key string) (map[string]string, error)
HDel(ctx context.Context, key string, fields ...string) error
HExists(ctx context.Context, key string, field string) (bool, error)
HKeys(ctx context.Context, key string) ([]string, error)
HLen(ctx context.Context, key string) (int64, error)
HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error)
}
type Config struct {
Driver string `json:"driver"`
Addr string `json:"addr"`
MasterAddr string `json:"master_addr"`
ReplicaAddrs []string `json:"replica_addrs"`
Password string `json:"password"`
DB int `json:"db"`
DialTimeout time.Duration `json:"dial_timeout"`
ReadTimeout time.Duration `json:"read_timeout"`
WriteTimeout time.Duration `json:"write_timeout"`
PoolSize int `json:"pool_size"`
ReadOnly bool `json:"read_only"`
// 重连配置
Reconnect bool `json:"reconnect"` // 是否启用自动重连,默认 true
ReconnectInterval time.Duration `json:"reconnect_interval"` // 重连检测间隔,默认 10 秒
}
type Option func(*Config)
func WithAddr(addr string) Option {
return func(c *Config) {
c.Addr = addr
}
}
func WithPassword(password string) Option {
return func(c *Config) {
c.Password = password
}
}
func WithDB(db int) Option {
return func(c *Config) {
c.DB = db
}
}
func WithDialTimeout(timeout time.Duration) Option {
return func(c *Config) {
c.DialTimeout = timeout
}
}
func WithMasterAddr(addr string) Option {
return func(c *Config) {
c.MasterAddr = addr
}
}
func WithReplicaAddrs(addrs []string) Option {
return func(c *Config) {
c.ReplicaAddrs = addrs
}
}
func WithReadOnly(readOnly bool) Option {
return func(c *Config) {
c.ReadOnly = readOnly
}
}
func WithReconnect(reconnect bool) Option {
return func(c *Config) {
c.Reconnect = reconnect
}
}
func WithReconnectInterval(interval time.Duration) Option {
return func(c *Config) {
c.ReconnectInterval = interval
}
}
type Driver interface {
Cache(config *Config) (Cache, error)
}
func NewMemoryCache() Cache {
return &memoryCache{
data: make(map[string]string),
hash: make(map[string]map[string]string),
}
}
type memoryCache struct {
data map[string]string
hash map[string]map[string]string
}
func (m *memoryCache) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
m.data[key] = fmt.Sprintf("%v", value)
return nil
}
func (m *memoryCache) Get(ctx context.Context, key string) (string, error) {
val, exists := m.data[key]
if !exists {
return "", ErrKeyNotFound
}
return val, nil
}
func (m *memoryCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := m.Get(ctx, key)
if err != nil {
return nil, err
}
return []byte(val), nil
}
func (m *memoryCache) GetScan(ctx context.Context, key string, dest interface{}) error {
// 简单实现,实际应该用 json.Unmarshal
return fmt.Errorf("GetScan not implemented in memory cache")
}
func (m *memoryCache) Del(ctx context.Context, keys ...string) error {
for _, key := range keys {
delete(m.data, key)
delete(m.hash, key)
}
return nil
}
func (m *memoryCache) Exists(ctx context.Context, key string) (bool, error) {
_, exists := m.data[key]
return exists, nil
}
func (m *memoryCache) Expire(ctx context.Context, key string, expiration time.Duration) error {
// Memory cache 简单实现,不支持过期
return nil
}
func (m *memoryCache) TTL(ctx context.Context, key string) (time.Duration, error) {
return -1, nil
}
func (m *memoryCache) Inc(ctx context.Context, key string) (int64, error) {
return m.IncBy(ctx, key, 1)
}
func (m *memoryCache) IncBy(ctx context.Context, key string, value int64) (int64, error) {
current, exists := m.data[key]
var currentInt int64
if exists {
var err error
currentInt, err = strconv.ParseInt(current, 10, 64)
if err != nil {
currentInt = 0
}
}
newVal := currentInt + value
m.data[key] = fmt.Sprintf("%d", newVal)
return newVal, nil
}
func (m *memoryCache) Dec(ctx context.Context, key string) (int64, error) {
return m.IncBy(ctx, key, -1)
}
func (m *memoryCache) SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error) {
if _, exists := m.data[key]; exists {
return false, nil
}
m.data[key] = fmt.Sprintf("%v", value)
return true, nil
}
func (m *memoryCache) Keys(ctx context.Context, pattern string) ([]string, error) {
var keys []string
for key := range m.data {
if pattern == "*" || key == pattern {
keys = append(keys, key)
}
}
return keys, nil
}
func (m *memoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
if _, exists := m.hash[key]; !exists {
m.hash[key] = make(map[string]string)
}
m.hash[key][field] = fmt.Sprintf("%v", value)
return nil
}
func (m *memoryCache) HGet(ctx context.Context, key string, field string) (string, error) {
hash, exists := m.hash[key]
if !exists {
return "", ErrKeyNotFound
}
val, exists := hash[field]
if !exists {
return "", ErrKeyNotFound
}
return val, nil
}
func (m *memoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
hash, exists := m.hash[key]
if !exists {
return nil, ErrKeyNotFound
}
// 返回副本
result := make(map[string]string)
for k, v := range hash {
result[k] = v
}
return result, nil
}
func (m *memoryCache) HDel(ctx context.Context, key string, fields ...string) error {
hash, exists := m.hash[key]
if !exists {
return nil
}
for _, field := range fields {
delete(hash, field)
}
return nil
}
func (m *memoryCache) HExists(ctx context.Context, key string, field string) (bool, error) {
hash, exists := m.hash[key]
if !exists {
return false, nil
}
_, exists = hash[field]
return exists, nil
}
func (m *memoryCache) HKeys(ctx context.Context, key string) ([]string, error) {
hash, exists := m.hash[key]
if !exists {
return []string{}, nil
}
var keys []string
for k := range hash {
keys = append(keys, k)
}
return keys, nil
}
func (m *memoryCache) HLen(ctx context.Context, key string) (int64, error) {
hash, exists := m.hash[key]
if !exists {
return 0, nil
}
return int64(len(hash)), nil
}
func (m *memoryCache) HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error) {
hash, exists := m.hash[key]
if !exists {
hash = make(map[string]string)
m.hash[key] = hash
}
current, exists := hash[field]
var currentInt int64
if exists {
var err error
currentInt, err = strconv.ParseInt(current, 10, 64)
if err != nil {
currentInt = 0
}
}
newVal := currentInt + increment
hash[field] = fmt.Sprintf("%d", newVal)
return newVal, nil
}
func (m *memoryCache) Close() error {
m.data = nil
m.hash = nil
return nil
}
func NewConfig(driver, addr string) *Config {
return &Config{
Driver: driver,
Addr: addr,
DB: 0,
DialTimeout: 5 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
PoolSize: 10,
Reconnect: true,
ReconnectInterval: 10 * time.Second,
}
}

315
database/cache/redis.go vendored Normal file
View File

@@ -0,0 +1,315 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
type RedisCache struct {
master *redis.Client
replica *redis.Client
config *Config
// 重连相关
mu sync.RWMutex
closed bool
ticker *time.Ticker
done chan struct{}
}
func NewRedis(config *Config) (Cache, error) {
rdb := &RedisCache{
config: config,
done: make(chan struct{}),
}
// 初始化主节点连接
if config.MasterAddr != "" {
rdb.master = rdb.createClient(config.MasterAddr, true)
} else {
// 如果没有指定 master使用 Addr
rdb.master = rdb.createClient(config.Addr, true)
}
// 初始化从节点连接(用于只读操作)
if len(config.ReplicaAddrs) > 0 {
// 如果有多个副本,使用第一个副本
rdb.replica = rdb.createClient(config.ReplicaAddrs[0], false)
} else {
// 如果没有指定副本,复用 master 连接
rdb.replica = rdb.master
}
// 启动自动重连
if config.Reconnect {
rdb.startReconnect()
}
return rdb, nil
}
func (r *RedisCache) createClient(addr string, isMaster bool) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: addr,
Password: r.config.Password,
DB: r.config.DB,
DialTimeout: r.config.DialTimeout,
ReadTimeout: r.config.ReadTimeout,
WriteTimeout: r.config.WriteTimeout,
PoolSize: r.config.PoolSize,
})
}
func (r *RedisCache) getClient(readOnly bool) *redis.Client {
r.mu.RLock()
defer r.mu.RUnlock()
if readOnly && r.replica != nil && !r.config.ReadOnly {
return r.replica
}
return r.master
}
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
return r.getClient(false).Set(ctx, key, value, exp).Err()
}
func (r *RedisCache) Get(ctx context.Context, key string) (string, error) {
result, err := r.getClient(true).Get(ctx, key).Result()
if err == redis.Nil {
return "", ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := r.Get(ctx, key)
if err != nil {
return nil, err
}
return []byte(val), nil
}
func (r *RedisCache) GetScan(ctx context.Context, key string, dest interface{}) error {
val, err := r.Get(ctx, key)
if err != nil {
return err
}
return json.Unmarshal([]byte(val), dest)
}
func (r *RedisCache) Del(ctx context.Context, keys ...string) error {
return r.getClient(false).Del(ctx, keys...).Err()
}
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
result, err := r.getClient(true).Exists(ctx, key).Result()
return result > 0, err
}
func (r *RedisCache) Expire(ctx context.Context, key string, expiration time.Duration) error {
return r.getClient(false).Expire(ctx, key, expiration).Err()
}
func (r *RedisCache) TTL(ctx context.Context, key string) (time.Duration, error) {
return r.getClient(true).TTL(ctx, key).Result()
}
func (r *RedisCache) Inc(ctx context.Context, key string) (int64, error) {
return r.getClient(false).Incr(ctx, key).Result()
}
func (r *RedisCache) IncBy(ctx context.Context, key string, value int64) (int64, error) {
return r.getClient(false).IncrBy(ctx, key, value).Result()
}
func (r *RedisCache) Dec(ctx context.Context, key string) (int64, error) {
return r.getClient(false).Decr(ctx, key).Result()
}
func (r *RedisCache) SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error) {
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
return r.getClient(false).SetNX(ctx, key, value, exp).Result()
}
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
return r.getClient(true).Keys(ctx, pattern).Result()
}
func (r *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
return r.getClient(false).HSet(ctx, key, field, value).Err()
}
func (r *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
result, err := r.getClient(true).HGet(ctx, key, field).Result()
if err == redis.Nil {
return "", ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
result, err := r.getClient(true).HGetAll(ctx, key).Result()
if err == redis.Nil {
return nil, ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
return r.getClient(false).HDel(ctx, key, fields...).Err()
}
func (r *RedisCache) HExists(ctx context.Context, key string, field string) (bool, error) {
return r.getClient(true).HExists(ctx, key, field).Result()
}
func (r *RedisCache) HKeys(ctx context.Context, key string) ([]string, error) {
return r.getClient(true).HKeys(ctx, key).Result()
}
func (r *RedisCache) HLen(ctx context.Context, key string) (int64, error) {
return r.getClient(true).HLen(ctx, key).Result()
}
func (r *RedisCache) HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error) {
return r.getClient(false).HIncrBy(ctx, key, field, increment).Result()
}
func (r *RedisCache) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return nil
}
r.closed = true
// 停止重连定时器
if r.ticker != nil {
r.ticker.Stop()
}
close(r.done)
// 关闭连接
if r.master != nil {
r.master.Close()
}
if r.replica != nil && r.replica != r.master {
r.replica.Close()
}
return nil
}
// 支持解析 Kubernetes Headless Service 地址的辅助函数
func ParseHeadlessServiceAddr(addr string) (master string, replicas []string, err error) {
// 格式: my-redis-headless.my-namespace.svc.cluster.local:6379
if !strings.Contains(addr, ".svc.") {
// 非集群地址,作为单节点处理
return addr, nil, nil
}
// 这里可以通过 DNS SRV 记录查询获取所有 pod 地址
// 简化实现:假设已知命名空间和服务名,返回 master 和多个副本
parts := strings.SplitN(addr, ".", 2)
if len(parts) < 2 {
return "", nil, fmt.Errorf("invalid headless service address")
}
serviceName := parts[0]
namespace := strings.SplitN(parts[1], ".", 2)[0]
// Kubernetes headless service 模式下,第一个 pod 作为 master
// 其余作为 replicas
master = fmt.Sprintf("%s-0.%s.%s.svc.cluster.local:6379", serviceName, serviceName, namespace)
for i := 1; i <= 2; i++ { // 假设有 2 个副本
replica := fmt.Sprintf("%s-%d.%s.%s.svc.cluster.local:6379", serviceName, i, serviceName, namespace)
replicas = append(replicas, replica)
}
return master, replicas, nil
}
func (r *RedisCache) startReconnect() {
r.ticker = time.NewTicker(r.config.ReconnectInterval)
go func() {
for {
select {
case <-r.done:
return
case <-r.ticker.C:
r.checkAndReconnect()
}
}
}()
}
func (r *RedisCache) checkAndReconnect() {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return
}
// 检查主节点连接
if r.master != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := r.master.Ping(ctx).Err(); err != nil {
fmt.Printf("Master connection lost: %v, attempting reconnect...\n", err)
if r.config.MasterAddr != "" {
r.master.Close()
r.master = r.createClient(r.config.MasterAddr, true)
} else {
r.master.Close()
r.master = r.createClient(r.config.Addr, true)
}
}
cancel()
}
// 检查副本节点连接如果与master不同
if r.replica != nil && r.replica != r.master {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := r.replica.Ping(ctx).Err(); err != nil {
fmt.Printf("Replica connection lost: %v, attempting reconnect...\n", err)
if len(r.config.ReplicaAddrs) > 0 {
r.replica.Close()
r.replica = r.createClient(r.config.ReplicaAddrs[0], false)
}
}
cancel()
}
}
// 从 Headless Service 自动创建 Redis 连接
func NewRedisFromHeadlessService(headlessAddr string, password string) (Cache, error) {
master, replicas, err := ParseHeadlessServiceAddr(headlessAddr)
if err != nil {
return nil, err
}
config := NewConfig("redis", headlessAddr)
config.MasterAddr = master
config.ReplicaAddrs = replicas
config.Password = password
return NewRedis(config)
}

28
database/cache/registry.go vendored Normal file
View File

@@ -0,0 +1,28 @@
package cache
import "fmt"
var drivers = make(map[string]Driver)
func Register(name string, driver Driver) {
if driver == nil {
panic("cache: Register driver is nil")
}
if _, dup := drivers[name]; dup {
panic("cache: Register called twice for driver " + name)
}
drivers[name] = driver
}
func Open(config *Config) (Cache, error) {
if config.Driver == "redis" {
return NewRedis(config)
}
driver, ok := drivers[config.Driver]
if !ok {
return nil, fmt.Errorf("unknown driver %q (forgotten import?)", config.Driver)
}
return driver.Cache(config)
}