wip v1.0.0
This commit is contained in:
25
internal/cmd/root.go
Normal file
25
internal/cmd/root.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "go-alived",
|
||||
Short: "Go-Alived - VRRP High Availability Service",
|
||||
Long: `go-alived is a lightweight, dependency-free VRRP implementation in Go.
|
||||
It provides high availability for IP addresses with health checking support.`,
|
||||
Version: "1.0.0",
|
||||
}
|
||||
|
||||
func Execute() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.CompletionOptions.DisableDefaultCmd = true
|
||||
}
|
||||
133
internal/cmd/run.go
Normal file
133
internal/cmd/run.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/loveuer/go-alived/internal/health"
|
||||
"github.com/loveuer/go-alived/internal/vrrp"
|
||||
"github.com/loveuer/go-alived/pkg/config"
|
||||
"github.com/loveuer/go-alived/pkg/logger"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
configFile string
|
||||
debug bool
|
||||
)
|
||||
|
||||
var runCmd = &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "Run the VRRP service",
|
||||
Long: `Start the go-alived VRRP service with health checking.`,
|
||||
Run: runService,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
runCmd.Flags().StringVarP(&configFile, "config", "c", "/etc/go-alived/config.yaml", "path to configuration file")
|
||||
runCmd.Flags().BoolVarP(&debug, "debug", "d", false, "enable debug mode")
|
||||
}
|
||||
|
||||
func runService(cmd *cobra.Command, args []string) {
|
||||
log := logger.New(debug)
|
||||
|
||||
log.Info("starting go-alived...")
|
||||
log.Info("loading configuration from: %s", configFile)
|
||||
|
||||
cfg, err := config.Load(configFile)
|
||||
if err != nil {
|
||||
log.Error("failed to load configuration: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Info("configuration loaded successfully")
|
||||
log.Debug("config: %+v", cfg)
|
||||
|
||||
healthMgr, err := health.LoadFromConfig(cfg, log)
|
||||
if err != nil {
|
||||
log.Error("failed to load health check configuration: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
vrrpMgr := vrrp.NewManager(log)
|
||||
if err := vrrpMgr.LoadFromConfig(cfg); err != nil {
|
||||
log.Error("failed to load VRRP configuration: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
setupHealthTracking(vrrpMgr, healthMgr, log)
|
||||
|
||||
healthMgr.StartAll()
|
||||
|
||||
if err := vrrpMgr.StartAll(); err != nil {
|
||||
log.Error("failed to start VRRP instances: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||
|
||||
for {
|
||||
sig := <-sigChan
|
||||
switch sig {
|
||||
case syscall.SIGHUP:
|
||||
log.Info("received SIGHUP, reloading configuration...")
|
||||
newCfg, err := config.Load(configFile)
|
||||
if err != nil {
|
||||
log.Error("failed to reload configuration: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := vrrpMgr.Reload(newCfg); err != nil {
|
||||
log.Error("failed to reload VRRP: %v", err)
|
||||
continue
|
||||
}
|
||||
cfg = newCfg
|
||||
log.Info("configuration reloaded successfully")
|
||||
case syscall.SIGINT, syscall.SIGTERM:
|
||||
log.Info("received signal %v, shutting down...", sig)
|
||||
cleanup(log, vrrpMgr, healthMgr)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup(log *logger.Logger, vrrpMgr *vrrp.Manager, healthMgr *health.Manager) {
|
||||
log.Info("cleaning up resources...")
|
||||
healthMgr.StopAll()
|
||||
vrrpMgr.StopAll()
|
||||
}
|
||||
|
||||
func setupHealthTracking(vrrpMgr *vrrp.Manager, healthMgr *health.Manager, log *logger.Logger) {
|
||||
instances := vrrpMgr.GetAllInstances()
|
||||
|
||||
for _, inst := range instances {
|
||||
for _, trackScript := range inst.TrackScripts {
|
||||
monitor, ok := healthMgr.GetMonitor(trackScript)
|
||||
if !ok {
|
||||
log.Warn("[%s] track_script '%s' not found in health checkers", inst.Name, trackScript)
|
||||
continue
|
||||
}
|
||||
|
||||
instanceName := inst.Name
|
||||
monitor.OnStateChange(func(checkerName string, oldHealthy, newHealthy bool) {
|
||||
vrrpInst, ok := vrrpMgr.GetInstance(instanceName)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if newHealthy && !oldHealthy {
|
||||
log.Info("[%s] health check '%s' recovered, resetting priority", instanceName, checkerName)
|
||||
vrrpInst.ResetPriority()
|
||||
} else if !newHealthy && oldHealthy {
|
||||
log.Warn("[%s] health check '%s' failed, decreasing priority", instanceName, checkerName)
|
||||
vrrpInst.AdjustPriority(-10)
|
||||
}
|
||||
})
|
||||
|
||||
log.Info("[%s] tracking health check: %s", inst.Name, trackScript)
|
||||
}
|
||||
}
|
||||
}
|
||||
470
internal/cmd/test.go
Normal file
470
internal/cmd/test.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/loveuer/go-alived/pkg/logger"
|
||||
"github.com/loveuer/go-alived/pkg/netif"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type TestResult struct {
|
||||
Name string
|
||||
Pass bool
|
||||
Message string
|
||||
Fatal bool
|
||||
}
|
||||
|
||||
type EnvironmentTest struct {
|
||||
log *logger.Logger
|
||||
results []TestResult
|
||||
errors int
|
||||
warns int
|
||||
}
|
||||
|
||||
func NewEnvironmentTest(log *logger.Logger) *EnvironmentTest {
|
||||
return &EnvironmentTest{
|
||||
log: log,
|
||||
results: make([]TestResult, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) AddResult(name string, pass bool, message string, fatal bool) {
|
||||
t.results = append(t.results, TestResult{
|
||||
Name: name,
|
||||
Pass: pass,
|
||||
Message: message,
|
||||
Fatal: fatal,
|
||||
})
|
||||
|
||||
if !pass {
|
||||
if fatal {
|
||||
t.errors++
|
||||
} else {
|
||||
t.warns++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestRootPermission() {
|
||||
t.log.Info("检查运行权限...")
|
||||
|
||||
if os.Geteuid() != 0 {
|
||||
t.AddResult("Root权限", false, "需要root权限运行,请使用sudo", true)
|
||||
} else {
|
||||
t.AddResult("Root权限", true, "以root用户运行", false)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestNetworkInterface(ifaceName string) string {
|
||||
t.log.Info("检查网络接口...")
|
||||
|
||||
if ifaceName == "" {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
t.AddResult("网络接口", false, "无法获取网络接口列表", true)
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback == 0 {
|
||||
addrs, err := iface.Addrs()
|
||||
if err == nil && len(addrs) > 0 {
|
||||
for _, addr := range addrs {
|
||||
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil {
|
||||
ifaceName = iface.Name
|
||||
t.log.Info("自动选择网卡: %s", ifaceName)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if ifaceName != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ifaceName == "" {
|
||||
t.AddResult("网络接口", false, "未找到可用的网络接口", true)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
iface, err := netif.GetInterface(ifaceName)
|
||||
if err != nil {
|
||||
t.AddResult("网络接口", false, fmt.Sprintf("网卡 %s 不存在", ifaceName), true)
|
||||
return ""
|
||||
}
|
||||
|
||||
if !iface.IsUp() {
|
||||
t.AddResult("网络接口状态", false, fmt.Sprintf("网卡 %s 未启动", ifaceName), true)
|
||||
return ""
|
||||
}
|
||||
|
||||
t.AddResult("网络接口", true, fmt.Sprintf("网卡 %s 存在且已启动", ifaceName), false)
|
||||
return ifaceName
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestVIPOperations(ifaceName, testVIP string) {
|
||||
t.log.Info("测试VIP添加/删除功能...")
|
||||
|
||||
if ifaceName == "" || testVIP == "" {
|
||||
t.AddResult("VIP操作", false, "网卡名或测试VIP为空", true)
|
||||
return
|
||||
}
|
||||
|
||||
iface, err := netif.GetInterface(ifaceName)
|
||||
if err != nil {
|
||||
t.AddResult("VIP操作", false, fmt.Sprintf("获取网卡失败: %v", err), true)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.Contains(testVIP, "/") {
|
||||
testVIP = testVIP + "/32"
|
||||
}
|
||||
|
||||
exists, _ := iface.HasIP(testVIP)
|
||||
if exists {
|
||||
t.AddResult("VIP操作", false, fmt.Sprintf("VIP %s 已存在,请使用其他IP测试", testVIP), true)
|
||||
return
|
||||
}
|
||||
|
||||
err = iface.AddIP(testVIP)
|
||||
if err != nil {
|
||||
t.AddResult("VIP添加", false, fmt.Sprintf("VIP添加失败: %v", err), true)
|
||||
return
|
||||
}
|
||||
|
||||
t.AddResult("VIP添加", true, fmt.Sprintf("成功添加VIP %s", testVIP), false)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
exists, _ = iface.HasIP(testVIP)
|
||||
if !exists {
|
||||
t.AddResult("VIP验证", false, "VIP添加后无法在网卡上找到", true)
|
||||
iface.DeleteIP(testVIP)
|
||||
return
|
||||
}
|
||||
|
||||
t.AddResult("VIP验证", true, "VIP已成功添加到网卡", false)
|
||||
|
||||
vipAddr := strings.Split(testVIP, "/")[0]
|
||||
cmd := exec.Command("ping", "-c", "1", "-W", "1", vipAddr)
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
t.AddResult("VIP可达性", false, "VIP ping失败(可能需要路由配置)", false)
|
||||
} else {
|
||||
t.AddResult("VIP可达性", true, "VIP可以ping通", false)
|
||||
}
|
||||
|
||||
err = iface.DeleteIP(testVIP)
|
||||
if err != nil {
|
||||
t.AddResult("VIP删除", false, fmt.Sprintf("VIP删除失败: %v", err), false)
|
||||
} else {
|
||||
t.AddResult("VIP删除", true, "VIP删除成功", false)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestMulticast(ifaceName string) {
|
||||
t.log.Info("检查组播支持...")
|
||||
|
||||
if ifaceName == "" {
|
||||
t.AddResult("组播支持", false, "网卡名为空,跳过检查", false)
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("ip", "maddr", "show", ifaceName)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.AddResult("组播支持", false, "无法查询组播配置", false)
|
||||
return
|
||||
}
|
||||
|
||||
if len(output) > 0 {
|
||||
t.AddResult("组播支持", true, "网卡支持组播", false)
|
||||
} else {
|
||||
t.AddResult("组播支持", false, "网卡组播支持未知", false)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestFirewall() {
|
||||
t.log.Info("检查防火墙设置...")
|
||||
|
||||
cmd := exec.Command("iptables", "-L", "INPUT", "-n")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.AddResult("防火墙检查", false, "无法查询iptables规则(可能未安装)", false)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.Contains(string(output), "112") || strings.Contains(string(output), "vrrp") {
|
||||
t.AddResult("防火墙VRRP", true, "防火墙已配置VRRP规则", false)
|
||||
} else {
|
||||
t.AddResult("防火墙VRRP", false, "防火墙未配置VRRP规则,建议添加: iptables -A INPUT -p 112 -j ACCEPT", false)
|
||||
}
|
||||
|
||||
cmd = exec.Command("systemctl", "is-active", "firewalld")
|
||||
err = cmd.Run()
|
||||
if err == nil {
|
||||
cmd = exec.Command("firewall-cmd", "--list-protocols")
|
||||
output, err = cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
if strings.Contains(string(output), "vrrp") {
|
||||
t.AddResult("Firewalld VRRP", true, "firewalld已允许VRRP协议", false)
|
||||
} else {
|
||||
t.AddResult("Firewalld VRRP", false, "firewalld未配置VRRP,建议: firewall-cmd --permanent --add-protocol=vrrp", false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestKernelParameters() {
|
||||
t.log.Info("检查内核参数...")
|
||||
|
||||
params := map[string]string{
|
||||
"/proc/sys/net/ipv4/ip_forward": "1",
|
||||
"/proc/sys/net/ipv4/conf/all/arp_ignore": "0",
|
||||
"/proc/sys/net/ipv4/conf/all/arp_announce": "0",
|
||||
}
|
||||
|
||||
for path, expected := range params {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(data))
|
||||
name := strings.TrimPrefix(path, "/proc/sys/net/ipv4/")
|
||||
|
||||
if value == expected {
|
||||
t.AddResult(name, true, fmt.Sprintf("%s = %s (正常)", name, value), false)
|
||||
} else {
|
||||
if name == "ip_forward" && value != "1" {
|
||||
t.AddResult(name, false, fmt.Sprintf("%s = %s (建议设置为1)", name, value), false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestConflictingServices() {
|
||||
t.log.Info("检查冲突服务...")
|
||||
|
||||
services := []string{"keepalived"}
|
||||
hasConflict := false
|
||||
|
||||
for _, service := range services {
|
||||
cmd := exec.Command("systemctl", "is-active", service)
|
||||
err := cmd.Run()
|
||||
if err == nil {
|
||||
t.AddResult("服务冲突", false, fmt.Sprintf("发现运行中的%s服务,可能冲突", service), false)
|
||||
hasConflict = true
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command("pgrep", "-x", "keepalived")
|
||||
err := cmd.Run()
|
||||
if err == nil {
|
||||
t.AddResult("进程冲突", false, "发现运行中的keepalived进程", false)
|
||||
hasConflict = true
|
||||
}
|
||||
|
||||
if !hasConflict {
|
||||
t.AddResult("服务冲突", true, "未发现冲突的服务", false)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestVirtualization() {
|
||||
t.log.Info("检查虚拟化环境...")
|
||||
|
||||
productFile := "/sys/class/dmi/id/product_name"
|
||||
data, err := os.ReadFile(productFile)
|
||||
if err != nil {
|
||||
cmd := exec.Command("systemd-detect-virt")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
virt := strings.TrimSpace(string(output))
|
||||
if virt != "none" {
|
||||
t.AddResult("虚拟化", true, fmt.Sprintf("检测到虚拟化环境: %s", virt), false)
|
||||
t.log.Warn("虚拟化环境可能需要特殊配置(如启用混杂模式)")
|
||||
} else {
|
||||
t.AddResult("虚拟化", true, "物理机环境", false)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
product := strings.TrimSpace(string(data))
|
||||
switch {
|
||||
case strings.Contains(product, "VMware"):
|
||||
t.AddResult("虚拟化", true, "VMware虚拟机(需要启用混杂模式)", false)
|
||||
t.log.Warn("VMware需要配置: 虚拟机设置 -> 网络适配器 -> 高级 -> 混杂模式: 允许全部")
|
||||
case strings.Contains(product, "VirtualBox"):
|
||||
t.AddResult("虚拟化", true, "VirtualBox虚拟机(需要桥接模式+混杂模式)", false)
|
||||
t.log.Warn("VirtualBox需要配置: 网络 -> 桥接网卡 -> 高级 -> 混杂模式: 全部允许")
|
||||
case strings.Contains(product, "KVM") || strings.Contains(product, "QEMU"):
|
||||
t.AddResult("虚拟化", true, "KVM/QEMU虚拟机(通常支持良好)", false)
|
||||
case strings.Contains(product, "Amazon") || strings.Contains(product, "EC2"):
|
||||
t.AddResult("虚拟化", false, "AWS EC2环境 - 不支持VRRP", true)
|
||||
t.log.Error("AWS不支持组播协议,无法运行VRRP,请使用Elastic IP或负载均衡")
|
||||
default:
|
||||
t.AddResult("虚拟化", true, fmt.Sprintf("环境: %s", product), false)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) TestCloudEnvironment() {
|
||||
t.log.Info("检查云环境...")
|
||||
|
||||
cloudTests := []struct {
|
||||
name string
|
||||
url string
|
||||
headers map[string]string
|
||||
isFatal bool
|
||||
solution string
|
||||
}{
|
||||
{
|
||||
name: "AWS",
|
||||
url: "http://169.254.169.254/latest/meta-data/instance-id",
|
||||
solution: "AWS不支持VRRP,请使用: Elastic IP、ALB或NLB",
|
||||
isFatal: true,
|
||||
},
|
||||
{
|
||||
name: "阿里云",
|
||||
url: "http://100.100.100.200/latest/meta-data/instance-id",
|
||||
solution: "阿里云ECS不支持VRRP,请使用: 负载均衡SLB或高可用虚拟IP(HaVip)",
|
||||
isFatal: true,
|
||||
},
|
||||
{
|
||||
name: "Azure",
|
||||
url: "http://169.254.169.254/metadata/instance?api-version=2021-02-01",
|
||||
headers: map[string]string{"Metadata": "true"},
|
||||
solution: "Azure建议使用: Azure Load Balancer或Traffic Manager",
|
||||
isFatal: false,
|
||||
},
|
||||
{
|
||||
name: "Google Cloud",
|
||||
url: "http://metadata.google.internal/computeMetadata/v1/instance/id",
|
||||
headers: map[string]string{"Metadata-Flavor": "Google"},
|
||||
solution: "GCP建议使用: Cloud Load Balancing",
|
||||
isFatal: false,
|
||||
},
|
||||
}
|
||||
|
||||
cloudDetected := false
|
||||
for _, test := range cloudTests {
|
||||
cmd := exec.Command("curl", "-s", "-m", "1", test.url)
|
||||
if len(test.headers) > 0 {
|
||||
for k, v := range test.headers {
|
||||
cmd.Args = append(cmd.Args, "-H", fmt.Sprintf("%s: %s", k, v))
|
||||
}
|
||||
}
|
||||
|
||||
err := cmd.Run()
|
||||
if err == nil {
|
||||
cloudDetected = true
|
||||
t.AddResult("云环境", !test.isFatal, fmt.Sprintf("检测到%s环境", test.name), test.isFatal)
|
||||
t.log.Warn(test.solution)
|
||||
}
|
||||
}
|
||||
|
||||
if !cloudDetected {
|
||||
t.AddResult("云环境", true, "未检测到公有云环境限制", false)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) PrintResults() {
|
||||
fmt.Println()
|
||||
fmt.Println("=== 测试结果 ===")
|
||||
fmt.Println()
|
||||
|
||||
for _, result := range t.results {
|
||||
status := "✓"
|
||||
if !result.Pass {
|
||||
if result.Fatal {
|
||||
status = "✗"
|
||||
} else {
|
||||
status = "⚠"
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("%s %-20s %s\n", status, result.Name, result.Message)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("=== 总结 ===")
|
||||
fmt.Println()
|
||||
|
||||
if t.errors == 0 && t.warns == 0 {
|
||||
fmt.Println("✓ 环境完全支持 go-alived")
|
||||
fmt.Println(" 可以正常使用所有功能")
|
||||
} else if t.errors == 0 {
|
||||
fmt.Printf("⚠ 环境基本支持,但有 %d 个警告\n", t.warns)
|
||||
fmt.Println(" 建议修复警告项以获得更好的稳定性")
|
||||
} else {
|
||||
fmt.Printf("✗ 发现 %d 个错误, %d 个警告\n", t.errors, t.warns)
|
||||
fmt.Println(" 请修复错误后再使用 go-alived")
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func (t *EnvironmentTest) HasErrors() bool {
|
||||
return t.errors > 0
|
||||
}
|
||||
|
||||
var (
|
||||
testIface string
|
||||
testVIP string
|
||||
)
|
||||
|
||||
var testCmd = &cobra.Command{
|
||||
Use: "test",
|
||||
Short: "Test environment for VRRP support",
|
||||
Long: `Test the current environment to verify if it supports VRRP functionality.
|
||||
This includes checking permissions, network interfaces, VIP operations, multicast support, and more.`,
|
||||
Run: runTest,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(testCmd)
|
||||
|
||||
testCmd.Flags().StringVarP(&testIface, "interface", "i", "", "network interface to test (auto-detect if not specified)")
|
||||
testCmd.Flags().StringVarP(&testVIP, "vip", "v", "", "test VIP address (e.g., 192.168.1.100/24)")
|
||||
}
|
||||
|
||||
func runTest(cmd *cobra.Command, args []string) {
|
||||
log := logger.New(false)
|
||||
|
||||
fmt.Println("=== go-alived 环境测试 ===")
|
||||
fmt.Println()
|
||||
|
||||
test := NewEnvironmentTest(log)
|
||||
|
||||
test.TestRootPermission()
|
||||
|
||||
selectedIface := test.TestNetworkInterface(testIface)
|
||||
|
||||
if selectedIface != "" && testVIP != "" {
|
||||
test.TestVIPOperations(selectedIface, testVIP)
|
||||
}
|
||||
|
||||
if selectedIface != "" {
|
||||
test.TestMulticast(selectedIface)
|
||||
}
|
||||
|
||||
test.TestFirewall()
|
||||
test.TestKernelParameters()
|
||||
test.TestConflictingServices()
|
||||
test.TestVirtualization()
|
||||
test.TestCloudEnvironment()
|
||||
|
||||
test.PrintResults()
|
||||
|
||||
if test.HasErrors() {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
89
internal/health/checker.go
Normal file
89
internal/health/checker.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CheckResult int
|
||||
|
||||
const (
|
||||
CheckResultUnknown CheckResult = iota
|
||||
CheckResultSuccess
|
||||
CheckResultFailure
|
||||
)
|
||||
|
||||
func (r CheckResult) String() string {
|
||||
switch r {
|
||||
case CheckResultSuccess:
|
||||
return "SUCCESS"
|
||||
case CheckResultFailure:
|
||||
return "FAILURE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
type Checker interface {
|
||||
Check(ctx context.Context) CheckResult
|
||||
Name() string
|
||||
Type() string
|
||||
}
|
||||
|
||||
type CheckerConfig struct {
|
||||
Name string
|
||||
Type string
|
||||
Interval time.Duration
|
||||
Timeout time.Duration
|
||||
Rise int
|
||||
Fall int
|
||||
Config map[string]interface{}
|
||||
}
|
||||
|
||||
type CheckerState struct {
|
||||
Name string
|
||||
Healthy bool
|
||||
LastResult CheckResult
|
||||
LastCheckTime time.Time
|
||||
SuccessCount int
|
||||
FailureCount int
|
||||
TotalChecks int
|
||||
ConsecutiveOK int
|
||||
ConsecutiveFail int
|
||||
}
|
||||
|
||||
func (s *CheckerState) IsHealthy() bool {
|
||||
return s.Healthy
|
||||
}
|
||||
|
||||
func (s *CheckerState) Update(result CheckResult, rise, fall int) bool {
|
||||
s.LastResult = result
|
||||
s.LastCheckTime = time.Now()
|
||||
s.TotalChecks++
|
||||
|
||||
oldHealthy := s.Healthy
|
||||
|
||||
switch result {
|
||||
case CheckResultSuccess:
|
||||
s.SuccessCount++
|
||||
s.ConsecutiveOK++
|
||||
s.ConsecutiveFail = 0
|
||||
|
||||
if !s.Healthy && s.ConsecutiveOK >= rise {
|
||||
s.Healthy = true
|
||||
}
|
||||
|
||||
case CheckResultFailure:
|
||||
s.FailureCount++
|
||||
s.ConsecutiveFail++
|
||||
s.ConsecutiveOK = 0
|
||||
|
||||
if s.Healthy && s.ConsecutiveFail >= fall {
|
||||
s.Healthy = false
|
||||
}
|
||||
}
|
||||
|
||||
return s.Healthy != oldHealthy
|
||||
}
|
||||
|
||||
type StateChangeCallback func(name string, oldHealthy, newHealthy bool)
|
||||
56
internal/health/factory.go
Normal file
56
internal/health/factory.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/loveuer/go-alived/pkg/config"
|
||||
"github.com/loveuer/go-alived/pkg/logger"
|
||||
)
|
||||
|
||||
func CreateChecker(cfg *config.HealthChecker) (Checker, error) {
|
||||
configMap, ok := cfg.Config.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid config for checker %s", cfg.Name)
|
||||
}
|
||||
|
||||
switch cfg.Type {
|
||||
case "tcp":
|
||||
return NewTCPChecker(cfg.Name, configMap)
|
||||
case "http", "https":
|
||||
return NewHTTPChecker(cfg.Name, configMap)
|
||||
case "ping", "icmp":
|
||||
return NewPingChecker(cfg.Name, configMap)
|
||||
case "script":
|
||||
return NewScriptChecker(cfg.Name, configMap)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported checker type: %s", cfg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func LoadFromConfig(cfg *config.Config, log *logger.Logger) (*Manager, error) {
|
||||
manager := NewManager(log)
|
||||
|
||||
for _, healthCfg := range cfg.Health {
|
||||
checker, err := CreateChecker(&healthCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create checker %s: %w", healthCfg.Name, err)
|
||||
}
|
||||
|
||||
monitorCfg := &CheckerConfig{
|
||||
Name: healthCfg.Name,
|
||||
Type: healthCfg.Type,
|
||||
Interval: healthCfg.Interval,
|
||||
Timeout: healthCfg.Timeout,
|
||||
Rise: healthCfg.Rise,
|
||||
Fall: healthCfg.Fall,
|
||||
Config: healthCfg.Config.(map[string]interface{}),
|
||||
}
|
||||
|
||||
monitor := NewMonitor(checker, monitorCfg, log)
|
||||
manager.AddMonitor(monitor)
|
||||
|
||||
log.Info("loaded health checker: %s (type=%s)", healthCfg.Name, healthCfg.Type)
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
90
internal/health/http.go
Normal file
90
internal/health/http.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HTTPChecker struct {
|
||||
name string
|
||||
url string
|
||||
method string
|
||||
expectedStatus int
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewHTTPChecker(name string, config map[string]interface{}) (*HTTPChecker, error) {
|
||||
url, ok := config["url"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("http checker: missing or invalid 'url' field")
|
||||
}
|
||||
|
||||
method := "GET"
|
||||
if m, ok := config["method"].(string); ok {
|
||||
method = m
|
||||
}
|
||||
|
||||
expectedStatus := 200
|
||||
if status, ok := config["expected_status"]; ok {
|
||||
switch v := status.(type) {
|
||||
case int:
|
||||
expectedStatus = v
|
||||
case float64:
|
||||
expectedStatus = int(v)
|
||||
}
|
||||
}
|
||||
|
||||
insecureSkipVerify := false
|
||||
if skip, ok := config["insecure_skip_verify"].(bool); ok {
|
||||
insecureSkipVerify = skip
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
},
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
return &HTTPChecker{
|
||||
name: name,
|
||||
url: url,
|
||||
method: method,
|
||||
expectedStatus: expectedStatus,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *HTTPChecker) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *HTTPChecker) Type() string {
|
||||
return "http"
|
||||
}
|
||||
|
||||
func (c *HTTPChecker) Check(ctx context.Context) CheckResult {
|
||||
req, err := http.NewRequestWithContext(ctx, c.method, c.url, nil)
|
||||
if err != nil {
|
||||
return CheckResultFailure
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return CheckResultFailure
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == c.expectedStatus {
|
||||
return CheckResultSuccess
|
||||
}
|
||||
|
||||
return CheckResultFailure
|
||||
}
|
||||
192
internal/health/monitor.go
Normal file
192
internal/health/monitor.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/loveuer/go-alived/pkg/logger"
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
checker Checker
|
||||
config *CheckerConfig
|
||||
state *CheckerState
|
||||
log *logger.Logger
|
||||
callbacks []StateChangeCallback
|
||||
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMonitor(checker Checker, config *CheckerConfig, log *logger.Logger) *Monitor {
|
||||
return &Monitor{
|
||||
checker: checker,
|
||||
config: config,
|
||||
state: &CheckerState{
|
||||
Name: config.Name,
|
||||
Healthy: false,
|
||||
},
|
||||
log: log,
|
||||
callbacks: make([]StateChangeCallback, 0),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Monitor) Start() {
|
||||
m.mu.Lock()
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.running = true
|
||||
m.mu.Unlock()
|
||||
|
||||
m.log.Info("[HealthCheck:%s] starting health check monitor (interval=%s, timeout=%s)",
|
||||
m.config.Name, m.config.Interval, m.config.Timeout)
|
||||
|
||||
m.wg.Add(1)
|
||||
go m.checkLoop()
|
||||
}
|
||||
|
||||
func (m *Monitor) Stop() {
|
||||
m.mu.Lock()
|
||||
if !m.running {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
|
||||
m.log.Info("[HealthCheck:%s] stopping health check monitor", m.config.Name)
|
||||
close(m.stopCh)
|
||||
m.wg.Wait()
|
||||
}
|
||||
|
||||
func (m *Monitor) checkLoop() {
|
||||
defer m.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(m.config.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
m.performCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.performCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Monitor) performCheck() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), m.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
result := m.checker.Check(ctx)
|
||||
duration := time.Since(startTime)
|
||||
|
||||
m.mu.Lock()
|
||||
oldHealthy := m.state.Healthy
|
||||
stateChanged := m.state.Update(result, m.config.Rise, m.config.Fall)
|
||||
newHealthy := m.state.Healthy
|
||||
callbacks := m.callbacks
|
||||
m.mu.Unlock()
|
||||
|
||||
m.log.Debug("[HealthCheck:%s] check completed: result=%s, duration=%s, healthy=%v",
|
||||
m.config.Name, result, duration, newHealthy)
|
||||
|
||||
if stateChanged {
|
||||
m.log.Info("[HealthCheck:%s] health state changed: %v -> %v (consecutive_ok=%d, consecutive_fail=%d)",
|
||||
m.config.Name, oldHealthy, newHealthy, m.state.ConsecutiveOK, m.state.ConsecutiveFail)
|
||||
|
||||
for _, callback := range callbacks {
|
||||
callback(m.config.Name, oldHealthy, newHealthy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Monitor) OnStateChange(callback StateChangeCallback) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callbacks = append(m.callbacks, callback)
|
||||
}
|
||||
|
||||
func (m *Monitor) GetState() *CheckerState {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
stateCopy := *m.state
|
||||
return &stateCopy
|
||||
}
|
||||
|
||||
func (m *Monitor) IsHealthy() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.state.Healthy
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
monitors map[string]*Monitor
|
||||
mu sync.RWMutex
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewManager(log *logger.Logger) *Manager {
|
||||
return &Manager{
|
||||
monitors: make(map[string]*Monitor),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) AddMonitor(monitor *Monitor) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.monitors[monitor.config.Name] = monitor
|
||||
}
|
||||
|
||||
func (m *Manager) GetMonitor(name string) (*Monitor, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
monitor, ok := m.monitors[name]
|
||||
return monitor, ok
|
||||
}
|
||||
|
||||
func (m *Manager) StartAll() {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, monitor := range m.monitors {
|
||||
monitor.Start()
|
||||
}
|
||||
|
||||
m.log.Info("started %d health check monitor(s)", len(m.monitors))
|
||||
}
|
||||
|
||||
func (m *Manager) StopAll() {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, monitor := range m.monitors {
|
||||
monitor.Stop()
|
||||
}
|
||||
|
||||
m.log.Info("stopped all health check monitors")
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllStates() map[string]*CheckerState {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
states := make(map[string]*CheckerState)
|
||||
for name, monitor := range m.monitors {
|
||||
states[name] = monitor.GetState()
|
||||
}
|
||||
|
||||
return states
|
||||
}
|
||||
129
internal/health/ping.go
Normal file
129
internal/health/ping.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
type PingChecker struct {
|
||||
name string
|
||||
host string
|
||||
count int
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewPingChecker(name string, config map[string]interface{}) (*PingChecker, error) {
|
||||
host, ok := config["host"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("ping checker: missing or invalid 'host' field")
|
||||
}
|
||||
|
||||
count := 1
|
||||
if c, ok := config["count"]; ok {
|
||||
switch v := c.(type) {
|
||||
case int:
|
||||
count = v
|
||||
case float64:
|
||||
count = int(v)
|
||||
}
|
||||
}
|
||||
|
||||
timeout := 2 * time.Second
|
||||
if t, ok := config["timeout"].(string); ok {
|
||||
if d, err := time.ParseDuration(t); err == nil {
|
||||
timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
return &PingChecker{
|
||||
name: name,
|
||||
host: host,
|
||||
count: count,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *PingChecker) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *PingChecker) Type() string {
|
||||
return "ping"
|
||||
}
|
||||
|
||||
func (c *PingChecker) Check(ctx context.Context) CheckResult {
|
||||
addr, err := net.ResolveIPAddr("ip4", c.host)
|
||||
if err != nil {
|
||||
return CheckResultFailure
|
||||
}
|
||||
|
||||
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
return CheckResultFailure
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
successCount := 0
|
||||
for i := 0; i < c.count; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return CheckResultFailure
|
||||
default:
|
||||
}
|
||||
|
||||
if c.sendPing(conn, addr) {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
if successCount > 0 {
|
||||
return CheckResultSuccess
|
||||
}
|
||||
|
||||
return CheckResultFailure
|
||||
}
|
||||
|
||||
func (c *PingChecker) sendPing(conn *icmp.PacketConn, addr *net.IPAddr) bool {
|
||||
msg := icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &icmp.Echo{
|
||||
ID: 1234,
|
||||
Seq: 1,
|
||||
Data: []byte("go-alived-ping"),
|
||||
},
|
||||
}
|
||||
|
||||
msgBytes, err := msg.Marshal(nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := conn.WriteTo(msgBytes, addr); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||
|
||||
reply := make([]byte, 1500)
|
||||
n, _, err := conn.ReadFrom(reply)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
parsedMsg, err := icmp.ParseMessage(ipv4.ICMPTypeEchoReply.Protocol(), reply[:n])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if parsedMsg.Type == ipv4.ICMPTypeEchoReply {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
73
internal/health/script.go
Normal file
73
internal/health/script.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ScriptChecker struct {
|
||||
name string
|
||||
script string
|
||||
args []string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewScriptChecker(name string, config map[string]interface{}) (*ScriptChecker, error) {
|
||||
script, ok := config["script"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("script checker: missing or invalid 'script' field")
|
||||
}
|
||||
|
||||
var args []string
|
||||
if argsInterface, ok := config["args"].([]interface{}); ok {
|
||||
args = make([]string, len(argsInterface))
|
||||
for i, arg := range argsInterface {
|
||||
if argStr, ok := arg.(string); ok {
|
||||
args[i] = argStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
timeout := 10 * time.Second
|
||||
if t, ok := config["timeout"].(string); ok {
|
||||
if d, err := time.ParseDuration(t); err == nil {
|
||||
timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
return &ScriptChecker{
|
||||
name: name,
|
||||
script: script,
|
||||
args: args,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *ScriptChecker) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *ScriptChecker) Type() string {
|
||||
return "script"
|
||||
}
|
||||
|
||||
func (c *ScriptChecker) Check(ctx context.Context) CheckResult {
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, c.timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, c.script, c.args...)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
if exitErr.ExitCode() != 0 {
|
||||
return CheckResultFailure
|
||||
}
|
||||
}
|
||||
return CheckResultFailure
|
||||
}
|
||||
|
||||
return CheckResultSuccess
|
||||
}
|
||||
61
internal/health/tcp.go
Normal file
61
internal/health/tcp.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type TCPChecker struct {
|
||||
name string
|
||||
host string
|
||||
port int
|
||||
}
|
||||
|
||||
func NewTCPChecker(name string, config map[string]interface{}) (*TCPChecker, error) {
|
||||
host, ok := config["host"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tcp checker: missing or invalid 'host' field")
|
||||
}
|
||||
|
||||
var port int
|
||||
switch v := config["port"].(type) {
|
||||
case int:
|
||||
port = v
|
||||
case float64:
|
||||
port = int(v)
|
||||
default:
|
||||
return nil, fmt.Errorf("tcp checker: missing or invalid 'port' field")
|
||||
}
|
||||
|
||||
if port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("tcp checker: invalid port number: %d", port)
|
||||
}
|
||||
|
||||
return &TCPChecker{
|
||||
name: name,
|
||||
host: host,
|
||||
port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *TCPChecker) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *TCPChecker) Type() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (c *TCPChecker) Check(ctx context.Context) CheckResult {
|
||||
addr := fmt.Sprintf("%s:%d", c.host, c.port)
|
||||
|
||||
var dialer net.Dialer
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return CheckResultFailure
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
return CheckResultSuccess
|
||||
}
|
||||
72
internal/vrrp/arp.go
Normal file
72
internal/vrrp/arp.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package vrrp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/mdlayher/arp"
|
||||
)
|
||||
|
||||
type ARPSender struct {
|
||||
client *arp.Client
|
||||
iface *net.Interface
|
||||
}
|
||||
|
||||
func NewARPSender(ifaceName string) (*ARPSender, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface %s: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
client, err := arp.Dial(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ARP client: %w", err)
|
||||
}
|
||||
|
||||
return &ARPSender{
|
||||
client: client,
|
||||
iface: iface,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *ARPSender) SendGratuitousARP(ip net.IP) error {
|
||||
if ip4 := ip.To4(); ip4 == nil {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ip.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse IP: %w", err)
|
||||
}
|
||||
|
||||
pkt, err := arp.NewPacket(
|
||||
arp.OperationRequest,
|
||||
a.iface.HardwareAddr,
|
||||
addr,
|
||||
net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
addr,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ARP packet: %w", err)
|
||||
}
|
||||
|
||||
if err := a.client.WriteTo(pkt, net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}); err != nil {
|
||||
return fmt.Errorf("failed to send gratuitous ARP: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ARPSender) SendGratuitousARPForIPs(ips []net.IP) error {
|
||||
for _, ip := range ips {
|
||||
if err := a.SendGratuitousARP(ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ARPSender) Close() error {
|
||||
return a.client.Close()
|
||||
}
|
||||
427
internal/vrrp/instance.go
Normal file
427
internal/vrrp/instance.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package vrrp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/loveuer/go-alived/pkg/logger"
|
||||
"github.com/loveuer/go-alived/pkg/netif"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
Name string
|
||||
VirtualRouterID uint8
|
||||
Priority uint8
|
||||
AdvertInterval uint8
|
||||
Interface string
|
||||
VirtualIPs []net.IP
|
||||
AuthType uint8
|
||||
AuthPass string
|
||||
TrackScripts []string
|
||||
|
||||
state *StateMachine
|
||||
priorityCalc *PriorityCalculator
|
||||
history *StateHistory
|
||||
socket *Socket
|
||||
arpSender *ARPSender
|
||||
netInterface *netif.Interface
|
||||
|
||||
advertTimer *Timer
|
||||
masterDownTimer *Timer
|
||||
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
|
||||
log *logger.Logger
|
||||
|
||||
onMaster func()
|
||||
onBackup func()
|
||||
onFault func()
|
||||
}
|
||||
|
||||
func NewInstance(
|
||||
name string,
|
||||
vrID uint8,
|
||||
priority uint8,
|
||||
advertInt uint8,
|
||||
iface string,
|
||||
vips []string,
|
||||
authType string,
|
||||
authPass string,
|
||||
trackScripts []string,
|
||||
log *logger.Logger,
|
||||
) (*Instance, error) {
|
||||
if vrID < 1 || vrID > 255 {
|
||||
return nil, fmt.Errorf("invalid virtual router ID: %d", vrID)
|
||||
}
|
||||
|
||||
if priority < 1 || priority > 255 {
|
||||
return nil, fmt.Errorf("invalid priority: %d", priority)
|
||||
}
|
||||
|
||||
virtualIPs := make([]net.IP, 0, len(vips))
|
||||
for _, vip := range vips {
|
||||
ip, _, err := net.ParseCIDR(vip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid VIP %s: %w", vip, err)
|
||||
}
|
||||
virtualIPs = append(virtualIPs, ip)
|
||||
}
|
||||
|
||||
var authTypeNum uint8
|
||||
switch authType {
|
||||
case "NONE", "":
|
||||
authTypeNum = AuthTypeNone
|
||||
case "PASS":
|
||||
authTypeNum = AuthTypeSimpleText
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported auth type: %s", authType)
|
||||
}
|
||||
|
||||
netInterface, err := netif.GetInterface(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface: %w", err)
|
||||
}
|
||||
|
||||
inst := &Instance{
|
||||
Name: name,
|
||||
VirtualRouterID: vrID,
|
||||
Priority: priority,
|
||||
AdvertInterval: advertInt,
|
||||
Interface: iface,
|
||||
VirtualIPs: virtualIPs,
|
||||
AuthType: authTypeNum,
|
||||
AuthPass: authPass,
|
||||
TrackScripts: trackScripts,
|
||||
state: NewStateMachine(StateInit),
|
||||
priorityCalc: NewPriorityCalculator(priority),
|
||||
history: NewStateHistory(100),
|
||||
netInterface: netInterface,
|
||||
stopCh: make(chan struct{}),
|
||||
log: log,
|
||||
}
|
||||
|
||||
inst.advertTimer = NewTimer(time.Duration(advertInt)*time.Second, inst.onAdvertTimer)
|
||||
inst.masterDownTimer = NewTimer(CalculateMasterDownInterval(advertInt), inst.onMasterDownTimer)
|
||||
|
||||
inst.state.OnStateChange(func(old, new State) {
|
||||
inst.history.Add(old, new, "state transition")
|
||||
inst.log.Info("[%s] state changed: %s -> %s", inst.Name, old, new)
|
||||
inst.handleStateChange(old, new)
|
||||
})
|
||||
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (inst *Instance) Start() error {
|
||||
inst.mu.Lock()
|
||||
if inst.running {
|
||||
inst.mu.Unlock()
|
||||
return fmt.Errorf("instance %s already running", inst.Name)
|
||||
}
|
||||
inst.running = true
|
||||
inst.mu.Unlock()
|
||||
|
||||
var err error
|
||||
inst.socket, err = NewSocket(inst.Interface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create socket: %w", err)
|
||||
}
|
||||
|
||||
inst.arpSender, err = NewARPSender(inst.Interface)
|
||||
if err != nil {
|
||||
inst.socket.Close()
|
||||
return fmt.Errorf("failed to create ARP sender: %w", err)
|
||||
}
|
||||
|
||||
inst.log.Info("[%s] starting VRRP instance (VRID=%d, Priority=%d, Interface=%s)",
|
||||
inst.Name, inst.VirtualRouterID, inst.Priority, inst.Interface)
|
||||
|
||||
inst.state.SetState(StateBackup)
|
||||
inst.masterDownTimer.Start()
|
||||
|
||||
inst.wg.Add(1)
|
||||
go inst.receiveLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) Stop() {
|
||||
inst.mu.Lock()
|
||||
if !inst.running {
|
||||
inst.mu.Unlock()
|
||||
return
|
||||
}
|
||||
inst.running = false
|
||||
inst.mu.Unlock()
|
||||
|
||||
inst.log.Info("[%s] stopping VRRP instance", inst.Name)
|
||||
|
||||
close(inst.stopCh)
|
||||
inst.wg.Wait()
|
||||
|
||||
inst.advertTimer.Stop()
|
||||
inst.masterDownTimer.Stop()
|
||||
|
||||
if inst.state.GetState() == StateMaster {
|
||||
inst.removeVIPs()
|
||||
}
|
||||
|
||||
if inst.socket != nil {
|
||||
inst.socket.Close()
|
||||
}
|
||||
|
||||
if inst.arpSender != nil {
|
||||
inst.arpSender.Close()
|
||||
}
|
||||
|
||||
inst.state.SetState(StateInit)
|
||||
}
|
||||
|
||||
func (inst *Instance) receiveLoop() {
|
||||
defer inst.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-inst.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
pkt, srcIP, err := inst.socket.Receive()
|
||||
if err != nil {
|
||||
inst.log.Debug("[%s] failed to receive packet: %v", inst.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if pkt.VirtualRtrID != inst.VirtualRouterID {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pkt.Validate(inst.AuthPass); err != nil {
|
||||
inst.log.Warn("[%s] packet validation failed: %v", inst.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
inst.handleAdvertisement(pkt, srcIP)
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) handleAdvertisement(pkt *VRRPPacket, srcIP net.IP) {
|
||||
currentState := inst.state.GetState()
|
||||
localPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
inst.log.Debug("[%s] received advertisement from %s (priority=%d, state=%s)",
|
||||
inst.Name, srcIP, pkt.Priority, currentState)
|
||||
|
||||
switch currentState {
|
||||
case StateBackup:
|
||||
if pkt.Priority == 0 {
|
||||
inst.masterDownTimer.SetDuration(CalculateSkewTime(localPriority))
|
||||
inst.masterDownTimer.Reset()
|
||||
} else if !ShouldBecomeMaster(localPriority, pkt.Priority, inst.socket.localIP.String(), srcIP.String()) {
|
||||
inst.masterDownTimer.Reset()
|
||||
}
|
||||
|
||||
case StateMaster:
|
||||
if ShouldBecomeMaster(pkt.Priority, localPriority, srcIP.String(), inst.socket.localIP.String()) {
|
||||
inst.log.Warn("[%s] received higher priority advertisement, stepping down", inst.Name)
|
||||
inst.state.SetState(StateBackup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) onAdvertTimer() {
|
||||
if inst.state.GetState() == StateMaster {
|
||||
inst.sendAdvertisement()
|
||||
inst.advertTimer.Start()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) onMasterDownTimer() {
|
||||
if inst.state.GetState() == StateBackup {
|
||||
inst.log.Info("[%s] master down timer expired, becoming master", inst.Name)
|
||||
inst.state.SetState(StateMaster)
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) sendAdvertisement() error {
|
||||
priority := inst.priorityCalc.GetPriority()
|
||||
|
||||
pkt := NewAdvertisement(
|
||||
inst.VirtualRouterID,
|
||||
priority,
|
||||
inst.AdvertInterval,
|
||||
inst.VirtualIPs,
|
||||
inst.AuthType,
|
||||
inst.AuthPass,
|
||||
)
|
||||
|
||||
if err := inst.socket.Send(pkt); err != nil {
|
||||
inst.log.Error("[%s] failed to send advertisement: %v", inst.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
inst.log.Debug("[%s] sent advertisement (priority=%d)", inst.Name, priority)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) handleStateChange(old, new State) {
|
||||
switch new {
|
||||
case StateMaster:
|
||||
inst.becomeMaster()
|
||||
case StateBackup:
|
||||
inst.becomeBackup(old)
|
||||
case StateFault:
|
||||
inst.becomeFault()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) becomeMaster() {
|
||||
inst.log.Info("[%s] transitioning to MASTER state", inst.Name)
|
||||
|
||||
if err := inst.addVIPs(); err != nil {
|
||||
inst.log.Error("[%s] failed to add VIPs: %v", inst.Name, err)
|
||||
inst.state.SetState(StateFault)
|
||||
return
|
||||
}
|
||||
|
||||
if err := inst.arpSender.SendGratuitousARPForIPs(inst.VirtualIPs); err != nil {
|
||||
inst.log.Error("[%s] failed to send gratuitous ARP: %v", inst.Name, err)
|
||||
}
|
||||
|
||||
inst.masterDownTimer.Stop()
|
||||
inst.advertTimer.Start()
|
||||
|
||||
inst.sendAdvertisement()
|
||||
|
||||
if inst.onMaster != nil {
|
||||
inst.onMaster()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) becomeBackup(oldState State) {
|
||||
inst.log.Info("[%s] transitioning to BACKUP state", inst.Name)
|
||||
|
||||
inst.advertTimer.Stop()
|
||||
|
||||
if oldState == StateMaster {
|
||||
if err := inst.removeVIPs(); err != nil {
|
||||
inst.log.Error("[%s] failed to remove VIPs: %v", inst.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
inst.masterDownTimer.Reset()
|
||||
|
||||
if inst.onBackup != nil {
|
||||
inst.onBackup()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) becomeFault() {
|
||||
inst.log.Error("[%s] transitioning to FAULT state", inst.Name)
|
||||
|
||||
inst.advertTimer.Stop()
|
||||
inst.masterDownTimer.Stop()
|
||||
|
||||
if err := inst.removeVIPs(); err != nil {
|
||||
inst.log.Error("[%s] failed to remove VIPs: %v", inst.Name, err)
|
||||
}
|
||||
|
||||
if inst.onFault != nil {
|
||||
inst.onFault()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) addVIPs() error {
|
||||
inst.log.Info("[%s] adding virtual IPs", inst.Name)
|
||||
|
||||
for _, vipStr := range inst.getVIPsWithCIDR() {
|
||||
if err := inst.netInterface.AddIP(vipStr); err != nil {
|
||||
inst.log.Error("[%s] failed to add VIP %s: %v", inst.Name, vipStr, err)
|
||||
return err
|
||||
}
|
||||
inst.log.Info("[%s] added VIP %s", inst.Name, vipStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) removeVIPs() error {
|
||||
inst.log.Info("[%s] removing virtual IPs", inst.Name)
|
||||
|
||||
for _, vipStr := range inst.getVIPsWithCIDR() {
|
||||
has, _ := inst.netInterface.HasIP(vipStr)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := inst.netInterface.DeleteIP(vipStr); err != nil {
|
||||
inst.log.Error("[%s] failed to remove VIP %s: %v", inst.Name, vipStr, err)
|
||||
return err
|
||||
}
|
||||
inst.log.Info("[%s] removed VIP %s", inst.Name, vipStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) getVIPsWithCIDR() []string {
|
||||
result := make([]string, len(inst.VirtualIPs))
|
||||
for i, ip := range inst.VirtualIPs {
|
||||
result[i] = ip.String() + "/32"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (inst *Instance) GetState() State {
|
||||
return inst.state.GetState()
|
||||
}
|
||||
|
||||
func (inst *Instance) OnMaster(callback func()) {
|
||||
inst.onMaster = callback
|
||||
}
|
||||
|
||||
func (inst *Instance) OnBackup(callback func()) {
|
||||
inst.onBackup = callback
|
||||
}
|
||||
|
||||
func (inst *Instance) OnFault(callback func()) {
|
||||
inst.onFault = callback
|
||||
}
|
||||
|
||||
func (inst *Instance) AdjustPriority(delta int) {
|
||||
inst.mu.Lock()
|
||||
defer inst.mu.Unlock()
|
||||
|
||||
oldPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
if delta < 0 {
|
||||
inst.priorityCalc.DecreasePriority(uint8(-delta))
|
||||
}
|
||||
|
||||
newPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
if oldPriority != newPriority {
|
||||
inst.log.Info("[%s] priority adjusted: %d -> %d (delta=%d)",
|
||||
inst.Name, oldPriority, newPriority, delta)
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) ResetPriority() {
|
||||
inst.mu.Lock()
|
||||
defer inst.mu.Unlock()
|
||||
|
||||
oldPriority := inst.priorityCalc.GetPriority()
|
||||
inst.priorityCalc.ResetPriority()
|
||||
newPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
if oldPriority != newPriority {
|
||||
inst.log.Info("[%s] priority reset: %d -> %d",
|
||||
inst.Name, oldPriority, newPriority)
|
||||
}
|
||||
}
|
||||
116
internal/vrrp/manager.go
Normal file
116
internal/vrrp/manager.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package vrrp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/loveuer/go-alived/pkg/config"
|
||||
"github.com/loveuer/go-alived/pkg/logger"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
instances map[string]*Instance
|
||||
mu sync.RWMutex
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewManager(log *logger.Logger) *Manager {
|
||||
return &Manager{
|
||||
instances: make(map[string]*Instance),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) LoadFromConfig(cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, vrrpCfg := range cfg.VRRP {
|
||||
inst, err := NewInstance(
|
||||
vrrpCfg.Name,
|
||||
uint8(vrrpCfg.VirtualRouterID),
|
||||
uint8(vrrpCfg.Priority),
|
||||
uint8(vrrpCfg.AdvertInterval),
|
||||
vrrpCfg.Interface,
|
||||
vrrpCfg.VirtualIPs,
|
||||
vrrpCfg.AuthType,
|
||||
vrrpCfg.AuthPass,
|
||||
vrrpCfg.TrackScripts,
|
||||
m.log,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create instance %s: %w", vrrpCfg.Name, err)
|
||||
}
|
||||
|
||||
m.instances[vrrpCfg.Name] = inst
|
||||
m.log.Info("loaded VRRP instance: %s", vrrpCfg.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) StartAll() error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for name, inst := range m.instances {
|
||||
if err := inst.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start instance %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
m.log.Info("started %d VRRP instance(s)", len(m.instances))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) StopAll() {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
for _, inst := range m.instances {
|
||||
inst.Stop()
|
||||
}
|
||||
|
||||
m.log.Info("stopped all VRRP instances")
|
||||
}
|
||||
|
||||
func (m *Manager) GetInstance(name string) (*Instance, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
inst, ok := m.instances[name]
|
||||
return inst, ok
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllInstances() []*Instance {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make([]*Instance, 0, len(m.instances))
|
||||
for _, inst := range m.instances {
|
||||
result = append(result, inst)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *Manager) Reload(cfg *config.Config) error {
|
||||
m.log.Info("reloading VRRP configuration...")
|
||||
|
||||
m.StopAll()
|
||||
|
||||
m.mu.Lock()
|
||||
m.instances = make(map[string]*Instance)
|
||||
m.mu.Unlock()
|
||||
|
||||
if err := m.LoadFromConfig(cfg); err != nil {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
if err := m.StartAll(); err != nil {
|
||||
return fmt.Errorf("failed to start instances: %w", err)
|
||||
}
|
||||
|
||||
m.log.Info("VRRP configuration reloaded successfully")
|
||||
return nil
|
||||
}
|
||||
184
internal/vrrp/packet.go
Normal file
184
internal/vrrp/packet.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package vrrp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
VRRPVersion = 2
|
||||
VRRPProtocolNumber = 112
|
||||
)
|
||||
|
||||
type VRRPPacket struct {
|
||||
Version uint8
|
||||
Type uint8
|
||||
VirtualRtrID uint8
|
||||
Priority uint8
|
||||
CountIPAddrs uint8
|
||||
AuthType uint8
|
||||
AdvertInt uint8
|
||||
Checksum uint16
|
||||
IPAddresses []net.IP
|
||||
AuthData [8]byte
|
||||
}
|
||||
|
||||
const (
|
||||
VRRPTypeAdvertisement = 1
|
||||
)
|
||||
|
||||
const (
|
||||
AuthTypeNone = 0
|
||||
AuthTypeSimpleText = 1
|
||||
AuthTypeIPAH = 2
|
||||
)
|
||||
|
||||
func NewAdvertisement(vrID uint8, priority uint8, advertInt uint8, ips []net.IP, authType uint8, authPass string) *VRRPPacket {
|
||||
pkt := &VRRPPacket{
|
||||
Version: VRRPVersion,
|
||||
Type: VRRPTypeAdvertisement,
|
||||
VirtualRtrID: vrID,
|
||||
Priority: priority,
|
||||
CountIPAddrs: uint8(len(ips)),
|
||||
AuthType: authType,
|
||||
AdvertInt: advertInt,
|
||||
IPAddresses: ips,
|
||||
}
|
||||
|
||||
if authType == AuthTypeSimpleText && authPass != "" {
|
||||
copy(pkt.AuthData[:], authPass)
|
||||
}
|
||||
|
||||
return pkt
|
||||
}
|
||||
|
||||
func (p *VRRPPacket) Marshal() ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
versionType := (p.Version << 4) | p.Type
|
||||
if err := binary.Write(buf, binary.BigEndian, versionType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, p.VirtualRtrID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, p.Priority); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, p.CountIPAddrs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, p.AuthType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, p.AdvertInt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, uint16(0)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, ip := range p.IPAddresses {
|
||||
ip4 := ip.To4()
|
||||
if ip4 == nil {
|
||||
return nil, fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
if err := binary.Write(buf, binary.BigEndian, ip4); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := binary.Write(buf, binary.BigEndian, p.AuthData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := buf.Bytes()
|
||||
|
||||
checksum := calculateChecksum(data)
|
||||
binary.BigEndian.PutUint16(data[6:8], checksum)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func Unmarshal(data []byte) (*VRRPPacket, error) {
|
||||
if len(data) < 20 {
|
||||
return nil, fmt.Errorf("packet too short: %d bytes", len(data))
|
||||
}
|
||||
|
||||
pkt := &VRRPPacket{}
|
||||
|
||||
versionType := data[0]
|
||||
pkt.Version = versionType >> 4
|
||||
pkt.Type = versionType & 0x0F
|
||||
pkt.VirtualRtrID = data[1]
|
||||
pkt.Priority = data[2]
|
||||
pkt.CountIPAddrs = data[3]
|
||||
pkt.AuthType = data[4]
|
||||
pkt.AdvertInt = data[5]
|
||||
pkt.Checksum = binary.BigEndian.Uint16(data[6:8])
|
||||
|
||||
offset := 8
|
||||
pkt.IPAddresses = make([]net.IP, pkt.CountIPAddrs)
|
||||
for i := 0; i < int(pkt.CountIPAddrs); i++ {
|
||||
if offset+4 > len(data) {
|
||||
return nil, fmt.Errorf("packet too short for IP addresses")
|
||||
}
|
||||
pkt.IPAddresses[i] = net.IPv4(data[offset], data[offset+1], data[offset+2], data[offset+3])
|
||||
offset += 4
|
||||
}
|
||||
|
||||
if offset+8 > len(data) {
|
||||
return nil, fmt.Errorf("packet too short for auth data")
|
||||
}
|
||||
copy(pkt.AuthData[:], data[offset:offset+8])
|
||||
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func calculateChecksum(data []byte) uint16 {
|
||||
sum := uint32(0)
|
||||
|
||||
for i := 0; i < len(data)-1; i += 2 {
|
||||
sum += uint32(data[i])<<8 | uint32(data[i+1])
|
||||
}
|
||||
|
||||
if len(data)%2 == 1 {
|
||||
sum += uint32(data[len(data)-1]) << 8
|
||||
}
|
||||
|
||||
for sum > 0xFFFF {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
|
||||
return uint16(^sum)
|
||||
}
|
||||
|
||||
func (p *VRRPPacket) Validate(authPass string) error {
|
||||
if p.Version != VRRPVersion {
|
||||
return fmt.Errorf("unsupported VRRP version: %d", p.Version)
|
||||
}
|
||||
|
||||
if p.Type != VRRPTypeAdvertisement {
|
||||
return fmt.Errorf("unsupported VRRP type: %d", p.Type)
|
||||
}
|
||||
|
||||
if p.AuthType == AuthTypeSimpleText {
|
||||
if authPass != "" {
|
||||
var expectedAuth [8]byte
|
||||
copy(expectedAuth[:], authPass)
|
||||
if !bytes.Equal(p.AuthData[:], expectedAuth[:]) {
|
||||
return fmt.Errorf("authentication failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
141
internal/vrrp/socket.go
Normal file
141
internal/vrrp/socket.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package vrrp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
const (
|
||||
VRRPMulticastAddr = "224.0.0.18"
|
||||
)
|
||||
|
||||
type Socket struct {
|
||||
conn *ipv4.RawConn
|
||||
iface *net.Interface
|
||||
localIP net.IP
|
||||
groupIP net.IP
|
||||
}
|
||||
|
||||
func NewSocket(ifaceName string) (*Socket, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface %s: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get addresses for %s: %w", ifaceName, err)
|
||||
}
|
||||
|
||||
var localIP net.IP
|
||||
for _, addr := range addrs {
|
||||
if ipNet, ok := addr.(*net.IPNet); ok {
|
||||
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
|
||||
localIP = ipv4
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if localIP == nil {
|
||||
return nil, fmt.Errorf("no IPv4 address found on interface %s", ifaceName)
|
||||
}
|
||||
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, VRRPProtocolNumber)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create raw socket: %w", err)
|
||||
}
|
||||
|
||||
if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, fmt.Errorf("failed to set SO_REUSEADDR: %w", err)
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "vrrp-socket")
|
||||
defer file.Close()
|
||||
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create packet connection: %w", err)
|
||||
}
|
||||
|
||||
rawConn, err := ipv4.NewRawConn(packetConn)
|
||||
if err != nil {
|
||||
packetConn.Close()
|
||||
return nil, fmt.Errorf("failed to create raw connection: %w", err)
|
||||
}
|
||||
|
||||
groupIP := net.ParseIP(VRRPMulticastAddr).To4()
|
||||
if groupIP == nil {
|
||||
rawConn.Close()
|
||||
return nil, fmt.Errorf("invalid multicast address: %s", VRRPMulticastAddr)
|
||||
}
|
||||
|
||||
if err := rawConn.JoinGroup(iface, &net.IPAddr{IP: groupIP}); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, fmt.Errorf("failed to join multicast group: %w", err)
|
||||
}
|
||||
|
||||
if err := rawConn.SetControlMessage(ipv4.FlagTTL|ipv4.FlagSrc|ipv4.FlagDst|ipv4.FlagInterface, true); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, fmt.Errorf("failed to set control message: %w", err)
|
||||
}
|
||||
|
||||
return &Socket{
|
||||
conn: rawConn,
|
||||
iface: iface,
|
||||
localIP: localIP,
|
||||
groupIP: groupIP,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Socket) Send(pkt *VRRPPacket) error {
|
||||
data, err := pkt.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal packet: %w", err)
|
||||
}
|
||||
|
||||
header := &ipv4.Header{
|
||||
Version: ipv4.Version,
|
||||
Len: ipv4.HeaderLen,
|
||||
TOS: 0xC0,
|
||||
TotalLen: ipv4.HeaderLen + len(data),
|
||||
TTL: 255,
|
||||
Protocol: VRRPProtocolNumber,
|
||||
Dst: s.groupIP,
|
||||
Src: s.localIP,
|
||||
}
|
||||
|
||||
if err := s.conn.WriteTo(header, data, nil); err != nil {
|
||||
return fmt.Errorf("failed to send packet: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Socket) Receive() (*VRRPPacket, net.IP, error) {
|
||||
buf := make([]byte, 1500)
|
||||
|
||||
header, payload, _, err := s.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to receive packet: %w", err)
|
||||
}
|
||||
|
||||
pkt, err := Unmarshal(payload)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to unmarshal packet: %w", err)
|
||||
}
|
||||
|
||||
return pkt, header.Src, nil
|
||||
}
|
||||
|
||||
func (s *Socket) Close() error {
|
||||
if err := s.conn.LeaveGroup(s.iface, &net.IPAddr{IP: s.groupIP}); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.conn.Close()
|
||||
}
|
||||
258
internal/vrrp/state.go
Normal file
258
internal/vrrp/state.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package vrrp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type State int
|
||||
|
||||
const (
|
||||
StateInit State = iota
|
||||
StateBackup
|
||||
StateMaster
|
||||
StateFault
|
||||
)
|
||||
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateInit:
|
||||
return "INIT"
|
||||
case StateBackup:
|
||||
return "BACKUP"
|
||||
case StateMaster:
|
||||
return "MASTER"
|
||||
case StateFault:
|
||||
return "FAULT"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
type StateMachine struct {
|
||||
currentState State
|
||||
previousState State
|
||||
mu sync.RWMutex
|
||||
stateChangeCallbacks []func(old, new State)
|
||||
}
|
||||
|
||||
func NewStateMachine(initialState State) *StateMachine {
|
||||
return &StateMachine{
|
||||
currentState: initialState,
|
||||
previousState: StateInit,
|
||||
stateChangeCallbacks: make([]func(old, new State), 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) GetState() State {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.currentState
|
||||
}
|
||||
|
||||
func (sm *StateMachine) SetState(newState State) {
|
||||
sm.mu.Lock()
|
||||
oldState := sm.currentState
|
||||
sm.previousState = oldState
|
||||
sm.currentState = newState
|
||||
callbacks := sm.stateChangeCallbacks
|
||||
sm.mu.Unlock()
|
||||
|
||||
for _, callback := range callbacks {
|
||||
callback(oldState, newState)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *StateMachine) OnStateChange(callback func(old, new State)) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.stateChangeCallbacks = append(sm.stateChangeCallbacks, callback)
|
||||
}
|
||||
|
||||
type Timer struct {
|
||||
duration time.Duration
|
||||
timer *time.Timer
|
||||
callback func()
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewTimer(duration time.Duration, callback func()) *Timer {
|
||||
return &Timer{
|
||||
duration: duration,
|
||||
callback: callback,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Timer) Start() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
|
||||
t.timer = time.AfterFunc(t.duration, t.callback)
|
||||
}
|
||||
|
||||
func (t *Timer) Stop() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Timer) Reset() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
}
|
||||
|
||||
t.timer = time.AfterFunc(t.duration, t.callback)
|
||||
}
|
||||
|
||||
func (t *Timer) SetDuration(duration time.Duration) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.duration = duration
|
||||
}
|
||||
|
||||
type PriorityCalculator struct {
|
||||
basePriority uint8
|
||||
currentPriority uint8
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewPriorityCalculator(basePriority uint8) *PriorityCalculator {
|
||||
return &PriorityCalculator{
|
||||
basePriority: basePriority,
|
||||
currentPriority: basePriority,
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *PriorityCalculator) GetPriority() uint8 {
|
||||
pc.mu.RLock()
|
||||
defer pc.mu.RUnlock()
|
||||
return pc.currentPriority
|
||||
}
|
||||
|
||||
func (pc *PriorityCalculator) DecreasePriority(amount uint8) {
|
||||
pc.mu.Lock()
|
||||
defer pc.mu.Unlock()
|
||||
|
||||
if pc.currentPriority > amount {
|
||||
pc.currentPriority -= amount
|
||||
} else {
|
||||
pc.currentPriority = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *PriorityCalculator) ResetPriority() {
|
||||
pc.mu.Lock()
|
||||
defer pc.mu.Unlock()
|
||||
pc.currentPriority = pc.basePriority
|
||||
}
|
||||
|
||||
func (pc *PriorityCalculator) SetBasePriority(priority uint8) {
|
||||
pc.mu.Lock()
|
||||
defer pc.mu.Unlock()
|
||||
pc.basePriority = priority
|
||||
pc.currentPriority = priority
|
||||
}
|
||||
|
||||
func ShouldBecomeMaster(localPriority, remotePriority uint8, localIP, remoteIP string) bool {
|
||||
if localPriority > remotePriority {
|
||||
return true
|
||||
}
|
||||
|
||||
if localPriority == remotePriority {
|
||||
return localIP > remoteIP
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func CalculateMasterDownInterval(advertInt uint8) time.Duration {
|
||||
return time.Duration(3*int(advertInt)) * time.Second
|
||||
}
|
||||
|
||||
func CalculateSkewTime(priority uint8) time.Duration {
|
||||
skew := float64(256-int(priority)) / 256.0
|
||||
return time.Duration(skew * float64(time.Second))
|
||||
}
|
||||
|
||||
type StateTransition struct {
|
||||
From State
|
||||
To State
|
||||
Timestamp time.Time
|
||||
Reason string
|
||||
}
|
||||
|
||||
type StateHistory struct {
|
||||
transitions []StateTransition
|
||||
maxSize int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewStateHistory(maxSize int) *StateHistory {
|
||||
return &StateHistory{
|
||||
transitions: make([]StateTransition, 0, maxSize),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (sh *StateHistory) Add(from, to State, reason string) {
|
||||
sh.mu.Lock()
|
||||
defer sh.mu.Unlock()
|
||||
|
||||
transition := StateTransition{
|
||||
From: from,
|
||||
To: to,
|
||||
Timestamp: time.Now(),
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
sh.transitions = append(sh.transitions, transition)
|
||||
|
||||
if len(sh.transitions) > sh.maxSize {
|
||||
sh.transitions = sh.transitions[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func (sh *StateHistory) GetRecent(n int) []StateTransition {
|
||||
sh.mu.RLock()
|
||||
defer sh.mu.RUnlock()
|
||||
|
||||
if n > len(sh.transitions) {
|
||||
n = len(sh.transitions)
|
||||
}
|
||||
|
||||
start := len(sh.transitions) - n
|
||||
result := make([]StateTransition, n)
|
||||
copy(result, sh.transitions[start:])
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (sh *StateHistory) String() string {
|
||||
sh.mu.RLock()
|
||||
defer sh.mu.RUnlock()
|
||||
|
||||
if len(sh.transitions) == 0 {
|
||||
return "No state transitions"
|
||||
}
|
||||
|
||||
result := "State transition history:\n"
|
||||
for _, t := range sh.transitions {
|
||||
result += fmt.Sprintf(" %s: %s -> %s (%s)\n",
|
||||
t.Timestamp.Format("2006-01-02 15:04:05"),
|
||||
t.From, t.To, t.Reason)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user