diff --git a/pkg/coredns/rqlite/setup.go b/pkg/coredns/rqlite/setup.go index 12b9382..abcb1c1 100644 --- a/pkg/coredns/rqlite/setup.go +++ b/pkg/coredns/rqlite/setup.go @@ -40,7 +40,7 @@ func parseConfig(c *caddy.Controller) (*RQLitePlugin, error) { var ( dsn = "http://localhost:5001" refreshRate = 10 * time.Second - cacheTTL = 60 * time.Second + cacheTTL = 30 * time.Second cacheSize = 10000 zones []string ) diff --git a/pkg/environments/production/installers/coredns.go b/pkg/environments/production/installers/coredns.go index 96291ce..348a447 100644 --- a/pkg/environments/production/installers/coredns.go +++ b/pkg/environments/production/installers/coredns.go @@ -333,7 +333,7 @@ func (ci *CoreDNSInstaller) generateCorefile(domain, rqliteDSN string) string { rqlite { dsn %s refresh 5s - ttl 60 + ttl 30 cache_size 10000 } diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 994b4a8..8d85382 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -580,12 +580,20 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { } }) gw.healthMonitor.OnNodeRecovered(func(nodeID string) { - logger.ComponentInfo(logging.ComponentGeneral, "Previously dead node recovered — checking for orphaned services", + logger.ComponentInfo(logging.ComponentGeneral, "Node recovered — re-enabling DNS and checking for orphaned services", zap.String("node_id", nodeID)) if gw.nodeRecoverer != nil { + go gw.nodeRecoverer.HandleSuspectRecovery(context.Background(), nodeID) go gw.nodeRecoverer.HandleRecoveredNode(context.Background(), nodeID) } }) + gw.healthMonitor.OnNodeSuspect(func(nodeID string) { + logger.ComponentWarn(logging.ComponentGeneral, "Node SUSPECT — disabling DNS records", + zap.String("suspect_node", nodeID)) + if gw.nodeRecoverer != nil { + go gw.nodeRecoverer.HandleSuspectNode(context.Background(), nodeID) + } + }) go gw.healthMonitor.Start(context.Background()) logger.ComponentInfo(logging.ComponentGeneral, "Node health monitor started", zap.String("node_id", cfg.NodePeerID)) diff --git a/pkg/gateway/handlers/auth/handlers.go b/pkg/gateway/handlers/auth/handlers.go index 4a3f4d6..beb4733 100644 --- a/pkg/gateway/handlers/auth/handlers.go +++ b/pkg/gateway/handlers/auth/handlers.go @@ -53,6 +53,8 @@ type ClusterProvisioner interface { type NodeRecoverer interface { HandleDeadNode(ctx context.Context, deadNodeID string) HandleRecoveredNode(ctx context.Context, nodeID string) + HandleSuspectNode(ctx context.Context, suspectNodeID string) + HandleSuspectRecovery(ctx context.Context, nodeID string) RepairCluster(ctx context.Context, namespaceName string) error } diff --git a/pkg/namespace/cluster_recovery.go b/pkg/namespace/cluster_recovery.go index 35607b3..b61e222 100644 --- a/pkg/namespace/cluster_recovery.go +++ b/pkg/namespace/cluster_recovery.go @@ -191,6 +191,148 @@ func (cm *ClusterManager) HandleRecoveredNode(ctx context.Context, nodeID string zap.Int("namespaces_cleaned", len(events))) } +// HandleSuspectNode disables DNS records for a suspect node to prevent traffic +// from being routed to it. Called early (T+30s) when the node first becomes suspect, +// before confirming it's actually dead. If the node recovers, HandleSuspectRecovery +// will re-enable the records. +// +// Safety: never disables the last active record for a namespace. +func (cm *ClusterManager) HandleSuspectNode(ctx context.Context, suspectNodeID string) { + cm.logger.Warn("Handling suspect node — disabling DNS records", + zap.String("suspect_node", suspectNodeID), + ) + + // Acquire per-node lock to prevent concurrent suspect handling + suspectKey := "suspect:" + suspectNodeID + cm.provisioningMu.Lock() + if cm.provisioning[suspectKey] { + cm.provisioningMu.Unlock() + cm.logger.Info("Suspect handling already in progress for node, skipping", + zap.String("node_id", suspectNodeID)) + return + } + cm.provisioning[suspectKey] = true + cm.provisioningMu.Unlock() + defer func() { + cm.provisioningMu.Lock() + delete(cm.provisioning, suspectKey) + cm.provisioningMu.Unlock() + }() + + // Find all clusters this node belongs to + clusters, err := cm.getClustersByNodeID(ctx, suspectNodeID) + if err != nil { + cm.logger.Warn("Failed to find clusters for suspect node", + zap.String("suspect_node", suspectNodeID), zap.Error(err)) + return + } + + if len(clusters) == 0 { + cm.logger.Info("Suspect node has no namespace cluster assignments", + zap.String("suspect_node", suspectNodeID)) + return + } + + // Get suspect node's public IP (DNS A records contain public IPs) + ips, err := cm.getNodeIPs(ctx, suspectNodeID) + if err != nil { + cm.logger.Warn("Failed to get suspect node IPs", + zap.String("suspect_node", suspectNodeID), zap.Error(err)) + return + } + + dnsManager := NewDNSRecordManager(cm.db, cm.baseDomain, cm.logger) + disabledCount := 0 + + for _, cluster := range clusters { + // Safety check: never disable the last active record + activeCount, err := dnsManager.CountActiveNamespaceRecords(ctx, cluster.NamespaceName) + if err != nil { + cm.logger.Warn("Failed to count active DNS records, skipping namespace", + zap.String("namespace", cluster.NamespaceName), + zap.Error(err)) + continue + } + + if activeCount <= 1 { + cm.logger.Warn("Not disabling DNS — would leave namespace with no active records", + zap.String("namespace", cluster.NamespaceName), + zap.String("suspect_node", suspectNodeID), + zap.Int("active_records", activeCount)) + continue + } + + if err := dnsManager.DisableNamespaceRecord(ctx, cluster.NamespaceName, ips.IPAddress); err != nil { + cm.logger.Warn("Failed to disable DNS record for suspect node", + zap.String("namespace", cluster.NamespaceName), + zap.String("ip", ips.IPAddress), + zap.Error(err)) + continue + } + + disabledCount++ + cm.logger.Info("Disabled DNS record for suspect node", + zap.String("namespace", cluster.NamespaceName), + zap.String("ip", ips.IPAddress)) + } + + cm.logger.Info("Suspect node DNS handling completed", + zap.String("suspect_node", suspectNodeID), + zap.Int("namespaces_affected", len(clusters)), + zap.Int("records_disabled", disabledCount)) +} + +// HandleSuspectRecovery re-enables DNS records for a node that recovered from +// suspect state without going dead. Called when the health monitor detects +// that a previously suspect node is responding to probes again. +func (cm *ClusterManager) HandleSuspectRecovery(ctx context.Context, nodeID string) { + cm.logger.Info("Handling suspect recovery — re-enabling DNS records", + zap.String("node_id", nodeID), + ) + + // Find all clusters this node belongs to + clusters, err := cm.getClustersByNodeID(ctx, nodeID) + if err != nil { + cm.logger.Warn("Failed to find clusters for recovered node", + zap.String("node_id", nodeID), zap.Error(err)) + return + } + + if len(clusters) == 0 { + return + } + + // Get node's public IP (DNS A records contain public IPs) + ips, err := cm.getNodeIPs(ctx, nodeID) + if err != nil { + cm.logger.Warn("Failed to get recovered node IPs", + zap.String("node_id", nodeID), zap.Error(err)) + return + } + + dnsManager := NewDNSRecordManager(cm.db, cm.baseDomain, cm.logger) + enabledCount := 0 + + for _, cluster := range clusters { + if err := dnsManager.EnableNamespaceRecord(ctx, cluster.NamespaceName, ips.IPAddress); err != nil { + cm.logger.Warn("Failed to re-enable DNS record for recovered node", + zap.String("namespace", cluster.NamespaceName), + zap.String("ip", ips.IPAddress), + zap.Error(err)) + continue + } + + enabledCount++ + cm.logger.Info("Re-enabled DNS record for recovered node", + zap.String("namespace", cluster.NamespaceName), + zap.String("ip", ips.IPAddress)) + } + + cm.logger.Info("Suspect recovery DNS handling completed", + zap.String("node_id", nodeID), + zap.Int("records_enabled", enabledCount)) +} + // ReplaceClusterNode replaces a dead node in a specific namespace cluster. // It selects a new node, allocates ports, spawns services, updates DNS, and cleans up. func (cm *ClusterManager) ReplaceClusterNode(ctx context.Context, cluster *NamespaceCluster, deadNodeID string) error { diff --git a/pkg/namespace/dns_manager.go b/pkg/namespace/dns_manager.go index 43f0f9c..d14df8b 100644 --- a/pkg/namespace/dns_manager.go +++ b/pkg/namespace/dns_manager.go @@ -182,6 +182,34 @@ func (drm *DNSRecordManager) GetNamespaceGatewayIPs(ctx context.Context, namespa return ips, nil } +// CountActiveNamespaceRecords returns the number of active A records for a namespace's main FQDN. +// Used as a safety check before disabling records to prevent disabling the last one. +func (drm *DNSRecordManager) CountActiveNamespaceRecords(ctx context.Context, namespaceName string) (int, error) { + internalCtx := client.WithInternalAuth(ctx) + + fqdn := fmt.Sprintf("ns-%s.%s.", namespaceName, drm.baseDomain) + + type countResult struct { + Count int `db:"count"` + } + + var results []countResult + query := `SELECT COUNT(*) as count FROM dns_records WHERE fqdn = ? AND record_type = 'A' AND is_active = TRUE` + err := drm.db.Query(internalCtx, &results, query, fqdn) + if err != nil { + return 0, &ClusterError{ + Message: "failed to count active namespace DNS records", + Cause: err, + } + } + + if len(results) == 0 { + return 0, nil + } + + return results[0].Count, nil +} + // AddNamespaceRecord adds DNS A records for a single IP to an existing namespace. // Unlike CreateNamespaceRecords, this does NOT delete existing records — it's purely additive. // Used when adding a new node to an under-provisioned cluster (repair). @@ -239,7 +267,7 @@ func (drm *DNSRecordManager) UpdateNamespaceRecord(ctx context.Context, namespac // Update both the main record and wildcard record for _, f := range []string{fqdn, wildcardFqdn} { - updateQuery := `UPDATE dns_records SET value = ?, updated_at = ? WHERE fqdn = ? AND value = ?` + updateQuery := `UPDATE dns_records SET value = ?, is_active = TRUE, updated_at = ? WHERE fqdn = ? AND value = ?` _, err := drm.db.Exec(internalCtx, updateQuery, newIP, time.Now(), f, oldIP) if err != nil { drm.logger.Warn("Failed to update DNS record", diff --git a/pkg/namespace/dns_manager_test.go b/pkg/namespace/dns_manager_test.go index 3fe682d..9b0d7e2 100644 --- a/pkg/namespace/dns_manager_test.go +++ b/pkg/namespace/dns_manager_test.go @@ -1,7 +1,9 @@ package namespace import ( + "context" "fmt" + "strings" "testing" "go.uber.org/zap" @@ -215,3 +217,60 @@ func TestDNSRecordManager_FQDNWithTrailingDot(t *testing.T) { }) } } + +func TestUpdateNamespaceRecord_SetsActiveTrue(t *testing.T) { + mockDB := newMockRQLiteClient() + logger := zap.NewNop() + manager := NewDNSRecordManager(mockDB, "orama-devnet.network", logger) + + ctx := context.Background() + err := manager.UpdateNamespaceRecord(ctx, "alice", "1.2.3.4", "5.6.7.8") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the SQL contains is_active = TRUE for both FQDN and wildcard + activeCount := 0 + for _, call := range mockDB.execCalls { + if strings.Contains(call.Query, "is_active = TRUE") && strings.Contains(call.Query, "UPDATE dns_records") { + activeCount++ + } + } + if activeCount != 2 { + t.Fatalf("expected 2 UPDATE queries with is_active = TRUE (fqdn + wildcard), got %d", activeCount) + } +} + +func TestCountActiveNamespaceRecords(t *testing.T) { + mockDB := newMockRQLiteClient() + logger := zap.NewNop() + manager := NewDNSRecordManager(mockDB, "orama-devnet.network", logger) + + ctx := context.Background() + count, err := manager.CountActiveNamespaceRecords(ctx, "alice") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // With mock returning empty results, count should be 0 + if count != 0 { + t.Fatalf("expected 0, got %d", count) + } + + // Verify the correct query was made + if len(mockDB.queryCalls) == 0 { + t.Fatal("expected a query call") + } + lastCall := mockDB.queryCalls[len(mockDB.queryCalls)-1] + if !strings.Contains(lastCall.Query, "COUNT(*)") || !strings.Contains(lastCall.Query, "is_active = TRUE") { + t.Fatalf("unexpected query: %s", lastCall.Query) + } + // Verify the FQDN arg + expectedFQDN := "ns-alice.orama-devnet.network." + if len(lastCall.Args) == 0 { + t.Fatal("expected query args") + } + if fqdn, ok := lastCall.Args[0].(string); !ok || fqdn != expectedFQDN { + t.Fatalf("expected FQDN arg %q, got %v", expectedFQDN, lastCall.Args[0]) + } +} diff --git a/pkg/node/health/monitor.go b/pkg/node/health/monitor.go index c756fa9..15a5329 100644 --- a/pkg/node/health/monitor.go +++ b/pkg/node/health/monitor.go @@ -83,7 +83,8 @@ type Monitor struct { peers map[string]*peerState // nodeID → state onDeadFn func(nodeID string) // callback when quorum confirms death - onRecoveredFn func(nodeID string) // callback when node transitions from dead → healthy + onRecoveredFn func(nodeID string) // callback when node transitions from suspect/dead → healthy + onSuspectFn func(nodeID string) // callback when node transitions healthy → suspect } // NewMonitor creates a new health monitor. @@ -121,12 +122,19 @@ func (m *Monitor) OnNodeDead(fn func(nodeID string)) { m.onDeadFn = fn } -// OnNodeRecovered registers a callback invoked when a previously dead node -// transitions back to healthy. The callback runs with the monitor lock released. +// OnNodeRecovered registers a callback invoked when a previously suspect or dead +// node transitions back to healthy. The callback runs with the monitor lock released. func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) { m.onRecoveredFn = fn } +// OnNodeSuspect registers a callback invoked when a node transitions from +// healthy to suspect (3 consecutive missed probes). The callback runs with +// the monitor lock released. +func (m *Monitor) OnNodeSuspect(fn func(nodeID string)) { + m.onSuspectFn = fn +} + // Start runs the monitor loop until ctx is cancelled. func (m *Monitor) Start(ctx context.Context) { m.logger.Info("Starting node health monitor", @@ -232,8 +240,8 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) } if healthy { - wasDead := ps.status == "dead" - shouldCallback := wasDead && m.onRecoveredFn != nil + wasUnhealthy := ps.status == "suspect" || ps.status == "dead" + shouldCallback := wasUnhealthy && m.onRecoveredFn != nil prevStatus := ps.status // Update state BEFORE releasing lock (C2 fix) @@ -299,6 +307,7 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy": ps.status = "suspect" ps.suspectAt = time.Now() + shouldCallSuspect := m.onSuspectFn != nil m.mu.Unlock() m.logger.Warn("Node SUSPECT", @@ -306,6 +315,10 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) zap.Int("misses", ps.missCount), ) m.writeEvent(ctx, nodeID, "suspect") + + if shouldCallSuspect { + m.onSuspectFn(nodeID) + } return } diff --git a/pkg/node/health/monitor_test.go b/pkg/node/health/monitor_test.go index b09ffff..9cd1cf4 100644 --- a/pkg/node/health/monitor_test.go +++ b/pkg/node/health/monitor_test.go @@ -521,3 +521,114 @@ func TestRecoveryCallback_InvokedWithoutLock(t *testing.T) { } m.mu.Unlock() } + +// --------------------------------------------------------------- +// OnNodeSuspect callback +// --------------------------------------------------------------- + +func TestOnNodeSuspect_Callback(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + Neighbors: 3, + StartupGracePeriod: 1 * time.Millisecond, + }) + time.Sleep(2 * time.Millisecond) + + var suspectNode string + m.OnNodeSuspect(func(nodeID string) { + suspectNode = nodeID + }) + + ctx := context.Background() + + // Drive 3 misses (DefaultSuspectAfter) → healthy → suspect + for i := 0; i < DefaultSuspectAfter; i++ { + m.updateState(ctx, "peer1", false) + } + + if suspectNode != "peer1" { + t.Fatalf("expected suspect callback for peer1, got %q", suspectNode) + } + + m.mu.Lock() + if m.peers["peer1"].status != "suspect" { + m.mu.Unlock() + t.Fatalf("expected suspect state, got %s", m.peers["peer1"].status) + } + m.mu.Unlock() +} + +func TestOnNodeSuspect_DoesNotFireOnSubsequentMisses(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + Neighbors: 3, + StartupGracePeriod: 1 * time.Millisecond, + }) + time.Sleep(2 * time.Millisecond) + + var callCount int32 + m.OnNodeSuspect(func(nodeID string) { + callCount++ + }) + + ctx := context.Background() + + // Drive to suspect (3 misses) + for i := 0; i < DefaultSuspectAfter; i++ { + m.updateState(ctx, "peer1", false) + } + if callCount != 1 { + t.Fatalf("expected suspect callback to fire once after 3 misses, got %d", callCount) + } + + // Keep missing (4th through 11th miss) — should NOT fire suspect again + for i := 0; i < 8; i++ { + m.updateState(ctx, "peer1", false) + } + + if callCount != 1 { + t.Fatalf("expected suspect callback to fire exactly once, got %d", callCount) + } +} + +func TestRecoveredFromSuspect_Callback(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + Neighbors: 3, + StartupGracePeriod: 1 * time.Millisecond, + }) + time.Sleep(2 * time.Millisecond) + + var recoveredNode string + m.OnNodeRecovered(func(nodeID string) { + recoveredNode = nodeID + }) + + ctx := context.Background() + + // Drive to suspect (3 misses, NOT dead) + for i := 0; i < DefaultSuspectAfter; i++ { + m.updateState(ctx, "peer1", false) + } + + m.mu.Lock() + if m.peers["peer1"].status != "suspect" { + m.mu.Unlock() + t.Fatal("expected suspect state before recovery") + } + m.mu.Unlock() + + // Recover from suspect + m.updateState(ctx, "peer1", true) + + if recoveredNode != "peer1" { + t.Fatalf("expected recovery callback for peer1 after suspect, got %q", recoveredNode) + } + + m.mu.Lock() + if m.peers["peer1"].status != "healthy" { + m.mu.Unlock() + t.Fatal("expected healthy after recovery from suspect") + } + m.mu.Unlock() +}