feat: implement suspect node handling with callbacks for DNS record management

This commit is contained in:
anonpenguin23 2026-02-20 09:27:35 +02:00
parent 2b0bfaaa12
commit 4ebf558719
9 changed files with 372 additions and 9 deletions

View File

@ -40,7 +40,7 @@ func parseConfig(c *caddy.Controller) (*RQLitePlugin, error) {
var (
dsn = "http://localhost:5001"
refreshRate = 10 * time.Second
cacheTTL = 60 * time.Second
cacheTTL = 30 * time.Second
cacheSize = 10000
zones []string
)

View File

@ -333,7 +333,7 @@ func (ci *CoreDNSInstaller) generateCorefile(domain, rqliteDSN string) string {
rqlite {
dsn %s
refresh 5s
ttl 60
ttl 30
cache_size 10000
}

View File

@ -580,12 +580,20 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
}
})
gw.healthMonitor.OnNodeRecovered(func(nodeID string) {
logger.ComponentInfo(logging.ComponentGeneral, "Previously dead node recovered — checking for orphaned services",
logger.ComponentInfo(logging.ComponentGeneral, "Node recovered — re-enabling DNS and checking for orphaned services",
zap.String("node_id", nodeID))
if gw.nodeRecoverer != nil {
go gw.nodeRecoverer.HandleSuspectRecovery(context.Background(), nodeID)
go gw.nodeRecoverer.HandleRecoveredNode(context.Background(), nodeID)
}
})
gw.healthMonitor.OnNodeSuspect(func(nodeID string) {
logger.ComponentWarn(logging.ComponentGeneral, "Node SUSPECT — disabling DNS records",
zap.String("suspect_node", nodeID))
if gw.nodeRecoverer != nil {
go gw.nodeRecoverer.HandleSuspectNode(context.Background(), nodeID)
}
})
go gw.healthMonitor.Start(context.Background())
logger.ComponentInfo(logging.ComponentGeneral, "Node health monitor started",
zap.String("node_id", cfg.NodePeerID))

View File

@ -53,6 +53,8 @@ type ClusterProvisioner interface {
type NodeRecoverer interface {
HandleDeadNode(ctx context.Context, deadNodeID string)
HandleRecoveredNode(ctx context.Context, nodeID string)
HandleSuspectNode(ctx context.Context, suspectNodeID string)
HandleSuspectRecovery(ctx context.Context, nodeID string)
RepairCluster(ctx context.Context, namespaceName string) error
}

View File

@ -191,6 +191,148 @@ func (cm *ClusterManager) HandleRecoveredNode(ctx context.Context, nodeID string
zap.Int("namespaces_cleaned", len(events)))
}
// HandleSuspectNode disables DNS records for a suspect node to prevent traffic
// from being routed to it. Called early (T+30s) when the node first becomes suspect,
// before confirming it's actually dead. If the node recovers, HandleSuspectRecovery
// will re-enable the records.
//
// Safety: never disables the last active record for a namespace.
func (cm *ClusterManager) HandleSuspectNode(ctx context.Context, suspectNodeID string) {
cm.logger.Warn("Handling suspect node — disabling DNS records",
zap.String("suspect_node", suspectNodeID),
)
// Acquire per-node lock to prevent concurrent suspect handling
suspectKey := "suspect:" + suspectNodeID
cm.provisioningMu.Lock()
if cm.provisioning[suspectKey] {
cm.provisioningMu.Unlock()
cm.logger.Info("Suspect handling already in progress for node, skipping",
zap.String("node_id", suspectNodeID))
return
}
cm.provisioning[suspectKey] = true
cm.provisioningMu.Unlock()
defer func() {
cm.provisioningMu.Lock()
delete(cm.provisioning, suspectKey)
cm.provisioningMu.Unlock()
}()
// Find all clusters this node belongs to
clusters, err := cm.getClustersByNodeID(ctx, suspectNodeID)
if err != nil {
cm.logger.Warn("Failed to find clusters for suspect node",
zap.String("suspect_node", suspectNodeID), zap.Error(err))
return
}
if len(clusters) == 0 {
cm.logger.Info("Suspect node has no namespace cluster assignments",
zap.String("suspect_node", suspectNodeID))
return
}
// Get suspect node's public IP (DNS A records contain public IPs)
ips, err := cm.getNodeIPs(ctx, suspectNodeID)
if err != nil {
cm.logger.Warn("Failed to get suspect node IPs",
zap.String("suspect_node", suspectNodeID), zap.Error(err))
return
}
dnsManager := NewDNSRecordManager(cm.db, cm.baseDomain, cm.logger)
disabledCount := 0
for _, cluster := range clusters {
// Safety check: never disable the last active record
activeCount, err := dnsManager.CountActiveNamespaceRecords(ctx, cluster.NamespaceName)
if err != nil {
cm.logger.Warn("Failed to count active DNS records, skipping namespace",
zap.String("namespace", cluster.NamespaceName),
zap.Error(err))
continue
}
if activeCount <= 1 {
cm.logger.Warn("Not disabling DNS — would leave namespace with no active records",
zap.String("namespace", cluster.NamespaceName),
zap.String("suspect_node", suspectNodeID),
zap.Int("active_records", activeCount))
continue
}
if err := dnsManager.DisableNamespaceRecord(ctx, cluster.NamespaceName, ips.IPAddress); err != nil {
cm.logger.Warn("Failed to disable DNS record for suspect node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress),
zap.Error(err))
continue
}
disabledCount++
cm.logger.Info("Disabled DNS record for suspect node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress))
}
cm.logger.Info("Suspect node DNS handling completed",
zap.String("suspect_node", suspectNodeID),
zap.Int("namespaces_affected", len(clusters)),
zap.Int("records_disabled", disabledCount))
}
// HandleSuspectRecovery re-enables DNS records for a node that recovered from
// suspect state without going dead. Called when the health monitor detects
// that a previously suspect node is responding to probes again.
func (cm *ClusterManager) HandleSuspectRecovery(ctx context.Context, nodeID string) {
cm.logger.Info("Handling suspect recovery — re-enabling DNS records",
zap.String("node_id", nodeID),
)
// Find all clusters this node belongs to
clusters, err := cm.getClustersByNodeID(ctx, nodeID)
if err != nil {
cm.logger.Warn("Failed to find clusters for recovered node",
zap.String("node_id", nodeID), zap.Error(err))
return
}
if len(clusters) == 0 {
return
}
// Get node's public IP (DNS A records contain public IPs)
ips, err := cm.getNodeIPs(ctx, nodeID)
if err != nil {
cm.logger.Warn("Failed to get recovered node IPs",
zap.String("node_id", nodeID), zap.Error(err))
return
}
dnsManager := NewDNSRecordManager(cm.db, cm.baseDomain, cm.logger)
enabledCount := 0
for _, cluster := range clusters {
if err := dnsManager.EnableNamespaceRecord(ctx, cluster.NamespaceName, ips.IPAddress); err != nil {
cm.logger.Warn("Failed to re-enable DNS record for recovered node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress),
zap.Error(err))
continue
}
enabledCount++
cm.logger.Info("Re-enabled DNS record for recovered node",
zap.String("namespace", cluster.NamespaceName),
zap.String("ip", ips.IPAddress))
}
cm.logger.Info("Suspect recovery DNS handling completed",
zap.String("node_id", nodeID),
zap.Int("records_enabled", enabledCount))
}
// ReplaceClusterNode replaces a dead node in a specific namespace cluster.
// It selects a new node, allocates ports, spawns services, updates DNS, and cleans up.
func (cm *ClusterManager) ReplaceClusterNode(ctx context.Context, cluster *NamespaceCluster, deadNodeID string) error {

View File

@ -182,6 +182,34 @@ func (drm *DNSRecordManager) GetNamespaceGatewayIPs(ctx context.Context, namespa
return ips, nil
}
// CountActiveNamespaceRecords returns the number of active A records for a namespace's main FQDN.
// Used as a safety check before disabling records to prevent disabling the last one.
func (drm *DNSRecordManager) CountActiveNamespaceRecords(ctx context.Context, namespaceName string) (int, error) {
internalCtx := client.WithInternalAuth(ctx)
fqdn := fmt.Sprintf("ns-%s.%s.", namespaceName, drm.baseDomain)
type countResult struct {
Count int `db:"count"`
}
var results []countResult
query := `SELECT COUNT(*) as count FROM dns_records WHERE fqdn = ? AND record_type = 'A' AND is_active = TRUE`
err := drm.db.Query(internalCtx, &results, query, fqdn)
if err != nil {
return 0, &ClusterError{
Message: "failed to count active namespace DNS records",
Cause: err,
}
}
if len(results) == 0 {
return 0, nil
}
return results[0].Count, nil
}
// AddNamespaceRecord adds DNS A records for a single IP to an existing namespace.
// Unlike CreateNamespaceRecords, this does NOT delete existing records — it's purely additive.
// Used when adding a new node to an under-provisioned cluster (repair).
@ -239,7 +267,7 @@ func (drm *DNSRecordManager) UpdateNamespaceRecord(ctx context.Context, namespac
// Update both the main record and wildcard record
for _, f := range []string{fqdn, wildcardFqdn} {
updateQuery := `UPDATE dns_records SET value = ?, updated_at = ? WHERE fqdn = ? AND value = ?`
updateQuery := `UPDATE dns_records SET value = ?, is_active = TRUE, updated_at = ? WHERE fqdn = ? AND value = ?`
_, err := drm.db.Exec(internalCtx, updateQuery, newIP, time.Now(), f, oldIP)
if err != nil {
drm.logger.Warn("Failed to update DNS record",

View File

@ -1,7 +1,9 @@
package namespace
import (
"context"
"fmt"
"strings"
"testing"
"go.uber.org/zap"
@ -215,3 +217,60 @@ func TestDNSRecordManager_FQDNWithTrailingDot(t *testing.T) {
})
}
}
func TestUpdateNamespaceRecord_SetsActiveTrue(t *testing.T) {
mockDB := newMockRQLiteClient()
logger := zap.NewNop()
manager := NewDNSRecordManager(mockDB, "orama-devnet.network", logger)
ctx := context.Background()
err := manager.UpdateNamespaceRecord(ctx, "alice", "1.2.3.4", "5.6.7.8")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify the SQL contains is_active = TRUE for both FQDN and wildcard
activeCount := 0
for _, call := range mockDB.execCalls {
if strings.Contains(call.Query, "is_active = TRUE") && strings.Contains(call.Query, "UPDATE dns_records") {
activeCount++
}
}
if activeCount != 2 {
t.Fatalf("expected 2 UPDATE queries with is_active = TRUE (fqdn + wildcard), got %d", activeCount)
}
}
func TestCountActiveNamespaceRecords(t *testing.T) {
mockDB := newMockRQLiteClient()
logger := zap.NewNop()
manager := NewDNSRecordManager(mockDB, "orama-devnet.network", logger)
ctx := context.Background()
count, err := manager.CountActiveNamespaceRecords(ctx, "alice")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// With mock returning empty results, count should be 0
if count != 0 {
t.Fatalf("expected 0, got %d", count)
}
// Verify the correct query was made
if len(mockDB.queryCalls) == 0 {
t.Fatal("expected a query call")
}
lastCall := mockDB.queryCalls[len(mockDB.queryCalls)-1]
if !strings.Contains(lastCall.Query, "COUNT(*)") || !strings.Contains(lastCall.Query, "is_active = TRUE") {
t.Fatalf("unexpected query: %s", lastCall.Query)
}
// Verify the FQDN arg
expectedFQDN := "ns-alice.orama-devnet.network."
if len(lastCall.Args) == 0 {
t.Fatal("expected query args")
}
if fqdn, ok := lastCall.Args[0].(string); !ok || fqdn != expectedFQDN {
t.Fatalf("expected FQDN arg %q, got %v", expectedFQDN, lastCall.Args[0])
}
}

View File

@ -83,7 +83,8 @@ type Monitor struct {
peers map[string]*peerState // nodeID → state
onDeadFn func(nodeID string) // callback when quorum confirms death
onRecoveredFn func(nodeID string) // callback when node transitions from dead → healthy
onRecoveredFn func(nodeID string) // callback when node transitions from suspect/dead → healthy
onSuspectFn func(nodeID string) // callback when node transitions healthy → suspect
}
// NewMonitor creates a new health monitor.
@ -121,12 +122,19 @@ func (m *Monitor) OnNodeDead(fn func(nodeID string)) {
m.onDeadFn = fn
}
// OnNodeRecovered registers a callback invoked when a previously dead node
// transitions back to healthy. The callback runs with the monitor lock released.
// OnNodeRecovered registers a callback invoked when a previously suspect or dead
// node transitions back to healthy. The callback runs with the monitor lock released.
func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) {
m.onRecoveredFn = fn
}
// OnNodeSuspect registers a callback invoked when a node transitions from
// healthy to suspect (3 consecutive missed probes). The callback runs with
// the monitor lock released.
func (m *Monitor) OnNodeSuspect(fn func(nodeID string)) {
m.onSuspectFn = fn
}
// Start runs the monitor loop until ctx is cancelled.
func (m *Monitor) Start(ctx context.Context) {
m.logger.Info("Starting node health monitor",
@ -232,8 +240,8 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
}
if healthy {
wasDead := ps.status == "dead"
shouldCallback := wasDead && m.onRecoveredFn != nil
wasUnhealthy := ps.status == "suspect" || ps.status == "dead"
shouldCallback := wasUnhealthy && m.onRecoveredFn != nil
prevStatus := ps.status
// Update state BEFORE releasing lock (C2 fix)
@ -299,6 +307,7 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy":
ps.status = "suspect"
ps.suspectAt = time.Now()
shouldCallSuspect := m.onSuspectFn != nil
m.mu.Unlock()
m.logger.Warn("Node SUSPECT",
@ -306,6 +315,10 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
zap.Int("misses", ps.missCount),
)
m.writeEvent(ctx, nodeID, "suspect")
if shouldCallSuspect {
m.onSuspectFn(nodeID)
}
return
}

View File

@ -521,3 +521,114 @@ func TestRecoveryCallback_InvokedWithoutLock(t *testing.T) {
}
m.mu.Unlock()
}
// ---------------------------------------------------------------
// OnNodeSuspect callback
// ---------------------------------------------------------------
func TestOnNodeSuspect_Callback(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
var suspectNode string
m.OnNodeSuspect(func(nodeID string) {
suspectNode = nodeID
})
ctx := context.Background()
// Drive 3 misses (DefaultSuspectAfter) → healthy → suspect
for i := 0; i < DefaultSuspectAfter; i++ {
m.updateState(ctx, "peer1", false)
}
if suspectNode != "peer1" {
t.Fatalf("expected suspect callback for peer1, got %q", suspectNode)
}
m.mu.Lock()
if m.peers["peer1"].status != "suspect" {
m.mu.Unlock()
t.Fatalf("expected suspect state, got %s", m.peers["peer1"].status)
}
m.mu.Unlock()
}
func TestOnNodeSuspect_DoesNotFireOnSubsequentMisses(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
var callCount int32
m.OnNodeSuspect(func(nodeID string) {
callCount++
})
ctx := context.Background()
// Drive to suspect (3 misses)
for i := 0; i < DefaultSuspectAfter; i++ {
m.updateState(ctx, "peer1", false)
}
if callCount != 1 {
t.Fatalf("expected suspect callback to fire once after 3 misses, got %d", callCount)
}
// Keep missing (4th through 11th miss) — should NOT fire suspect again
for i := 0; i < 8; i++ {
m.updateState(ctx, "peer1", false)
}
if callCount != 1 {
t.Fatalf("expected suspect callback to fire exactly once, got %d", callCount)
}
}
func TestRecoveredFromSuspect_Callback(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
var recoveredNode string
m.OnNodeRecovered(func(nodeID string) {
recoveredNode = nodeID
})
ctx := context.Background()
// Drive to suspect (3 misses, NOT dead)
for i := 0; i < DefaultSuspectAfter; i++ {
m.updateState(ctx, "peer1", false)
}
m.mu.Lock()
if m.peers["peer1"].status != "suspect" {
m.mu.Unlock()
t.Fatal("expected suspect state before recovery")
}
m.mu.Unlock()
// Recover from suspect
m.updateState(ctx, "peer1", true)
if recoveredNode != "peer1" {
t.Fatalf("expected recovery callback for peer1 after suspect, got %q", recoveredNode)
}
m.mu.Lock()
if m.peers["peer1"].status != "healthy" {
m.mu.Unlock()
t.Fatal("expected healthy after recovery from suspect")
}
m.mu.Unlock()
}