feat: implement suspect node handling with callbacks for DNS record management

This commit is contained in:
anonpenguin23 2026-02-20 09:27:35 +02:00
parent 2b0bfaaa12
commit 4ebf558719
9 changed files with 372 additions and 9 deletions

View File

@ -40,7 +40,7 @@ func parseConfig(c *caddy.Controller) (*RQLitePlugin, error) {
var ( var (
dsn = "http://localhost:5001" dsn = "http://localhost:5001"
refreshRate = 10 * time.Second refreshRate = 10 * time.Second
cacheTTL = 60 * time.Second cacheTTL = 30 * time.Second
cacheSize = 10000 cacheSize = 10000
zones []string zones []string
) )

View File

@ -333,7 +333,7 @@ func (ci *CoreDNSInstaller) generateCorefile(domain, rqliteDSN string) string {
rqlite { rqlite {
dsn %s dsn %s
refresh 5s refresh 5s
ttl 60 ttl 30
cache_size 10000 cache_size 10000
} }

View File

@ -580,12 +580,20 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
} }
}) })
gw.healthMonitor.OnNodeRecovered(func(nodeID string) { gw.healthMonitor.OnNodeRecovered(func(nodeID string) {
logger.ComponentInfo(logging.ComponentGeneral, "Previously dead node recovered — checking for orphaned services", logger.ComponentInfo(logging.ComponentGeneral, "Node recovered — re-enabling DNS and checking for orphaned services",
zap.String("node_id", nodeID)) zap.String("node_id", nodeID))
if gw.nodeRecoverer != nil { if gw.nodeRecoverer != nil {
go gw.nodeRecoverer.HandleSuspectRecovery(context.Background(), nodeID)
go gw.nodeRecoverer.HandleRecoveredNode(context.Background(), nodeID) go gw.nodeRecoverer.HandleRecoveredNode(context.Background(), nodeID)
} }
}) })
gw.healthMonitor.OnNodeSuspect(func(nodeID string) {
logger.ComponentWarn(logging.ComponentGeneral, "Node SUSPECT — disabling DNS records",
zap.String("suspect_node", nodeID))
if gw.nodeRecoverer != nil {
go gw.nodeRecoverer.HandleSuspectNode(context.Background(), nodeID)
}
})
go gw.healthMonitor.Start(context.Background()) go gw.healthMonitor.Start(context.Background())
logger.ComponentInfo(logging.ComponentGeneral, "Node health monitor started", logger.ComponentInfo(logging.ComponentGeneral, "Node health monitor started",
zap.String("node_id", cfg.NodePeerID)) zap.String("node_id", cfg.NodePeerID))

View File

