wip v1.0.0
This commit is contained in:
427
internal/vrrp/instance.go
Normal file
427
internal/vrrp/instance.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package vrrp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/loveuer/go-alived/pkg/logger"
|
||||
"github.com/loveuer/go-alived/pkg/netif"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
Name string
|
||||
VirtualRouterID uint8
|
||||
Priority uint8
|
||||
AdvertInterval uint8
|
||||
Interface string
|
||||
VirtualIPs []net.IP
|
||||
AuthType uint8
|
||||
AuthPass string
|
||||
TrackScripts []string
|
||||
|
||||
state *StateMachine
|
||||
priorityCalc *PriorityCalculator
|
||||
history *StateHistory
|
||||
socket *Socket
|
||||
arpSender *ARPSender
|
||||
netInterface *netif.Interface
|
||||
|
||||
advertTimer *Timer
|
||||
masterDownTimer *Timer
|
||||
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
|
||||
log *logger.Logger
|
||||
|
||||
onMaster func()
|
||||
onBackup func()
|
||||
onFault func()
|
||||
}
|
||||
|
||||
func NewInstance(
|
||||
name string,
|
||||
vrID uint8,
|
||||
priority uint8,
|
||||
advertInt uint8,
|
||||
iface string,
|
||||
vips []string,
|
||||
authType string,
|
||||
authPass string,
|
||||
trackScripts []string,
|
||||
log *logger.Logger,
|
||||
) (*Instance, error) {
|
||||
if vrID < 1 || vrID > 255 {
|
||||
return nil, fmt.Errorf("invalid virtual router ID: %d", vrID)
|
||||
}
|
||||
|
||||
if priority < 1 || priority > 255 {
|
||||
return nil, fmt.Errorf("invalid priority: %d", priority)
|
||||
}
|
||||
|
||||
virtualIPs := make([]net.IP, 0, len(vips))
|
||||
for _, vip := range vips {
|
||||
ip, _, err := net.ParseCIDR(vip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid VIP %s: %w", vip, err)
|
||||
}
|
||||
virtualIPs = append(virtualIPs, ip)
|
||||
}
|
||||
|
||||
var authTypeNum uint8
|
||||
switch authType {
|
||||
case "NONE", "":
|
||||
authTypeNum = AuthTypeNone
|
||||
case "PASS":
|
||||
authTypeNum = AuthTypeSimpleText
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported auth type: %s", authType)
|
||||
}
|
||||
|
||||
netInterface, err := netif.GetInterface(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface: %w", err)
|
||||
}
|
||||
|
||||
inst := &Instance{
|
||||
Name: name,
|
||||
VirtualRouterID: vrID,
|
||||
Priority: priority,
|
||||
AdvertInterval: advertInt,
|
||||
Interface: iface,
|
||||
VirtualIPs: virtualIPs,
|
||||
AuthType: authTypeNum,
|
||||
AuthPass: authPass,
|
||||
TrackScripts: trackScripts,
|
||||
state: NewStateMachine(StateInit),
|
||||
priorityCalc: NewPriorityCalculator(priority),
|
||||
history: NewStateHistory(100),
|
||||
netInterface: netInterface,
|
||||
stopCh: make(chan struct{}),
|
||||
log: log,
|
||||
}
|
||||
|
||||
inst.advertTimer = NewTimer(time.Duration(advertInt)*time.Second, inst.onAdvertTimer)
|
||||
inst.masterDownTimer = NewTimer(CalculateMasterDownInterval(advertInt), inst.onMasterDownTimer)
|
||||
|
||||
inst.state.OnStateChange(func(old, new State) {
|
||||
inst.history.Add(old, new, "state transition")
|
||||
inst.log.Info("[%s] state changed: %s -> %s", inst.Name, old, new)
|
||||
inst.handleStateChange(old, new)
|
||||
})
|
||||
|
||||
return inst, nil
|
||||
}
|
||||
|
||||
func (inst *Instance) Start() error {
|
||||
inst.mu.Lock()
|
||||
if inst.running {
|
||||
inst.mu.Unlock()
|
||||
return fmt.Errorf("instance %s already running", inst.Name)
|
||||
}
|
||||
inst.running = true
|
||||
inst.mu.Unlock()
|
||||
|
||||
var err error
|
||||
inst.socket, err = NewSocket(inst.Interface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create socket: %w", err)
|
||||
}
|
||||
|
||||
inst.arpSender, err = NewARPSender(inst.Interface)
|
||||
if err != nil {
|
||||
inst.socket.Close()
|
||||
return fmt.Errorf("failed to create ARP sender: %w", err)
|
||||
}
|
||||
|
||||
inst.log.Info("[%s] starting VRRP instance (VRID=%d, Priority=%d, Interface=%s)",
|
||||
inst.Name, inst.VirtualRouterID, inst.Priority, inst.Interface)
|
||||
|
||||
inst.state.SetState(StateBackup)
|
||||
inst.masterDownTimer.Start()
|
||||
|
||||
inst.wg.Add(1)
|
||||
go inst.receiveLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) Stop() {
|
||||
inst.mu.Lock()
|
||||
if !inst.running {
|
||||
inst.mu.Unlock()
|
||||
return
|
||||
}
|
||||
inst.running = false
|
||||
inst.mu.Unlock()
|
||||
|
||||
inst.log.Info("[%s] stopping VRRP instance", inst.Name)
|
||||
|
||||
close(inst.stopCh)
|
||||
inst.wg.Wait()
|
||||
|
||||
inst.advertTimer.Stop()
|
||||
inst.masterDownTimer.Stop()
|
||||
|
||||
if inst.state.GetState() == StateMaster {
|
||||
inst.removeVIPs()
|
||||
}
|
||||
|
||||
if inst.socket != nil {
|
||||
inst.socket.Close()
|
||||
}
|
||||
|
||||
if inst.arpSender != nil {
|
||||
inst.arpSender.Close()
|
||||
}
|
||||
|
||||
inst.state.SetState(StateInit)
|
||||
}
|
||||
|
||||
func (inst *Instance) receiveLoop() {
|
||||
defer inst.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-inst.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
pkt, srcIP, err := inst.socket.Receive()
|
||||
if err != nil {
|
||||
inst.log.Debug("[%s] failed to receive packet: %v", inst.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if pkt.VirtualRtrID != inst.VirtualRouterID {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pkt.Validate(inst.AuthPass); err != nil {
|
||||
inst.log.Warn("[%s] packet validation failed: %v", inst.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
inst.handleAdvertisement(pkt, srcIP)
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) handleAdvertisement(pkt *VRRPPacket, srcIP net.IP) {
|
||||
currentState := inst.state.GetState()
|
||||
localPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
inst.log.Debug("[%s] received advertisement from %s (priority=%d, state=%s)",
|
||||
inst.Name, srcIP, pkt.Priority, currentState)
|
||||
|
||||
switch currentState {
|
||||
case StateBackup:
|
||||
if pkt.Priority == 0 {
|
||||
inst.masterDownTimer.SetDuration(CalculateSkewTime(localPriority))
|
||||
inst.masterDownTimer.Reset()
|
||||
} else if !ShouldBecomeMaster(localPriority, pkt.Priority, inst.socket.localIP.String(), srcIP.String()) {
|
||||
inst.masterDownTimer.Reset()
|
||||
}
|
||||
|
||||
case StateMaster:
|
||||
if ShouldBecomeMaster(pkt.Priority, localPriority, srcIP.String(), inst.socket.localIP.String()) {
|
||||
inst.log.Warn("[%s] received higher priority advertisement, stepping down", inst.Name)
|
||||
inst.state.SetState(StateBackup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) onAdvertTimer() {
|
||||
if inst.state.GetState() == StateMaster {
|
||||
inst.sendAdvertisement()
|
||||
inst.advertTimer.Start()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) onMasterDownTimer() {
|
||||
if inst.state.GetState() == StateBackup {
|
||||
inst.log.Info("[%s] master down timer expired, becoming master", inst.Name)
|
||||
inst.state.SetState(StateMaster)
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) sendAdvertisement() error {
|
||||
priority := inst.priorityCalc.GetPriority()
|
||||
|
||||
pkt := NewAdvertisement(
|
||||
inst.VirtualRouterID,
|
||||
priority,
|
||||
inst.AdvertInterval,
|
||||
inst.VirtualIPs,
|
||||
inst.AuthType,
|
||||
inst.AuthPass,
|
||||
)
|
||||
|
||||
if err := inst.socket.Send(pkt); err != nil {
|
||||
inst.log.Error("[%s] failed to send advertisement: %v", inst.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
inst.log.Debug("[%s] sent advertisement (priority=%d)", inst.Name, priority)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) handleStateChange(old, new State) {
|
||||
switch new {
|
||||
case StateMaster:
|
||||
inst.becomeMaster()
|
||||
case StateBackup:
|
||||
inst.becomeBackup(old)
|
||||
case StateFault:
|
||||
inst.becomeFault()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) becomeMaster() {
|
||||
inst.log.Info("[%s] transitioning to MASTER state", inst.Name)
|
||||
|
||||
if err := inst.addVIPs(); err != nil {
|
||||
inst.log.Error("[%s] failed to add VIPs: %v", inst.Name, err)
|
||||
inst.state.SetState(StateFault)
|
||||
return
|
||||
}
|
||||
|
||||
if err := inst.arpSender.SendGratuitousARPForIPs(inst.VirtualIPs); err != nil {
|
||||
inst.log.Error("[%s] failed to send gratuitous ARP: %v", inst.Name, err)
|
||||
}
|
||||
|
||||
inst.masterDownTimer.Stop()
|
||||
inst.advertTimer.Start()
|
||||
|
||||
inst.sendAdvertisement()
|
||||
|
||||
if inst.onMaster != nil {
|
||||
inst.onMaster()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) becomeBackup(oldState State) {
|
||||
inst.log.Info("[%s] transitioning to BACKUP state", inst.Name)
|
||||
|
||||
inst.advertTimer.Stop()
|
||||
|
||||
if oldState == StateMaster {
|
||||
if err := inst.removeVIPs(); err != nil {
|
||||
inst.log.Error("[%s] failed to remove VIPs: %v", inst.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
inst.masterDownTimer.Reset()
|
||||
|
||||
if inst.onBackup != nil {
|
||||
inst.onBackup()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) becomeFault() {
|
||||
inst.log.Error("[%s] transitioning to FAULT state", inst.Name)
|
||||
|
||||
inst.advertTimer.Stop()
|
||||
inst.masterDownTimer.Stop()
|
||||
|
||||
if err := inst.removeVIPs(); err != nil {
|
||||
inst.log.Error("[%s] failed to remove VIPs: %v", inst.Name, err)
|
||||
}
|
||||
|
||||
if inst.onFault != nil {
|
||||
inst.onFault()
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) addVIPs() error {
|
||||
inst.log.Info("[%s] adding virtual IPs", inst.Name)
|
||||
|
||||
for _, vipStr := range inst.getVIPsWithCIDR() {
|
||||
if err := inst.netInterface.AddIP(vipStr); err != nil {
|
||||
inst.log.Error("[%s] failed to add VIP %s: %v", inst.Name, vipStr, err)
|
||||
return err
|
||||
}
|
||||
inst.log.Info("[%s] added VIP %s", inst.Name, vipStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) removeVIPs() error {
|
||||
inst.log.Info("[%s] removing virtual IPs", inst.Name)
|
||||
|
||||
for _, vipStr := range inst.getVIPsWithCIDR() {
|
||||
has, _ := inst.netInterface.HasIP(vipStr)
|
||||
if !has {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := inst.netInterface.DeleteIP(vipStr); err != nil {
|
||||
inst.log.Error("[%s] failed to remove VIP %s: %v", inst.Name, vipStr, err)
|
||||
return err
|
||||
}
|
||||
inst.log.Info("[%s] removed VIP %s", inst.Name, vipStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (inst *Instance) getVIPsWithCIDR() []string {
|
||||
result := make([]string, len(inst.VirtualIPs))
|
||||
for i, ip := range inst.VirtualIPs {
|
||||
result[i] = ip.String() + "/32"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (inst *Instance) GetState() State {
|
||||
return inst.state.GetState()
|
||||
}
|
||||
|
||||
func (inst *Instance) OnMaster(callback func()) {
|
||||
inst.onMaster = callback
|
||||
}
|
||||
|
||||
func (inst *Instance) OnBackup(callback func()) {
|
||||
inst.onBackup = callback
|
||||
}
|
||||
|
||||
func (inst *Instance) OnFault(callback func()) {
|
||||
inst.onFault = callback
|
||||
}
|
||||
|
||||
func (inst *Instance) AdjustPriority(delta int) {
|
||||
inst.mu.Lock()
|
||||
defer inst.mu.Unlock()
|
||||
|
||||
oldPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
if delta < 0 {
|
||||
inst.priorityCalc.DecreasePriority(uint8(-delta))
|
||||
}
|
||||
|
||||
newPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
if oldPriority != newPriority {
|
||||
inst.log.Info("[%s] priority adjusted: %d -> %d (delta=%d)",
|
||||
inst.Name, oldPriority, newPriority, delta)
|
||||
}
|
||||
}
|
||||
|
||||
func (inst *Instance) ResetPriority() {
|
||||
inst.mu.Lock()
|
||||
defer inst.mu.Unlock()
|
||||
|
||||
oldPriority := inst.priorityCalc.GetPriority()
|
||||
inst.priorityCalc.ResetPriority()
|
||||
newPriority := inst.priorityCalc.GetPriority()
|
||||
|
||||
if oldPriority != newPriority {
|
||||
inst.log.Info("[%s] priority reset: %d -> %d",
|
||||
inst.Name, oldPriority, newPriority)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user