diff --git a/go.mod b/go.mod index 23af7d4..f12d66d 100644 --- a/go.mod +++ b/go.mod @@ -1,20 +1,23 @@ module github.com/loveuer/go-alived -go 1.25.0 +go 1.24.0 + +require ( + github.com/mdlayher/arp v0.0.0-20220512170110-6706a2966875 + github.com/spf13/cobra v1.10.2 + github.com/vishvananda/netlink v1.3.1 + golang.org/x/net v0.47.0 + gopkg.in/yaml.v3 v3.0.1 +) require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/native v1.0.0 // indirect - github.com/mdlayher/arp v0.0.0-20220512170110-6706a2966875 // indirect github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118 // indirect github.com/mdlayher/packet v1.0.0 // indirect github.com/mdlayher/socket v0.2.1 // indirect - github.com/spf13/cobra v1.10.2 // indirect github.com/spf13/pflag v1.0.9 // indirect - github.com/vishvananda/netlink v1.3.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect - golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sys v0.38.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 5730267..4ef65df 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= @@ -25,7 +26,6 @@ github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZla go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65 h1:+rhAzEzT3f4JtomfC371qB+0Ola2caSKcY69NUBZrRQ= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= @@ -35,12 +35,12 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/cmd/test.go b/internal/cmd/test.go index 0aecdba..e8f3a92 100644 --- a/internal/cmd/test.go +++ b/internal/cmd/test.go @@ -367,7 +367,7 @@ func (t *EnvironmentTest) TestCloudEnvironment() { if err == nil { cloudDetected = true t.AddResult("云环境", !test.isFatal, fmt.Sprintf("检测到%s环境", test.name), test.isFatal) - t.log.Warn(test.solution) + t.log.Warn("%s", test.solution) } } diff --git a/internal/health/factory.go b/internal/health/factory.go index d7b69ac..c6177c3 100644 --- a/internal/health/factory.go +++ b/internal/health/factory.go @@ -8,9 +8,13 @@ import ( ) func CreateChecker(cfg *config.HealthChecker) (Checker, error) { + if cfg.Config == nil { + return nil, fmt.Errorf("missing config for checker %s", cfg.Name) + } + configMap, ok := cfg.Config.(map[string]interface{}) if !ok { - return nil, fmt.Errorf("invalid config for checker %s", cfg.Name) + return nil, fmt.Errorf("invalid config type for checker %s: expected map[string]interface{}", cfg.Name) } switch cfg.Type { @@ -36,6 +40,7 @@ func LoadFromConfig(cfg *config.Config, log *logger.Logger) (*Manager, error) { return nil, fmt.Errorf("failed to create checker %s: %w", healthCfg.Name, err) } + configMap, _ := healthCfg.Config.(map[string]interface{}) monitorCfg := &CheckerConfig{ Name: healthCfg.Name, Type: healthCfg.Type, @@ -43,7 +48,7 @@ func LoadFromConfig(cfg *config.Config, log *logger.Logger) (*Manager, error) { Timeout: healthCfg.Timeout, Rise: healthCfg.Rise, Fall: healthCfg.Fall, - Config: healthCfg.Config.(map[string]interface{}), + Config: configMap, } monitor := NewMonitor(checker, monitorCfg, log) diff --git a/internal/health/manager.go b/internal/health/manager.go new file mode 100644 index 0000000..be32cb4 --- /dev/null +++ b/internal/health/manager.go @@ -0,0 +1,74 @@ +package health + +import ( + "sync" + + "github.com/loveuer/go-alived/pkg/logger" +) + +// Manager manages multiple health check monitors. +type Manager struct { + monitors map[string]*Monitor + mu sync.RWMutex + log *logger.Logger +} + +// NewManager creates a new health check Manager. +func NewManager(log *logger.Logger) *Manager { + return &Manager{ + monitors: make(map[string]*Monitor), + log: log, + } +} + +// AddMonitor adds a monitor to the manager. +func (m *Manager) AddMonitor(monitor *Monitor) { + m.mu.Lock() + defer m.mu.Unlock() + m.monitors[monitor.config.Name] = monitor +} + +// GetMonitor retrieves a monitor by name. +func (m *Manager) GetMonitor(name string) (*Monitor, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + monitor, ok := m.monitors[name] + return monitor, ok +} + +// StartAll starts all registered monitors. +func (m *Manager) StartAll() { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, monitor := range m.monitors { + monitor.Start() + } + + m.log.Info("started %d health check monitor(s)", len(m.monitors)) +} + +// StopAll stops all registered monitors. +func (m *Manager) StopAll() { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, monitor := range m.monitors { + monitor.Stop() + } + + m.log.Info("stopped all health check monitors") +} + +// GetAllStates returns the current state of all monitors. +func (m *Manager) GetAllStates() map[string]*CheckerState { + m.mu.RLock() + defer m.mu.RUnlock() + + states := make(map[string]*CheckerState) + for name, monitor := range m.monitors { + states[name] = monitor.GetState() + } + + return states +} diff --git a/internal/health/monitor.go b/internal/health/monitor.go index 331ddca..ff4f857 100644 --- a/internal/health/monitor.go +++ b/internal/health/monitor.go @@ -8,6 +8,7 @@ import ( "github.com/loveuer/go-alived/pkg/logger" ) +// Monitor runs periodic health checks and tracks state. type Monitor struct { checker Checker config *CheckerConfig @@ -21,6 +22,7 @@ type Monitor struct { mu sync.RWMutex } +// NewMonitor creates a new Monitor for the given checker. func NewMonitor(checker Checker, config *CheckerConfig, log *logger.Logger) *Monitor { return &Monitor{ checker: checker, @@ -35,6 +37,7 @@ func NewMonitor(checker Checker, config *CheckerConfig, log *logger.Logger) *Mon } } +// Start begins the health check loop. func (m *Monitor) Start() { m.mu.Lock() if m.running { @@ -51,6 +54,7 @@ func (m *Monitor) Start() { go m.checkLoop() } +// Stop stops the health check loop. func (m *Monitor) Stop() { m.mu.Lock() if !m.running { @@ -71,6 +75,7 @@ func (m *Monitor) checkLoop() { ticker := time.NewTicker(m.config.Interval) defer ticker.Stop() + // Perform initial check immediately m.performCheck() for { @@ -95,7 +100,8 @@ func (m *Monitor) performCheck() { oldHealthy := m.state.Healthy stateChanged := m.state.Update(result, m.config.Rise, m.config.Fall) newHealthy := m.state.Healthy - callbacks := m.callbacks + callbacks := make([]StateChangeCallback, len(m.callbacks)) + copy(callbacks, m.callbacks) m.mu.Unlock() m.log.Debug("[HealthCheck:%s] check completed: result=%s, duration=%s, healthy=%v", @@ -111,12 +117,14 @@ func (m *Monitor) performCheck() { } } +// OnStateChange registers a callback for health state changes. func (m *Monitor) OnStateChange(callback StateChangeCallback) { m.mu.Lock() defer m.mu.Unlock() m.callbacks = append(m.callbacks, callback) } +// GetState returns a copy of the current checker state. func (m *Monitor) GetState() *CheckerState { m.mu.RLock() defer m.mu.RUnlock() @@ -125,68 +133,9 @@ func (m *Monitor) GetState() *CheckerState { return &stateCopy } +// IsHealthy returns whether the checker is currently healthy. func (m *Monitor) IsHealthy() bool { m.mu.RLock() defer m.mu.RUnlock() return m.state.Healthy } - -type Manager struct { - monitors map[string]*Monitor - mu sync.RWMutex - log *logger.Logger -} - -func NewManager(log *logger.Logger) *Manager { - return &Manager{ - monitors: make(map[string]*Monitor), - log: log, - } -} - -func (m *Manager) AddMonitor(monitor *Monitor) { - m.mu.Lock() - defer m.mu.Unlock() - m.monitors[monitor.config.Name] = monitor -} - -func (m *Manager) GetMonitor(name string) (*Monitor, bool) { - m.mu.RLock() - defer m.mu.RUnlock() - monitor, ok := m.monitors[name] - return monitor, ok -} - -func (m *Manager) StartAll() { - m.mu.RLock() - defer m.mu.RUnlock() - - for _, monitor := range m.monitors { - monitor.Start() - } - - m.log.Info("started %d health check monitor(s)", len(m.monitors)) -} - -func (m *Manager) StopAll() { - m.mu.RLock() - defer m.mu.RUnlock() - - for _, monitor := range m.monitors { - monitor.Stop() - } - - m.log.Info("stopped all health check monitors") -} - -func (m *Manager) GetAllStates() map[string]*CheckerState { - m.mu.RLock() - defer m.mu.RUnlock() - - states := make(map[string]*CheckerState) - for name, monitor := range m.monitors { - states[name] = monitor.GetState() - } - - return states -} diff --git a/internal/vrrp/history.go b/internal/vrrp/history.go new file mode 100644 index 0000000..2b81545 --- /dev/null +++ b/internal/vrrp/history.go @@ -0,0 +1,95 @@ +package vrrp + +import ( + "fmt" + "sync" + "time" +) + +// StateTransition represents a single state transition event. +type StateTransition struct { + From State + To State + Timestamp time.Time + Reason string +} + +// StateHistory maintains a bounded history of state transitions. +type StateHistory struct { + transitions []StateTransition + maxSize int + mu sync.RWMutex +} + +// NewStateHistory creates a new StateHistory with the specified maximum size. +func NewStateHistory(maxSize int) *StateHistory { + return &StateHistory{ + transitions: make([]StateTransition, 0, maxSize), + maxSize: maxSize, + } +} + +// Add records a new state transition. +func (sh *StateHistory) Add(from, to State, reason string) { + sh.mu.Lock() + defer sh.mu.Unlock() + + transition := StateTransition{ + From: from, + To: to, + Timestamp: time.Now(), + Reason: reason, + } + + sh.transitions = append(sh.transitions, transition) + + // Maintain bounded size using ring buffer style + if len(sh.transitions) > sh.maxSize { + // Copy to new slice to allow garbage collection of old backing array + newTransitions := make([]StateTransition, len(sh.transitions)-1, sh.maxSize) + copy(newTransitions, sh.transitions[1:]) + sh.transitions = newTransitions + } +} + +// GetRecent returns the most recent n transitions. +func (sh *StateHistory) GetRecent(n int) []StateTransition { + sh.mu.RLock() + defer sh.mu.RUnlock() + + if n > len(sh.transitions) { + n = len(sh.transitions) + } + + start := len(sh.transitions) - n + result := make([]StateTransition, n) + copy(result, sh.transitions[start:]) + + return result +} + +// Len returns the number of recorded transitions. +func (sh *StateHistory) Len() int { + sh.mu.RLock() + defer sh.mu.RUnlock() + return len(sh.transitions) +} + +// String returns a formatted string representation of the history. +func (sh *StateHistory) String() string { + sh.mu.RLock() + defer sh.mu.RUnlock() + + if len(sh.transitions) == 0 { + return "No state transitions" + } + + result := "State transition history:\n" + for _, t := range sh.transitions { + result += fmt.Sprintf(" %s: %s -> %s (%s)\n", + t.Timestamp.Format("2006-01-02 15:04:05"), + t.From, t.To, t.Reason) + } + + return result +} diff --git a/internal/vrrp/instance.go b/internal/vrrp/instance.go index eb90a4f..3202406 100644 --- a/internal/vrrp/instance.go +++ b/internal/vrrp/instance.go @@ -17,6 +17,7 @@ type Instance struct { AdvertInterval uint8 Interface string VirtualIPs []net.IP + VirtualIPCIDRs []string // preserve original CIDR notation AuthType uint8 AuthPass string TrackScripts []string @@ -64,12 +65,14 @@ func NewInstance( } virtualIPs := make([]net.IP, 0, len(vips)) + virtualIPCIDRs := make([]string, 0, len(vips)) for _, vip := range vips { ip, _, err := net.ParseCIDR(vip) if err != nil { return nil, fmt.Errorf("invalid VIP %s: %w", vip, err) } virtualIPs = append(virtualIPs, ip) + virtualIPCIDRs = append(virtualIPCIDRs, vip) } var authTypeNum uint8 @@ -94,6 +97,7 @@ func NewInstance( AdvertInterval: advertInt, Interface: iface, VirtualIPs: virtualIPs, + VirtualIPCIDRs: virtualIPCIDRs, AuthType: authTypeNum, AuthPass: authPass, TrackScripts: trackScripts, @@ -192,8 +196,15 @@ func (inst *Instance) receiveLoop() { default: } + // Set read deadline to allow periodic check of stop channel + inst.socket.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + pkt, srcIP, err := inst.socket.Receive() if err != nil { + // Check if it's a timeout error, which is expected + if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { + continue + } inst.log.Debug("[%s] failed to receive packet: %v", inst.Name, err) continue } @@ -371,11 +382,7 @@ func (inst *Instance) removeVIPs() error { } func (inst *Instance) getVIPsWithCIDR() []string { - result := make([]string, len(inst.VirtualIPs)) - for i, ip := range inst.VirtualIPs { - result[i] = ip.String() + "/32" - } - return result + return inst.VirtualIPCIDRs } func (inst *Instance) GetState() State { @@ -399,15 +406,17 @@ func (inst *Instance) AdjustPriority(delta int) { defer inst.mu.Unlock() oldPriority := inst.priorityCalc.GetPriority() - + if delta < 0 { inst.priorityCalc.DecreasePriority(uint8(-delta)) + } else if delta > 0 { + inst.priorityCalc.IncreasePriority(uint8(delta)) } - + newPriority := inst.priorityCalc.GetPriority() - + if oldPriority != newPriority { - inst.log.Info("[%s] priority adjusted: %d -> %d (delta=%d)", + inst.log.Info("[%s] priority adjusted: %d -> %d (delta=%d)", inst.Name, oldPriority, newPriority, delta) } } diff --git a/internal/vrrp/priority.go b/internal/vrrp/priority.go new file mode 100644 index 0000000..368bf05 --- /dev/null +++ b/internal/vrrp/priority.go @@ -0,0 +1,99 @@ +package vrrp + +import ( + "sync" + "time" +) + +// PriorityCalculator manages VRRP priority with support for dynamic adjustment. +type PriorityCalculator struct { + basePriority uint8 + currentPriority uint8 + mu sync.RWMutex +} + +// NewPriorityCalculator creates a new PriorityCalculator with the specified base priority. +func NewPriorityCalculator(basePriority uint8) *PriorityCalculator { + return &PriorityCalculator{ + basePriority: basePriority, + currentPriority: basePriority, + } +} + +// GetPriority returns the current priority. +func (pc *PriorityCalculator) GetPriority() uint8 { + pc.mu.RLock() + defer pc.mu.RUnlock() + return pc.currentPriority +} + +// DecreasePriority decreases the current priority by the specified amount. +// The priority will not go below 0. +func (pc *PriorityCalculator) DecreasePriority(amount uint8) { + pc.mu.Lock() + defer pc.mu.Unlock() + + if pc.currentPriority > amount { + pc.currentPriority -= amount + } else { + pc.currentPriority = 0 + } +} + +// IncreasePriority increases the current priority by the specified amount. +// The priority will not exceed 255 or the base priority. +func (pc *PriorityCalculator) IncreasePriority(amount uint8) { + pc.mu.Lock() + defer pc.mu.Unlock() + + newPriority := pc.currentPriority + amount + if newPriority > pc.basePriority { + newPriority = pc.basePriority + } + if newPriority < pc.currentPriority { // overflow check + newPriority = pc.basePriority + } + pc.currentPriority = newPriority +} + +// ResetPriority resets the priority to the base value. +func (pc *PriorityCalculator) ResetPriority() { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.currentPriority = pc.basePriority +} + +// SetBasePriority sets a new base priority and resets current priority to match. +func (pc *PriorityCalculator) SetBasePriority(priority uint8) { + pc.mu.Lock() + defer pc.mu.Unlock() + pc.basePriority = priority + pc.currentPriority = priority +} + +// ShouldBecomeMaster determines if the local node should become master +// based on priority comparison and IP address tie-breaking. +func ShouldBecomeMaster(localPriority, remotePriority uint8, localIP, remoteIP string) bool { + if localPriority > remotePriority { + return true + } + + if localPriority == remotePriority { + return localIP > remoteIP + } + + return false +} + +// CalculateMasterDownInterval calculates the master down interval +// according to VRRP specification: (3 * Advertisement_Interval). +func CalculateMasterDownInterval(advertInt uint8) time.Duration { + return time.Duration(3*int(advertInt)) * time.Second +} + +// CalculateSkewTime calculates the skew time for master down timer +// according to VRRP specification: ((256 - Priority) / 256). +func CalculateSkewTime(priority uint8) time.Duration { + skew := float64(256-int(priority)) / 256.0 + return time.Duration(skew * float64(time.Second)) +} diff --git a/internal/vrrp/socket.go b/internal/vrrp/socket.go index 39b1937..b1d9dd7 100644 --- a/internal/vrrp/socket.go +++ b/internal/vrrp/socket.go @@ -5,6 +5,7 @@ import ( "net" "os" "syscall" + "time" "golang.org/x/net/ipv4" ) @@ -14,10 +15,11 @@ const ( ) type Socket struct { - conn *ipv4.RawConn - iface *net.Interface - localIP net.IP - groupIP net.IP + conn *ipv4.RawConn + packetConn net.PacketConn + iface *net.Interface + localIP net.IP + groupIP net.IP } func NewSocket(ifaceName string) (*Socket, error) { @@ -56,9 +58,8 @@ func NewSocket(ifaceName string) (*Socket, error) { } file := os.NewFile(uintptr(fd), "vrrp-socket") - defer file.Close() - packetConn, err := net.FilePacketConn(file) + file.Close() if err != nil { return nil, fmt.Errorf("failed to create packet connection: %w", err) } @@ -86,10 +87,11 @@ func NewSocket(ifaceName string) (*Socket, error) { } return &Socket{ - conn: rawConn, - iface: iface, - localIP: localIP, - groupIP: groupIP, + conn: rawConn, + packetConn: packetConn, + iface: iface, + localIP: localIP, + groupIP: groupIP, }, nil } @@ -133,6 +135,10 @@ func (s *Socket) Receive() (*VRRPPacket, net.IP, error) { return pkt, header.Src, nil } +func (s *Socket) SetReadDeadline(t time.Time) error { + return s.packetConn.SetReadDeadline(t) +} + func (s *Socket) Close() error { if err := s.conn.LeaveGroup(s.iface, &net.IPAddr{IP: s.groupIP}); err != nil { return err diff --git a/internal/vrrp/state.go b/internal/vrrp/state.go index efec825..ad1f1ba 100644 --- a/internal/vrrp/state.go +++ b/internal/vrrp/state.go @@ -1,11 +1,8 @@ package vrrp -import ( - "fmt" - "sync" - "time" -) +import "sync" +// State represents the VRRP instance state. type State int const ( @@ -15,6 +12,7 @@ const ( StateFault ) +// String returns the string representation of the state. func (s State) String() string { switch s { case StateInit: @@ -30,33 +28,39 @@ func (s State) String() string { } } +// StateMachine manages VRRP state transitions with thread-safe callbacks. type StateMachine struct { - currentState State - previousState State - mu sync.RWMutex + currentState State + mu sync.RWMutex stateChangeCallbacks []func(old, new State) } +// NewStateMachine creates a new StateMachine with the specified initial state. func NewStateMachine(initialState State) *StateMachine { return &StateMachine{ - currentState: initialState, - previousState: StateInit, + currentState: initialState, stateChangeCallbacks: make([]func(old, new State), 0), } } +// GetState returns the current state. func (sm *StateMachine) GetState() State { sm.mu.RLock() defer sm.mu.RUnlock() return sm.currentState } +// SetState transitions to a new state and triggers registered callbacks. func (sm *StateMachine) SetState(newState State) { sm.mu.Lock() oldState := sm.currentState - sm.previousState = oldState + if oldState == newState { + sm.mu.Unlock() + return + } sm.currentState = newState - callbacks := sm.stateChangeCallbacks + callbacks := make([]func(old, new State), len(sm.stateChangeCallbacks)) + copy(callbacks, sm.stateChangeCallbacks) sm.mu.Unlock() for _, callback := range callbacks { @@ -64,195 +68,9 @@ func (sm *StateMachine) SetState(newState State) { } } +// OnStateChange registers a callback to be invoked on state changes. func (sm *StateMachine) OnStateChange(callback func(old, new State)) { sm.mu.Lock() defer sm.mu.Unlock() sm.stateChangeCallbacks = append(sm.stateChangeCallbacks, callback) } - -type Timer struct { - duration time.Duration - timer *time.Timer - callback func() - mu sync.Mutex -} - -func NewTimer(duration time.Duration, callback func()) *Timer { - return &Timer{ - duration: duration, - callback: callback, - } -} - -func (t *Timer) Start() { - t.mu.Lock() - defer t.mu.Unlock() - - if t.timer != nil { - t.timer.Stop() - } - - t.timer = time.AfterFunc(t.duration, t.callback) -} - -func (t *Timer) Stop() { - t.mu.Lock() - defer t.mu.Unlock() - - if t.timer != nil { - t.timer.Stop() - t.timer = nil - } -} - -func (t *Timer) Reset() { - t.mu.Lock() - defer t.mu.Unlock() - - if t.timer != nil { - t.timer.Stop() - } - - t.timer = time.AfterFunc(t.duration, t.callback) -} - -func (t *Timer) SetDuration(duration time.Duration) { - t.mu.Lock() - defer t.mu.Unlock() - t.duration = duration -} - -type PriorityCalculator struct { - basePriority uint8 - currentPriority uint8 - mu sync.RWMutex -} - -func NewPriorityCalculator(basePriority uint8) *PriorityCalculator { - return &PriorityCalculator{ - basePriority: basePriority, - currentPriority: basePriority, - } -} - -func (pc *PriorityCalculator) GetPriority() uint8 { - pc.mu.RLock() - defer pc.mu.RUnlock() - return pc.currentPriority -} - -func (pc *PriorityCalculator) DecreasePriority(amount uint8) { - pc.mu.Lock() - defer pc.mu.Unlock() - - if pc.currentPriority > amount { - pc.currentPriority -= amount - } else { - pc.currentPriority = 0 - } -} - -func (pc *PriorityCalculator) ResetPriority() { - pc.mu.Lock() - defer pc.mu.Unlock() - pc.currentPriority = pc.basePriority -} - -func (pc *PriorityCalculator) SetBasePriority(priority uint8) { - pc.mu.Lock() - defer pc.mu.Unlock() - pc.basePriority = priority - pc.currentPriority = priority -} - -func ShouldBecomeMaster(localPriority, remotePriority uint8, localIP, remoteIP string) bool { - if localPriority > remotePriority { - return true - } - - if localPriority == remotePriority { - return localIP > remoteIP - } - - return false -} - -func CalculateMasterDownInterval(advertInt uint8) time.Duration { - return time.Duration(3*int(advertInt)) * time.Second -} - -func CalculateSkewTime(priority uint8) time.Duration { - skew := float64(256-int(priority)) / 256.0 - return time.Duration(skew * float64(time.Second)) -} - -type StateTransition struct { - From State - To State - Timestamp time.Time - Reason string -} - -type StateHistory struct { - transitions []StateTransition - maxSize int - mu sync.RWMutex -} - -func NewStateHistory(maxSize int) *StateHistory { - return &StateHistory{ - transitions: make([]StateTransition, 0, maxSize), - maxSize: maxSize, - } -} - -func (sh *StateHistory) Add(from, to State, reason string) { - sh.mu.Lock() - defer sh.mu.Unlock() - - transition := StateTransition{ - From: from, - To: to, - Timestamp: time.Now(), - Reason: reason, - } - - sh.transitions = append(sh.transitions, transition) - - if len(sh.transitions) > sh.maxSize { - sh.transitions = sh.transitions[1:] - } -} - -func (sh *StateHistory) GetRecent(n int) []StateTransition { - sh.mu.RLock() - defer sh.mu.RUnlock() - - if n > len(sh.transitions) { - n = len(sh.transitions) - } - - start := len(sh.transitions) - n - result := make([]StateTransition, n) - copy(result, sh.transitions[start:]) - - return result -} - -func (sh *StateHistory) String() string { - sh.mu.RLock() - defer sh.mu.RUnlock() - - if len(sh.transitions) == 0 { - return "No state transitions" - } - - result := "State transition history:\n" - for _, t := range sh.transitions { - result += fmt.Sprintf(" %s: %s -> %s (%s)\n", - t.Timestamp.Format("2006-01-02 15:04:05"), - t.From, t.To, t.Reason) - } - - return result -} \ No newline at end of file diff --git a/internal/vrrp/timer.go b/internal/vrrp/timer.go new file mode 100644 index 0000000..3ecef07 --- /dev/null +++ b/internal/vrrp/timer.go @@ -0,0 +1,64 @@ +package vrrp + +import ( + "sync" + "time" +) + +// Timer provides a thread-safe timer with callback support. +type Timer struct { + duration time.Duration + timer *time.Timer + callback func() + mu sync.Mutex +} + +// NewTimer creates a new Timer with the specified duration and callback. +func NewTimer(duration time.Duration, callback func()) *Timer { + return &Timer{ + duration: duration, + callback: callback, + } +} + +// Start starts or restarts the timer. +func (t *Timer) Start() { + t.mu.Lock() + defer t.mu.Unlock() + + if t.timer != nil { + t.timer.Stop() + } + + t.timer = time.AfterFunc(t.duration, t.callback) +} + +// Stop stops the timer if it's running. +func (t *Timer) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + + if t.timer != nil { + t.timer.Stop() + t.timer = nil + } +} + +// Reset stops the current timer and starts a new one with the same duration. +func (t *Timer) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + + if t.timer != nil { + t.timer.Stop() + } + + t.timer = time.AfterFunc(t.duration, t.callback) +} + +// SetDuration updates the timer's duration for future starts. +func (t *Timer) SetDuration(duration time.Duration) { + t.mu.Lock() + defer t.mu.Unlock() + t.duration = duration +}