316 lines
8.2 KiB
Go
316 lines
8.2 KiB
Go
package cache
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
type RedisCache struct {
|
||
master *redis.Client
|
||
replica *redis.Client
|
||
config *Config
|
||
|
||
// 重连相关
|
||
mu sync.RWMutex
|
||
closed bool
|
||
ticker *time.Ticker
|
||
done chan struct{}
|
||
}
|
||
|
||
func NewRedis(config *Config) (Cache, error) {
|
||
rdb := &RedisCache{
|
||
config: config,
|
||
done: make(chan struct{}),
|
||
}
|
||
|
||
// 初始化主节点连接
|
||
if config.MasterAddr != "" {
|
||
rdb.master = rdb.createClient(config.MasterAddr, true)
|
||
} else {
|
||
// 如果没有指定 master,使用 Addr
|
||
rdb.master = rdb.createClient(config.Addr, true)
|
||
}
|
||
|
||
// 初始化从节点连接(用于只读操作)
|
||
if len(config.ReplicaAddrs) > 0 {
|
||
// 如果有多个副本,使用第一个副本
|
||
rdb.replica = rdb.createClient(config.ReplicaAddrs[0], false)
|
||
} else {
|
||
// 如果没有指定副本,复用 master 连接
|
||
rdb.replica = rdb.master
|
||
}
|
||
|
||
// 启动自动重连
|
||
if config.Reconnect {
|
||
rdb.startReconnect()
|
||
}
|
||
|
||
return rdb, nil
|
||
}
|
||
|
||
func (r *RedisCache) createClient(addr string, isMaster bool) *redis.Client {
|
||
return redis.NewClient(&redis.Options{
|
||
Addr: addr,
|
||
Password: r.config.Password,
|
||
DB: r.config.DB,
|
||
DialTimeout: r.config.DialTimeout,
|
||
ReadTimeout: r.config.ReadTimeout,
|
||
WriteTimeout: r.config.WriteTimeout,
|
||
PoolSize: r.config.PoolSize,
|
||
})
|
||
}
|
||
|
||
func (r *RedisCache) getClient(readOnly bool) *redis.Client {
|
||
r.mu.RLock()
|
||
defer r.mu.RUnlock()
|
||
|
||
if readOnly && r.replica != nil && !r.config.ReadOnly {
|
||
return r.replica
|
||
}
|
||
return r.master
|
||
}
|
||
|
||
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
|
||
var exp time.Duration
|
||
if len(expiration) > 0 {
|
||
exp = expiration[0]
|
||
}
|
||
return r.getClient(false).Set(ctx, key, value, exp).Err()
|
||
}
|
||
|
||
func (r *RedisCache) Get(ctx context.Context, key string) (string, error) {
|
||
result, err := r.getClient(true).Get(ctx, key).Result()
|
||
if err == redis.Nil {
|
||
return "", ErrKeyNotFound
|
||
}
|
||
return result, err
|
||
}
|
||
|
||
func (r *RedisCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
|
||
val, err := r.Get(ctx, key)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return []byte(val), nil
|
||
}
|
||
|
||
func (r *RedisCache) GetScan(ctx context.Context, key string, dest interface{}) error {
|
||
val, err := r.Get(ctx, key)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return json.Unmarshal([]byte(val), dest)
|
||
}
|
||
|
||
func (r *RedisCache) Del(ctx context.Context, keys ...string) error {
|
||
return r.getClient(false).Del(ctx, keys...).Err()
|
||
}
|
||
|
||
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||
result, err := r.getClient(true).Exists(ctx, key).Result()
|
||
return result > 0, err
|
||
}
|
||
|
||
func (r *RedisCache) Expire(ctx context.Context, key string, expiration time.Duration) error {
|
||
return r.getClient(false).Expire(ctx, key, expiration).Err()
|
||
}
|
||
|
||
func (r *RedisCache) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||
return r.getClient(true).TTL(ctx, key).Result()
|
||
}
|
||
|
||
func (r *RedisCache) Inc(ctx context.Context, key string) (int64, error) {
|
||
return r.getClient(false).Incr(ctx, key).Result()
|
||
}
|
||
|
||
func (r *RedisCache) IncBy(ctx context.Context, key string, value int64) (int64, error) {
|
||
return r.getClient(false).IncrBy(ctx, key, value).Result()
|
||
}
|
||
|
||
func (r *RedisCache) Dec(ctx context.Context, key string) (int64, error) {
|
||
return r.getClient(false).Decr(ctx, key).Result()
|
||
}
|
||
|
||
func (r *RedisCache) SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error) {
|
||
var exp time.Duration
|
||
if len(expiration) > 0 {
|
||
exp = expiration[0]
|
||
}
|
||
return r.getClient(false).SetNX(ctx, key, value, exp).Result()
|
||
}
|
||
|
||
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||
return r.getClient(true).Keys(ctx, pattern).Result()
|
||
}
|
||
|
||
func (r *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
|
||
return r.getClient(false).HSet(ctx, key, field, value).Err()
|
||
}
|
||
|
||
func (r *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
|
||
result, err := r.getClient(true).HGet(ctx, key, field).Result()
|
||
if err == redis.Nil {
|
||
return "", ErrKeyNotFound
|
||
}
|
||
return result, err
|
||
}
|
||
|
||
func (r *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||
result, err := r.getClient(true).HGetAll(ctx, key).Result()
|
||
if err == redis.Nil {
|
||
return nil, ErrKeyNotFound
|
||
}
|
||
return result, err
|
||
}
|
||
|
||
func (r *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
|
||
return r.getClient(false).HDel(ctx, key, fields...).Err()
|
||
}
|
||
|
||
func (r *RedisCache) HExists(ctx context.Context, key string, field string) (bool, error) {
|
||
return r.getClient(true).HExists(ctx, key, field).Result()
|
||
}
|
||
|
||
func (r *RedisCache) HKeys(ctx context.Context, key string) ([]string, error) {
|
||
return r.getClient(true).HKeys(ctx, key).Result()
|
||
}
|
||
|
||
func (r *RedisCache) HLen(ctx context.Context, key string) (int64, error) {
|
||
return r.getClient(true).HLen(ctx, key).Result()
|
||
}
|
||
|
||
func (r *RedisCache) HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error) {
|
||
return r.getClient(false).HIncrBy(ctx, key, field, increment).Result()
|
||
}
|
||
|
||
func (r *RedisCache) Close() error {
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
|
||
if r.closed {
|
||
return nil
|
||
}
|
||
|
||
r.closed = true
|
||
|
||
// 停止重连定时器
|
||
if r.ticker != nil {
|
||
r.ticker.Stop()
|
||
}
|
||
close(r.done)
|
||
|
||
// 关闭连接
|
||
if r.master != nil {
|
||
r.master.Close()
|
||
}
|
||
if r.replica != nil && r.replica != r.master {
|
||
r.replica.Close()
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 支持解析 Kubernetes Headless Service 地址的辅助函数
|
||
func ParseHeadlessServiceAddr(addr string) (master string, replicas []string, err error) {
|
||
// 格式: my-redis-headless.my-namespace.svc.cluster.local:6379
|
||
if !strings.Contains(addr, ".svc.") {
|
||
// 非集群地址,作为单节点处理
|
||
return addr, nil, nil
|
||
}
|
||
|
||
// 这里可以通过 DNS SRV 记录查询获取所有 pod 地址
|
||
// 简化实现:假设已知命名空间和服务名,返回 master 和多个副本
|
||
parts := strings.SplitN(addr, ".", 2)
|
||
if len(parts) < 2 {
|
||
return "", nil, fmt.Errorf("invalid headless service address")
|
||
}
|
||
|
||
serviceName := parts[0]
|
||
namespace := strings.SplitN(parts[1], ".", 2)[0]
|
||
|
||
// Kubernetes headless service 模式下,第一个 pod 作为 master
|
||
// 其余作为 replicas
|
||
master = fmt.Sprintf("%s-0.%s.%s.svc.cluster.local:6379", serviceName, serviceName, namespace)
|
||
|
||
for i := 1; i <= 2; i++ { // 假设有 2 个副本
|
||
replica := fmt.Sprintf("%s-%d.%s.%s.svc.cluster.local:6379", serviceName, i, serviceName, namespace)
|
||
replicas = append(replicas, replica)
|
||
}
|
||
|
||
return master, replicas, nil
|
||
}
|
||
|
||
func (r *RedisCache) startReconnect() {
|
||
r.ticker = time.NewTicker(r.config.ReconnectInterval)
|
||
|
||
go func() {
|
||
for {
|
||
select {
|
||
case <-r.done:
|
||
return
|
||
case <-r.ticker.C:
|
||
r.checkAndReconnect()
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (r *RedisCache) checkAndReconnect() {
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
|
||
if r.closed {
|
||
return
|
||
}
|
||
|
||
// 检查主节点连接
|
||
if r.master != nil {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
if err := r.master.Ping(ctx).Err(); err != nil {
|
||
fmt.Printf("Master connection lost: %v, attempting reconnect...\n", err)
|
||
if r.config.MasterAddr != "" {
|
||
r.master.Close()
|
||
r.master = r.createClient(r.config.MasterAddr, true)
|
||
} else {
|
||
r.master.Close()
|
||
r.master = r.createClient(r.config.Addr, true)
|
||
}
|
||
}
|
||
cancel()
|
||
}
|
||
|
||
// 检查副本节点连接(如果与master不同)
|
||
if r.replica != nil && r.replica != r.master {
|
||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
if err := r.replica.Ping(ctx).Err(); err != nil {
|
||
fmt.Printf("Replica connection lost: %v, attempting reconnect...\n", err)
|
||
if len(r.config.ReplicaAddrs) > 0 {
|
||
r.replica.Close()
|
||
r.replica = r.createClient(r.config.ReplicaAddrs[0], false)
|
||
}
|
||
}
|
||
cancel()
|
||
}
|
||
}
|
||
|
||
// 从 Headless Service 自动创建 Redis 连接
|
||
func NewRedisFromHeadlessService(headlessAddr string, password string) (Cache, error) {
|
||
master, replicas, err := ParseHeadlessServiceAddr(headlessAddr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
config := NewConfig("redis", headlessAddr)
|
||
config.MasterAddr = master
|
||
config.ReplicaAddrs = replicas
|
||
config.Password = password
|
||
|
||
return NewRedis(config)
|
||
}
|