package cache import ( "context" "encoding/json" "fmt" "strings" "sync" "time" "github.com/redis/go-redis/v9" ) type RedisCache struct { master *redis.Client replica *redis.Client config *Config // 重连相关 mu sync.RWMutex closed bool ticker *time.Ticker done chan struct{} } func NewRedis(config *Config) (Cache, error) { rdb := &RedisCache{ config: config, done: make(chan struct{}), } // 初始化主节点连接 if config.MasterAddr != "" { rdb.master = rdb.createClient(config.MasterAddr, true) } else { // 如果没有指定 master,使用 Addr rdb.master = rdb.createClient(config.Addr, true) } // 初始化从节点连接(用于只读操作) if len(config.ReplicaAddrs) > 0 { // 如果有多个副本,使用第一个副本 rdb.replica = rdb.createClient(config.ReplicaAddrs[0], false) } else { // 如果没有指定副本,复用 master 连接 rdb.replica = rdb.master } // 启动自动重连 if config.Reconnect { rdb.startReconnect() } return rdb, nil } func (r *RedisCache) createClient(addr string, isMaster bool) *redis.Client { return redis.NewClient(&redis.Options{ Addr: addr, Password: r.config.Password, DB: r.config.DB, DialTimeout: r.config.DialTimeout, ReadTimeout: r.config.ReadTimeout, WriteTimeout: r.config.WriteTimeout, PoolSize: r.config.PoolSize, }) } func (r *RedisCache) getClient(readOnly bool) *redis.Client { r.mu.RLock() defer r.mu.RUnlock() if readOnly && r.replica != nil && !r.config.ReadOnly { return r.replica } return r.master } func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error { var exp time.Duration if len(expiration) > 0 { exp = expiration[0] } return r.getClient(false).Set(ctx, key, value, exp).Err() } func (r *RedisCache) Get(ctx context.Context, key string) (string, error) { result, err := r.getClient(true).Get(ctx, key).Result() if err == redis.Nil { return "", ErrKeyNotFound } return result, err } func (r *RedisCache) GetBytes(ctx context.Context, key string) ([]byte, error) { val, err := r.Get(ctx, key) if err != nil { return nil, err } return []byte(val), nil } func (r *RedisCache) GetScan(ctx context.Context, key string, dest interface{}) error { val, err := r.Get(ctx, key) if err != nil { return err } return json.Unmarshal([]byte(val), dest) } func (r *RedisCache) Del(ctx context.Context, keys ...string) error { return r.getClient(false).Del(ctx, keys...).Err() } func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) { result, err := r.getClient(true).Exists(ctx, key).Result() return result > 0, err } func (r *RedisCache) Expire(ctx context.Context, key string, expiration time.Duration) error { return r.getClient(false).Expire(ctx, key, expiration).Err() } func (r *RedisCache) TTL(ctx context.Context, key string) (time.Duration, error) { return r.getClient(true).TTL(ctx, key).Result() } func (r *RedisCache) Inc(ctx context.Context, key string) (int64, error) { return r.getClient(false).Incr(ctx, key).Result() } func (r *RedisCache) IncBy(ctx context.Context, key string, value int64) (int64, error) { return r.getClient(false).IncrBy(ctx, key, value).Result() } func (r *RedisCache) Dec(ctx context.Context, key string) (int64, error) { return r.getClient(false).Decr(ctx, key).Result() } func (r *RedisCache) SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error) { var exp time.Duration if len(expiration) > 0 { exp = expiration[0] } return r.getClient(false).SetNX(ctx, key, value, exp).Result() } func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) { return r.getClient(true).Keys(ctx, pattern).Result() } func (r *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error { return r.getClient(false).HSet(ctx, key, field, value).Err() } func (r *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) { result, err := r.getClient(true).HGet(ctx, key, field).Result() if err == redis.Nil { return "", ErrKeyNotFound } return result, err } func (r *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) { result, err := r.getClient(true).HGetAll(ctx, key).Result() if err == redis.Nil { return nil, ErrKeyNotFound } return result, err } func (r *RedisCache) HDel(ctx context.Context, key string, fields ...string) error { return r.getClient(false).HDel(ctx, key, fields...).Err() } func (r *RedisCache) HExists(ctx context.Context, key string, field string) (bool, error) { return r.getClient(true).HExists(ctx, key, field).Result() } func (r *RedisCache) HKeys(ctx context.Context, key string) ([]string, error) { return r.getClient(true).HKeys(ctx, key).Result() } func (r *RedisCache) HLen(ctx context.Context, key string) (int64, error) { return r.getClient(true).HLen(ctx, key).Result() } func (r *RedisCache) HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error) { return r.getClient(false).HIncrBy(ctx, key, field, increment).Result() } func (r *RedisCache) Close() error { r.mu.Lock() defer r.mu.Unlock() if r.closed { return nil } r.closed = true // 停止重连定时器 if r.ticker != nil { r.ticker.Stop() } close(r.done) // 关闭连接 if r.master != nil { r.master.Close() } if r.replica != nil && r.replica != r.master { r.replica.Close() } return nil } // 支持解析 Kubernetes Headless Service 地址的辅助函数 func ParseHeadlessServiceAddr(addr string) (master string, replicas []string, err error) { // 格式: my-redis-headless.my-namespace.svc.cluster.local:6379 if !strings.Contains(addr, ".svc.") { // 非集群地址,作为单节点处理 return addr, nil, nil } // 这里可以通过 DNS SRV 记录查询获取所有 pod 地址 // 简化实现:假设已知命名空间和服务名,返回 master 和多个副本 parts := strings.SplitN(addr, ".", 2) if len(parts) < 2 { return "", nil, fmt.Errorf("invalid headless service address") } serviceName := parts[0] namespace := strings.SplitN(parts[1], ".", 2)[0] // Kubernetes headless service 模式下,第一个 pod 作为 master // 其余作为 replicas master = fmt.Sprintf("%s-0.%s.%s.svc.cluster.local:6379", serviceName, serviceName, namespace) for i := 1; i <= 2; i++ { // 假设有 2 个副本 replica := fmt.Sprintf("%s-%d.%s.%s.svc.cluster.local:6379", serviceName, i, serviceName, namespace) replicas = append(replicas, replica) } return master, replicas, nil } func (r *RedisCache) startReconnect() { r.ticker = time.NewTicker(r.config.ReconnectInterval) go func() { for { select { case <-r.done: return case <-r.ticker.C: r.checkAndReconnect() } } }() } func (r *RedisCache) checkAndReconnect() { r.mu.Lock() defer r.mu.Unlock() if r.closed { return } // 检查主节点连接 if r.master != nil { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) if err := r.master.Ping(ctx).Err(); err != nil { fmt.Printf("Master connection lost: %v, attempting reconnect...\n", err) if r.config.MasterAddr != "" { r.master.Close() r.master = r.createClient(r.config.MasterAddr, true) } else { r.master.Close() r.master = r.createClient(r.config.Addr, true) } } cancel() } // 检查副本节点连接(如果与master不同) if r.replica != nil && r.replica != r.master { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) if err := r.replica.Ping(ctx).Err(); err != nil { fmt.Printf("Replica connection lost: %v, attempting reconnect...\n", err) if len(r.config.ReplicaAddrs) > 0 { r.replica.Close() r.replica = r.createClient(r.config.ReplicaAddrs[0], false) } } cancel() } } // 从 Headless Service 自动创建 Redis 连接 func NewRedisFromHeadlessService(headlessAddr string, password string) (Cache, error) { master, replicas, err := ParseHeadlessServiceAddr(headlessAddr) if err != nil { return nil, err } config := NewConfig("redis", headlessAddr) config.MasterAddr = master config.ReplicaAddrs = replicas config.Password = password return NewRedis(config) }