This commit is contained in:
loveuer
2026-01-28 10:28:13 +08:00
parent 507a67e455
commit 3ee0c9c098
29 changed files with 2852 additions and 0 deletions

315
database/cache/redis.go vendored Normal file
View File

@@ -0,0 +1,315 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
type RedisCache struct {
master *redis.Client
replica *redis.Client
config *Config
// 重连相关
mu sync.RWMutex
closed bool
ticker *time.Ticker
done chan struct{}
}
func NewRedis(config *Config) (Cache, error) {
rdb := &RedisCache{
config: config,
done: make(chan struct{}),
}
// 初始化主节点连接
if config.MasterAddr != "" {
rdb.master = rdb.createClient(config.MasterAddr, true)
} else {
// 如果没有指定 master使用 Addr
rdb.master = rdb.createClient(config.Addr, true)
}
// 初始化从节点连接(用于只读操作)
if len(config.ReplicaAddrs) > 0 {
// 如果有多个副本,使用第一个副本
rdb.replica = rdb.createClient(config.ReplicaAddrs[0], false)
} else {
// 如果没有指定副本,复用 master 连接
rdb.replica = rdb.master
}
// 启动自动重连
if config.Reconnect {
rdb.startReconnect()
}
return rdb, nil
}
func (r *RedisCache) createClient(addr string, isMaster bool) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: addr,
Password: r.config.Password,
DB: r.config.DB,
DialTimeout: r.config.DialTimeout,
ReadTimeout: r.config.ReadTimeout,
WriteTimeout: r.config.WriteTimeout,
PoolSize: r.config.PoolSize,
})
}
func (r *RedisCache) getClient(readOnly bool) *redis.Client {
r.mu.RLock()
defer r.mu.RUnlock()
if readOnly && r.replica != nil && !r.config.ReadOnly {
return r.replica
}
return r.master
}
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
return r.getClient(false).Set(ctx, key, value, exp).Err()
}
func (r *RedisCache) Get(ctx context.Context, key string) (string, error) {
result, err := r.getClient(true).Get(ctx, key).Result()
if err == redis.Nil {
return "", ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := r.Get(ctx, key)
if err != nil {
return nil, err
}
return []byte(val), nil
}
func (r *RedisCache) GetScan(ctx context.Context, key string, dest interface{}) error {
val, err := r.Get(ctx, key)
if err != nil {
return err
}
return json.Unmarshal([]byte(val), dest)
}
func (r *RedisCache) Del(ctx context.Context, keys ...string) error {
return r.getClient(false).Del(ctx, keys...).Err()
}
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
result, err := r.getClient(true).Exists(ctx, key).Result()
return result > 0, err
}
func (r *RedisCache) Expire(ctx context.Context, key string, expiration time.Duration) error {
return r.getClient(false).Expire(ctx, key, expiration).Err()
}
func (r *RedisCache) TTL(ctx context.Context, key string) (time.Duration, error) {
return r.getClient(true).TTL(ctx, key).Result()
}
func (r *RedisCache) Inc(ctx context.Context, key string) (int64, error) {
return r.getClient(false).Incr(ctx, key).Result()
}
func (r *RedisCache) IncBy(ctx context.Context, key string, value int64) (int64, error) {
return r.getClient(false).IncrBy(ctx, key, value).Result()
}
func (r *RedisCache) Dec(ctx context.Context, key string) (int64, error) {
return r.getClient(false).Decr(ctx, key).Result()
}
func (r *RedisCache) SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error) {
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
return r.getClient(false).SetNX(ctx, key, value, exp).Result()
}
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
return r.getClient(true).Keys(ctx, pattern).Result()
}
func (r *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
return r.getClient(false).HSet(ctx, key, field, value).Err()
}
func (r *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
result, err := r.getClient(true).HGet(ctx, key, field).Result()
if err == redis.Nil {
return "", ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
result, err := r.getClient(true).HGetAll(ctx, key).Result()
if err == redis.Nil {
return nil, ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
return r.getClient(false).HDel(ctx, key, fields...).Err()
}
func (r *RedisCache) HExists(ctx context.Context, key string, field string) (bool, error) {
return r.getClient(true).HExists(ctx, key, field).Result()
}
func (r *RedisCache) HKeys(ctx context.Context, key string) ([]string, error) {
return r.getClient(true).HKeys(ctx, key).Result()
}
func (r *RedisCache) HLen(ctx context.Context, key string) (int64, error) {
return r.getClient(true).HLen(ctx, key).Result()
}
func (r *RedisCache) HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error) {
return r.getClient(false).HIncrBy(ctx, key, field, increment).Result()
}
func (r *RedisCache) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return nil
}
r.closed = true
// 停止重连定时器
if r.ticker != nil {
r.ticker.Stop()
}
close(r.done)
// 关闭连接
if r.master != nil {
r.master.Close()
}
if r.replica != nil && r.replica != r.master {
r.replica.Close()
}
return nil
}
// 支持解析 Kubernetes Headless Service 地址的辅助函数
func ParseHeadlessServiceAddr(addr string) (master string, replicas []string, err error) {
// 格式: my-redis-headless.my-namespace.svc.cluster.local:6379
if !strings.Contains(addr, ".svc.") {
// 非集群地址,作为单节点处理
return addr, nil, nil
}
// 这里可以通过 DNS SRV 记录查询获取所有 pod 地址
// 简化实现:假设已知命名空间和服务名,返回 master 和多个副本
parts := strings.SplitN(addr, ".", 2)
if len(parts) < 2 {
return "", nil, fmt.Errorf("invalid headless service address")
}
serviceName := parts[0]
namespace := strings.SplitN(parts[1], ".", 2)[0]
// Kubernetes headless service 模式下,第一个 pod 作为 master
// 其余作为 replicas
master = fmt.Sprintf("%s-0.%s.%s.svc.cluster.local:6379", serviceName, serviceName, namespace)
for i := 1; i <= 2; i++ { // 假设有 2 个副本
replica := fmt.Sprintf("%s-%d.%s.%s.svc.cluster.local:6379", serviceName, i, serviceName, namespace)
replicas = append(replicas, replica)
}
return master, replicas, nil
}
func (r *RedisCache) startReconnect() {
r.ticker = time.NewTicker(r.config.ReconnectInterval)
go func() {
for {
select {
case <-r.done:
return
case <-r.ticker.C:
r.checkAndReconnect()
}
}
}()
}
func (r *RedisCache) checkAndReconnect() {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return
}
// 检查主节点连接
if r.master != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := r.master.Ping(ctx).Err(); err != nil {
fmt.Printf("Master connection lost: %v, attempting reconnect...\n", err)
if r.config.MasterAddr != "" {
r.master.Close()
r.master = r.createClient(r.config.MasterAddr, true)
} else {
r.master.Close()
r.master = r.createClient(r.config.Addr, true)
}
}
cancel()
}
// 检查副本节点连接如果与master不同
if r.replica != nil && r.replica != r.master {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := r.replica.Ping(ctx).Err(); err != nil {
fmt.Printf("Replica connection lost: %v, attempting reconnect...\n", err)
if len(r.config.ReplicaAddrs) > 0 {
r.replica.Close()
r.replica = r.createClient(r.config.ReplicaAddrs[0], false)
}
}
cancel()
}
}
// 从 Headless Service 自动创建 Redis 连接
func NewRedisFromHeadlessService(headlessAddr string, password string) (Cache, error) {
master, replicas, err := ParseHeadlessServiceAddr(headlessAddr)
if err != nil {
return nil, err
}
config := NewConfig("redis", headlessAddr)
config.MasterAddr = master
config.ReplicaAddrs = replicas
config.Password = password
return NewRedis(config)
}