Files
upkg/database/cache/redis.go
2026-01-28 10:28:13 +08:00

316 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package cache
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
type RedisCache struct {
master *redis.Client
replica *redis.Client
config *Config
// 重连相关
mu sync.RWMutex
closed bool
ticker *time.Ticker
done chan struct{}
}
func NewRedis(config *Config) (Cache, error) {
rdb := &RedisCache{
config: config,
done: make(chan struct{}),
}
// 初始化主节点连接
if config.MasterAddr != "" {
rdb.master = rdb.createClient(config.MasterAddr, true)
} else {
// 如果没有指定 master使用 Addr
rdb.master = rdb.createClient(config.Addr, true)
}
// 初始化从节点连接(用于只读操作)
if len(config.ReplicaAddrs) > 0 {
// 如果有多个副本,使用第一个副本
rdb.replica = rdb.createClient(config.ReplicaAddrs[0], false)
} else {
// 如果没有指定副本,复用 master 连接
rdb.replica = rdb.master
}
// 启动自动重连
if config.Reconnect {
rdb.startReconnect()
}
return rdb, nil
}
func (r *RedisCache) createClient(addr string, isMaster bool) *redis.Client {
return redis.NewClient(&redis.Options{
Addr: addr,
Password: r.config.Password,
DB: r.config.DB,
DialTimeout: r.config.DialTimeout,
ReadTimeout: r.config.ReadTimeout,
WriteTimeout: r.config.WriteTimeout,
PoolSize: r.config.PoolSize,
})
}
func (r *RedisCache) getClient(readOnly bool) *redis.Client {
r.mu.RLock()
defer r.mu.RUnlock()
if readOnly && r.replica != nil && !r.config.ReadOnly {
return r.replica
}
return r.master
}
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
return r.getClient(false).Set(ctx, key, value, exp).Err()
}
func (r *RedisCache) Get(ctx context.Context, key string) (string, error) {
result, err := r.getClient(true).Get(ctx, key).Result()
if err == redis.Nil {
return "", ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) GetBytes(ctx context.Context, key string) ([]byte, error) {
val, err := r.Get(ctx, key)
if err != nil {
return nil, err
}
return []byte(val), nil
}
func (r *RedisCache) GetScan(ctx context.Context, key string, dest interface{}) error {
val, err := r.Get(ctx, key)
if err != nil {
return err
}
return json.Unmarshal([]byte(val), dest)
}
func (r *RedisCache) Del(ctx context.Context, keys ...string) error {
return r.getClient(false).Del(ctx, keys...).Err()
}
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
result, err := r.getClient(true).Exists(ctx, key).Result()
return result > 0, err
}
func (r *RedisCache) Expire(ctx context.Context, key string, expiration time.Duration) error {
return r.getClient(false).Expire(ctx, key, expiration).Err()
}
func (r *RedisCache) TTL(ctx context.Context, key string) (time.Duration, error) {
return r.getClient(true).TTL(ctx, key).Result()
}
func (r *RedisCache) Inc(ctx context.Context, key string) (int64, error) {
return r.getClient(false).Incr(ctx, key).Result()
}
func (r *RedisCache) IncBy(ctx context.Context, key string, value int64) (int64, error) {
return r.getClient(false).IncrBy(ctx, key, value).Result()
}
func (r *RedisCache) Dec(ctx context.Context, key string) (int64, error) {
return r.getClient(false).Decr(ctx, key).Result()
}
func (r *RedisCache) SetNX(ctx context.Context, key string, value interface{}, expiration ...time.Duration) (bool, error) {
var exp time.Duration
if len(expiration) > 0 {
exp = expiration[0]
}
return r.getClient(false).SetNX(ctx, key, value, exp).Result()
}
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
return r.getClient(true).Keys(ctx, pattern).Result()
}
func (r *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
return r.getClient(false).HSet(ctx, key, field, value).Err()
}
func (r *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
result, err := r.getClient(true).HGet(ctx, key, field).Result()
if err == redis.Nil {
return "", ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
result, err := r.getClient(true).HGetAll(ctx, key).Result()
if err == redis.Nil {
return nil, ErrKeyNotFound
}
return result, err
}
func (r *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
return r.getClient(false).HDel(ctx, key, fields...).Err()
}
func (r *RedisCache) HExists(ctx context.Context, key string, field string) (bool, error) {
return r.getClient(true).HExists(ctx, key, field).Result()
}
func (r *RedisCache) HKeys(ctx context.Context, key string) ([]string, error) {
return r.getClient(true).HKeys(ctx, key).Result()
}
func (r *RedisCache) HLen(ctx context.Context, key string) (int64, error) {
return r.getClient(true).HLen(ctx, key).Result()
}
func (r *RedisCache) HIncrBy(ctx context.Context, key string, field string, increment int64) (int64, error) {
return r.getClient(false).HIncrBy(ctx, key, field, increment).Result()
}
func (r *RedisCache) Close() error {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return nil
}
r.closed = true
// 停止重连定时器
if r.ticker != nil {
r.ticker.Stop()
}
close(r.done)
// 关闭连接
if r.master != nil {
r.master.Close()
}
if r.replica != nil && r.replica != r.master {
r.replica.Close()
}
return nil
}
// 支持解析 Kubernetes Headless Service 地址的辅助函数
func ParseHeadlessServiceAddr(addr string) (master string, replicas []string, err error) {
// 格式: my-redis-headless.my-namespace.svc.cluster.local:6379
if !strings.Contains(addr, ".svc.") {
// 非集群地址,作为单节点处理
return addr, nil, nil
}
// 这里可以通过 DNS SRV 记录查询获取所有 pod 地址
// 简化实现:假设已知命名空间和服务名,返回 master 和多个副本
parts := strings.SplitN(addr, ".", 2)
if len(parts) < 2 {
return "", nil, fmt.Errorf("invalid headless service address")
}
serviceName := parts[0]
namespace := strings.SplitN(parts[1], ".", 2)[0]
// Kubernetes headless service 模式下,第一个 pod 作为 master
// 其余作为 replicas
master = fmt.Sprintf("%s-0.%s.%s.svc.cluster.local:6379", serviceName, serviceName, namespace)
for i := 1; i <= 2; i++ { // 假设有 2 个副本
replica := fmt.Sprintf("%s-%d.%s.%s.svc.cluster.local:6379", serviceName, i, serviceName, namespace)
replicas = append(replicas, replica)
}
return master, replicas, nil
}
func (r *RedisCache) startReconnect() {
r.ticker = time.NewTicker(r.config.ReconnectInterval)
go func() {
for {
select {
case <-r.done:
return
case <-r.ticker.C:
r.checkAndReconnect()
}
}
}()
}
func (r *RedisCache) checkAndReconnect() {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return
}
// 检查主节点连接
if r.master != nil {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := r.master.Ping(ctx).Err(); err != nil {
fmt.Printf("Master connection lost: %v, attempting reconnect...\n", err)
if r.config.MasterAddr != "" {
r.master.Close()
r.master = r.createClient(r.config.MasterAddr, true)
} else {
r.master.Close()
r.master = r.createClient(r.config.Addr, true)
}
}
cancel()
}
// 检查副本节点连接如果与master不同
if r.replica != nil && r.replica != r.master {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err := r.replica.Ping(ctx).Err(); err != nil {
fmt.Printf("Replica connection lost: %v, attempting reconnect...\n", err)
if len(r.config.ReplicaAddrs) > 0 {
r.replica.Close()
r.replica = r.createClient(r.config.ReplicaAddrs[0], false)
}
}
cancel()
}
}
// 从 Headless Service 自动创建 Redis 连接
func NewRedisFromHeadlessService(headlessAddr string, password string) (Cache, error) {
master, replicas, err := ParseHeadlessServiceAddr(headlessAddr)
if err != nil {
return nil, err
}
config := NewConfig("redis", headlessAddr)
config.MasterAddr = master
config.ReplicaAddrs = replicas
config.Password = password
return NewRedis(config)
}