@ -53,6 +53,8 @@ type ClusterProvisioner interface {
type NodeRecoverer interface { type NodeRecoverer interface {
HandleDeadNode(ctx context.Context, deadNodeID string) HandleDeadNode(ctx context.Context, deadNodeID string)
HandleRecoveredNode(ctx context.Context, nodeID string) HandleRecoveredNode(ctx context.Context, nodeID string)
HandleSuspectNode(ctx context.Context, suspectNodeID string)
HandleSuspectRecovery(ctx context.Context, nodeID string)
RepairCluster(ctx context.Context, namespaceName string) error RepairCluster(ctx context.Context, namespaceName string) error
} }

View File

@ -191,6 +191,148 @@ func (cm *ClusterManager) HandleRecoveredNode(ctx context.Context, nodeID string
zap.Int("namespaces_cleaned", len(events))) zap.Int("namespaces_cleaned", len(events)))
} }
// HandleSuspectNode disables DNS records for a suspect node to prevent traffic
// from being routed to it. Called early (T+30s) when the node first becomes suspect,
// before confirming it's actually dead. If the node recovers, HandleSuspectRecovery
// will re-enable the records.
//
// Safety: never disables the last active record for a namespace.
func (cm *ClusterManager) HandleSuspectNode(ctx context.Context, suspectNodeID string) {
cm.logger.Warn("Handling suspect node — disabling DNS records",
zap.String("suspect_node", suspectNodeID),
)
// Acquire per-node lock to prevent concurrent suspect handling
suspectKey := "suspect:" + suspectNodeID
cm.provisioningMu.Lock()
if cm.provisioning[suspectKey] {
cm.provisioningMu.Unlock()
cm.logger.Info("Suspect handling already in progress for node, skipping",
zap.String("node_id", suspectNodeID))
return
}
cm.provisioning[suspectKey] = true
cm.provisioningMu.Unlock()
defer func() {
cm.provisioningMu.Lock()
delete(cm.provisioning, suspectKey)
cm.provisioningMu.Unlock()
}()
// Find all clusters this node belongs to
clusters, err := cm.getClustersByNodeID(ctx, suspectNodeID)
if err != nil {
cm.logger.Warn("Failed to find clusters for suspect node",
zap.String("suspect_node", suspectNodeID), zap.Error(err))
return
}
if len(clusters) == 0 {
cm.logger.Info("Suspect node has no namespace cluster assignments",
zap.String("suspect_node", suspectNodeID))
return
}
// Get suspect node's public IP (DNS A records contain public IPs)
ips, err := cm.getNodeIPs(ctx, suspectNodeID)
if err != nil {
cm.logger.Warn("Failed to get suspect node IPs",
zap.String("suspect_node", suspectNodeID), zap.Error(err))
return
}
dnsManager := NewDNSRecordManager(cm.db, cm.baseDomain, cm.logger)
disabledCount := 0
for _, cluster := range clusters {
// Safety check: never disable the last active record
activeCount, err := dnsManager.CountActiveNamespaceRecords(ctx, cluster.NamespaceName)
if err != nil {
cm.logger.Warn("Failed to count active DNS records, skipping namespace",
zap.String("namespace", cluster.NamespaceName),
zap.Error(err))
continue
}
if activeCount <= 1 {
cm.logger.Warn("Not disabling DNS — would leave namespace with no active records",
zap.String("namespace", cluster.NamespaceName),
zap.String("suspect_node", suspectNodeID),
zap.Int("active_records", activeCount))
continue
}
if err := dnsManager.DisableNamespaceRecord(ctx, cluster.NamespaceName, ips.IPAddress); err != nil {
cm.logger.Warn("Failed to disable DNS record for suspect node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress),
zap.Error(err))
continue
}
disabledCount++
cm.logger.Info("Disabled DNS record for suspect node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress))
}
cm.logger.Info("Suspect node DNS handling completed",
zap.String("suspect_node", suspectNodeID),
zap.Int("namespaces_affected", len(clusters)),
zap.Int("records_disabled", disabledCount))
}
// HandleSuspectRecovery re-enables DNS records for a node that recovered from
// suspect state without going dead. Called when the health monitor detects
// that a previously suspect node is responding to probes again.
func (cm *ClusterManager) HandleSuspectRecovery(ctx context.Context, nodeID string) {
cm.logger.Info("Handling suspect recovery — re-enabling DNS records",
zap.String("node_id", nodeID),
)
// Find all clusters this node belongs to
clusters, err := cm.getClustersByNodeID(ctx, nodeID)
if err != nil {
cm.logger.Warn("Failed to find clusters for recovered node",
zap.String("node_id", nodeID), zap.Error(err))
return
}
if len(clusters) == 0 {
return
}
// Get node's public IP (DNS A records contain public IPs)
ips, err := cm.getNodeIPs(ctx, nodeID)
if err != nil {
cm.logger.Warn("Failed to get recovered node IPs",
zap.String("node_id", nodeID), zap.Error(err))
return
}
dnsManager := NewDNSRecordManager(cm.db, cm.baseDomain, cm.logger)
enabledCount := 0
for _, cluster := range clusters {
if err := dnsManager.EnableNamespaceRecord(ctx, cluster.NamespaceName, ips.IPAddress); err != nil {
cm.logger.Warn("Failed to re-enable DNS record for recovered node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress),
zap.Error(err))
continue
}
enabledCount++
cm.logger.Info("Re-enabled DNS record for recovered node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress))
}
cm.logger.Info("Suspect recovery DNS handling completed",
zap.String("node_id", nodeID),
zap.Int("records_enabled", enabledCount))
}
// ReplaceClusterNode replaces a dead node in a specific namespace cluster. // ReplaceClusterNode replaces a dead node in a specific namespace cluster.
// It selects a new node, allocates ports, spawns services, updates DNS, and cleans up. // It selects a new node, allocates ports, spawns services, updates DNS, and cleans up.
func (cm *ClusterManager) ReplaceClusterNode(ctx context.Context, cluster *NamespaceCluster, deadNodeID string) error { func (cm *ClusterManager) ReplaceClusterNode(ctx context.Context, cluster *NamespaceCluster, deadNodeID string) error {

View File

@ -182,6 +182,34 @@ func (drm *DNSRecordManager) GetNamespaceGatewayIPs(ctx context.Context, namespa
return ips, nil return ips, nil
} }
// CountActiveNamespaceRecords returns the number of active A records for a namespace's main FQDN.
// Used as a safety check before disabling records to prevent disabling the last one.
func (drm *DNSRecordManager) CountActiveNamespaceRecords(ctx context.Context, namespaceName string) (int, error) {
internalCtx := client.WithInternalAuth(ctx)
fqdn := fmt.Sprintf("ns-%s.%s.", namespaceName, drm.baseDomain)
type countResult struct {
Count int `db:"count"`
}
var results []countResult
query := `SELECT COUNT(*) as count FROM dns_records WHERE fqdn = ? AND record_type = 'A' AND is_active = TRUE`
err := drm.db.Query(internalCtx, &results, query, fqdn)
if err != nil {
return 0, &ClusterError{
Message: "failed to count active namespace DNS records",
Cause: err,
}
}
if len(results) == 0 {
return 0, nil
}
return results[0].Count, nil
}
// AddNamespaceRecord adds DNS A records for a single IP to an existing namespace. // AddNamespaceRecord adds DNS A records for a single IP to an existing namespace.
// Unlike CreateNamespaceRecords, this does NOT delete existing records — it's purely additive. // Unlike CreateNamespaceRecords, this does NOT delete existing records — it's purely additive.
// Used when adding a new node to an under-provisioned cluster (repair). // Used when adding a new node to an under-provisioned cluster (repair).
@ -239,7 +267,7 @@ func (drm *DNSRecordManager) UpdateNamespaceRecord(ctx context.Context, namespac
// Update both the main record and wildcard record // Update both the main record and wildcard record
for _, f := range []string{fqdn, wildcardFqdn} { for _, f := range []string{fqdn, wildcardFqdn} {
updateQuery := `UPDATE dns_records SET value = ?, updated_at = ? WHERE fqdn = ? AND value = ?` updateQuery := `UPDATE dns_records SET value = ?, is_active = TRUE, updated_at = ? WHERE fqdn = ? AND value = ?`
_, err := drm.db.Exec(internalCtx, updateQuery, newIP, time.Now(), f, oldIP) _, err := drm.db.Exec(internalCtx, updateQuery, newIP, time.Now(), f, oldIP)
if err != nil { if err != nil {
drm.logger.Warn("Failed to update DNS record", drm.logger.Warn("Failed to update DNS record",

View File

@ -1,7 +1,9 @@
package namespace package namespace
import ( import (
"context"
"fmt" "fmt"
"strings"
"testing" "testing"
"go.uber.org/zap" "go.uber.org/zap"
@ -215,3 +217,60 @@ func TestDNSRecordManager_FQDNWithTrailingDot(t *testing.T) {
}) })
} }
} }
func TestUpdateNamespaceRecord_SetsActiveTrue(t *testing.T) {
mockDB := newMockRQLiteClient()
logger := zap.NewNop()
manager := NewDNSRecordManager(mockDB, "orama-devnet.network", logger)
ctx := context.Background()
err := manager.UpdateNamespaceRecord(ctx, "alice", "1.2.3.4", "5.6.7.8")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify the SQL contains is_active = TRUE for both FQDN and wildcard
activeCount := 0
for _, call := range mockDB.execCalls {
if strings.Contains(call.Query, "is_active = TRUE") && strings.Contains(call.Query, "UPDATE dns_records") {
activeCount++
}
}
if activeCount != 2 {
t.Fatalf("expected 2 UPDATE queries with is_active = TRUE (fqdn + wildcard), got %d", activeCount)
}
}
func TestCountActiveNamespaceRecords(t *testing.T) {
mockDB := newMockRQLiteClient()
logger := zap.NewNop()
manager := NewDNSRecordManager(mockDB, "orama-devnet.network", logger)
ctx := context.Background()
count, err := manager.CountActiveNamespaceRecords(ctx, "alice")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// With mock returning empty results, count should be 0
if count != 0 {
t.Fatalf("expected 0, got %d", count)
}
// Verify the correct query was made
if len(mockDB.queryCalls) == 0 {
t.Fatal("expected a query call")
}
lastCall := mockDB.queryCalls[len(mockDB.queryCalls)-1]
if !strings.Contains(lastCall.Query, "COUNT(*)") || !strings.Contains(lastCall.Query, "is_active = TRUE") {
t.Fatalf("unexpected query: %s", lastCall.Query)
}
// Verify the FQDN arg
expectedFQDN := "ns-alice.orama-devnet.network."
if len(lastCall.Args) == 0 {
t.Fatal("expected query args")
}
if fqdn, ok := lastCall.Args[0].(string); !ok || fqdn != expectedFQDN {
t.Fatalf("expected FQDN arg %q, got %v", expectedFQDN, lastCall.Args[0])
}
}

View File

@ -83,7 +83,8 @@ type Monitor struct {
peers map[string]*peerState // nodeID → state peers map[string]*peerState // nodeID → state
onDeadFn func(nodeID string) // callback when quorum confirms death onDeadFn func(nodeID string) // callback when quorum confirms death
onRecoveredFn func(nodeID string) // callback when node transitions from dead → healthy onRecoveredFn func(nodeID string) // callback when node transitions from suspect/dead → healthy
onSuspectFn func(nodeID string) // callback when node transitions healthy → suspect
} }
// NewMonitor creates a new health monitor. // NewMonitor creates a new health monitor.
@ -121,12 +122,19 @@ func (m *Monitor) OnNodeDead(fn func(nodeID string)) {
m.onDeadFn = fn m.onDeadFn = fn
} }
// OnNodeRecovered registers a callback invoked when a previously dead node // OnNodeRecovered registers a callback invoked when a previously suspect or dead
// transitions back to healthy. The callback runs with the monitor lock released. // node transitions back to healthy. The callback runs with the monitor lock released.
func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) { func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) {
m.onRecoveredFn = fn m.onRecoveredFn = fn
} }
// OnNodeSuspect registers a callback invoked when a node transitions from
// healthy to suspect (3 consecutive missed probes). The callback runs with
// the monitor lock released.
func (m *Monitor) OnNodeSuspect(fn func(nodeID string)) {
m.onSuspectFn = fn
}
// Start runs the monitor loop until ctx is cancelled. // Start runs the monitor loop until ctx is cancelled.
func (m *Monitor) Start(ctx context.Context) { func (m *Monitor) Start(ctx context.Context) {
m.logger.Info("Starting node health monitor", m.logger.Info("Starting node health monitor",
@ -232,8 +240,8 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
} }
if healthy { if healthy {
wasDead := ps.status == "dead" wasUnhealthy := ps.status == "suspect" || ps.status == "dead"
shouldCallback := wasDead && m.onRecoveredFn != nil shouldCallback := wasUnhealthy && m.onRecoveredFn != nil
prevStatus := ps.status prevStatus := ps.status
// Update state BEFORE releasing lock (C2 fix) // Update state BEFORE releasing lock (C2 fix)
@ -299,6 +307,7 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy": case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy":
ps.status = "suspect" ps.status = "suspect"
ps.suspectAt = time.Now() ps.suspectAt = time.Now()
shouldCallSuspect := m.onSuspectFn != nil
m.mu.Unlock() m.mu.Unlock()
m.logger.Warn("Node SUSPECT", m.logger.Warn("Node SUSPECT",
@ -306,6 +315,10 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
zap.Int("misses", ps.missCount), zap.Int("misses", ps.missCount),
) )
m.writeEvent(ctx, nodeID, "suspect") m.writeEvent(ctx, nodeID, "suspect")
if shouldCallSuspect {
m.onSuspectFn(nodeID)
}
return return
} }

View File

@ -521,3 +521,114 @@ func TestRecoveryCallback_InvokedWithoutLock(t *testing.T) {
} }
m.mu.Unlock() m.mu.Unlock()
} }
// ---------------------------------------------------------------
// OnNodeSuspect callback
// ---------------------------------------------------------------
func TestOnNodeSuspect_Callback(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
var suspectNode string
m.OnNodeSuspect(func(nodeID string) {
suspectNode = nodeID
})
ctx := context.Background()
// Drive 3 misses (DefaultSuspectAfter) → healthy → suspect
for i := 0; i < DefaultSuspectAfter; i++ {
m.updateState(ctx, "peer1", false)
}
if suspectNode != "peer1" {
t.Fatalf("expected suspect callback for peer1, got %q", suspectNode)
}
m.mu.Lock()
if m.peers["peer1"].status != "suspect" {
m.mu.Unlock()
t.Fatalf("expected suspect state, got %s", m.peers["peer1"].status)
}
m.mu.Unlock()
}
func TestOnNodeSuspect_DoesNotFireOnSubsequentMisses(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
var callCount int32
m.OnNodeSuspect(func(nodeID string) {
callCount++
})
ctx := context.Background()
// Drive to suspect (3 misses)
for i := 0; i < DefaultSuspectAfter; i++ {
m.updateState(ctx, "peer1", false)
}
if callCount != 1 {
t.Fatalf("expected suspect callback to fire once after 3 misses, got %d", callCount)
}
// Keep missing (4th through 11th miss) — should NOT fire suspect again
for i := 0; i < 8; i++ {
m.updateState(ctx, "peer1", false)
}
if callCount != 1 {
t.Fatalf("expected suspect callback to fire exactly once, got %d", callCount)
}
}
func TestRecoveredFromSuspect_Callback(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
var recoveredNode string
m.OnNodeRecovered(func(nodeID string) {
recoveredNode = nodeID
})
ctx := context.Background()
// Drive to suspect (3 misses, NOT dead)
for i := 0; i < DefaultSuspectAfter; i++ {
m.updateState(ctx, "peer1", false)
}
m.mu.Lock()
if m.peers["peer1"].status != "suspect" {
m.mu.Unlock()
t.Fatal("expected suspect state before recovery")
}
m.mu.Unlock()
// Recover from suspect
m.updateState(ctx, "peer1", true)
if recoveredNode != "peer1" {
t.Fatalf("expected recovery callback for peer1 after suspect, got %q", recoveredNode)
}
m.mu.Lock()
if m.peers["peer1"].status != "healthy" {
m.mu.Unlock()
t.Fatal("expected healthy after recovery from suspect")
}
m.mu.Unlock()
}