From 749d5ed5e7ada48701409c50d952a7ac144f6530 Mon Sep 17 00:00:00 2001 From: anonpenguin23 Date: Sat, 14 Feb 2026 10:56:26 +0200 Subject: [PATCH] Bro i did so many things to fix the problematic discovery and redeployment and i dont even remember what i did --- pkg/cli/production/commands.go | 8 + pkg/client/client.go | 25 ++ pkg/client/config.go | 9 +- pkg/client/identity_test.go | 92 ++++++ pkg/config/database_config.go | 7 + pkg/discovery/discovery.go | 44 ++- pkg/discovery/metadata_publisher.go | 81 ++++++ pkg/discovery/rqlite_metadata.go | 93 +++++- pkg/discovery/rqlite_metadata_test.go | 235 +++++++++++++++ pkg/encryption/wallet_keygen.go | 194 +++++++++++++ pkg/encryption/wallet_keygen_test.go | 202 +++++++++++++ pkg/gateway/auth/jwt.go | 166 ++++++++--- pkg/gateway/auth/service.go | 23 +- pkg/gateway/auth/service_test.go | 252 ++++++++++++++++ pkg/gateway/dependencies.go | 13 +- pkg/gateway/namespace_health.go | 2 +- pkg/gateway/signing_key.go | 55 ++++ pkg/namespace/cluster_recovery.go | 36 ++- pkg/node/health/monitor.go | 136 +++++++-- pkg/node/health/monitor_test.go | 215 +++++++++++++- pkg/node/lifecycle/manager.go | 184 ++++++++++++ pkg/node/lifecycle/manager_test.go | 320 +++++++++++++++++++++ pkg/node/node.go | 31 +- pkg/node/rqlite.go | 1 + pkg/rqlite/cluster.go | 29 +- pkg/rqlite/cluster_discovery.go | 34 +++ pkg/rqlite/cluster_discovery_membership.go | 121 ++++++-- pkg/rqlite/cluster_discovery_queries.go | 67 ++++- pkg/rqlite/instance_spawner.go | 8 + pkg/rqlite/leadership.go | 131 +++++++++ pkg/rqlite/process.go | 47 +++ pkg/rqlite/rqlite.go | 17 +- pkg/rqlite/voter_reconciliation.go | 191 ++++++++++-- pkg/rqlite/watchdog.go | 42 ++- scripts/clean-testnet.sh | 208 ++++++++++++++ scripts/recover-rqlite.sh | 289 +++++++++++++++++++ scripts/redeploy.sh | 152 ++++++++-- 37 files changed, 3559 insertions(+), 201 deletions(-) create mode 100644 pkg/client/identity_test.go create mode 100644 pkg/discovery/metadata_publisher.go create mode 100644 pkg/discovery/rqlite_metadata_test.go create mode 100644 pkg/encryption/wallet_keygen.go create mode 100644 pkg/encryption/wallet_keygen_test.go create mode 100644 pkg/node/lifecycle/manager.go create mode 100644 pkg/node/lifecycle/manager_test.go create mode 100644 pkg/rqlite/leadership.go create mode 100755 scripts/clean-testnet.sh create mode 100644 scripts/recover-rqlite.sh diff --git a/pkg/cli/production/commands.go b/pkg/cli/production/commands.go index 6230757..b961520 100644 --- a/pkg/cli/production/commands.go +++ b/pkg/cli/production/commands.go @@ -43,6 +43,10 @@ func HandleCommand(args []string) { case "restart": force := hasFlag(subargs, "--force") lifecycle.HandleRestartWithFlags(force) + case "pre-upgrade": + lifecycle.HandlePreUpgrade() + case "post-upgrade": + lifecycle.HandlePostUpgrade() case "logs": logs.Handle(subargs) case "uninstall": @@ -105,6 +109,10 @@ func ShowHelp() { fmt.Printf(" restart - Restart all production services (requires root/sudo)\n") fmt.Printf(" Options:\n") fmt.Printf(" --force - Bypass quorum safety check\n") + fmt.Printf(" pre-upgrade - Prepare node for safe restart (requires root/sudo)\n") + fmt.Printf(" Transfers leadership, enters maintenance mode, waits for propagation\n") + fmt.Printf(" post-upgrade - Bring node back online after restart (requires root/sudo)\n") + fmt.Printf(" Starts services, verifies RQLite health, exits maintenance\n") fmt.Printf(" logs - View production service logs\n") fmt.Printf(" Service aliases: node, ipfs, cluster, gateway, olric\n") fmt.Printf(" Options:\n") diff --git a/pkg/client/client.go b/pkg/client/client.go index cdc92cc..3710063 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -19,6 +19,7 @@ import ( libp2ppubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/DeBrosOfficial/network/pkg/encryption" "github.com/DeBrosOfficial/network/pkg/pubsub" ) @@ -144,6 +145,30 @@ func (c *Client) Connect() error { libp2p.DefaultMuxers, ) opts = append(opts, libp2p.Transport(tcp.NewTCPTransport)) + + // Load or create persistent identity if IdentityPath is configured + if c.config.IdentityPath != "" { + identity, loadErr := encryption.LoadIdentity(c.config.IdentityPath) + if loadErr != nil { + // File doesn't exist yet — generate and save + identity, loadErr = encryption.GenerateIdentity() + if loadErr != nil { + return fmt.Errorf("failed to generate identity: %w", loadErr) + } + if saveErr := encryption.SaveIdentity(identity, c.config.IdentityPath); saveErr != nil { + return fmt.Errorf("failed to save identity: %w", saveErr) + } + c.logger.Info("Generated new persistent identity", + zap.String("peer_id", identity.PeerID.String()), + zap.String("path", c.config.IdentityPath)) + } else { + c.logger.Info("Loaded persistent identity", + zap.String("peer_id", identity.PeerID.String()), + zap.String("path", c.config.IdentityPath)) + } + opts = append(opts, libp2p.Identity(identity.PrivateKey)) + } + // Enable QUIC only when not proxying. When proxy is enabled, prefer TCP via SOCKS5. h, err := libp2p.New(opts...) if err != nil { diff --git a/pkg/client/config.go b/pkg/client/config.go index 48c36ae..acbbb44 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -11,13 +11,14 @@ type ClientConfig struct { DatabaseName string `json:"database_name"` BootstrapPeers []string `json:"peers"` DatabaseEndpoints []string `json:"database_endpoints"` - GatewayURL string `json:"gateway_url"` // Gateway URL for HTTP API access + GatewayURL string `json:"gateway_url"` // Gateway URL for HTTP API access ConnectTimeout time.Duration `json:"connect_timeout"` RetryAttempts int `json:"retry_attempts"` RetryDelay time.Duration `json:"retry_delay"` - QuietMode bool `json:"quiet_mode"` // Suppress debug/info logs - APIKey string `json:"api_key"` // API key for gateway auth - JWT string `json:"jwt"` // Optional JWT bearer token + QuietMode bool `json:"quiet_mode"` // Suppress debug/info logs + APIKey string `json:"api_key"` // API key for gateway auth + JWT string `json:"jwt"` // Optional JWT bearer token + IdentityPath string `json:"identity_path"` // Path to persistent LibP2P identity key file } // DefaultClientConfig returns a default client configuration diff --git a/pkg/client/identity_test.go b/pkg/client/identity_test.go new file mode 100644 index 0000000..e00789b --- /dev/null +++ b/pkg/client/identity_test.go @@ -0,0 +1,92 @@ +package client + +import ( + "os" + "path/filepath" + "testing" + + "github.com/DeBrosOfficial/network/pkg/encryption" +) + +func TestPersistentIdentity_NoPath(t *testing.T) { + // Without IdentityPath, Connect() generates a random ID each time. + // We can't easily test Connect() (needs network), so verify config defaults. + cfg := DefaultClientConfig("test-app") + if cfg.IdentityPath != "" { + t.Fatalf("expected empty IdentityPath by default, got %q", cfg.IdentityPath) + } +} + +func TestPersistentIdentity_GenerateAndReload(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "identity.key") + + // 1. No file exists — generate + save + id1, err := encryption.GenerateIdentity() + if err != nil { + t.Fatalf("GenerateIdentity: %v", err) + } + if err := encryption.SaveIdentity(id1, keyPath); err != nil { + t.Fatalf("SaveIdentity: %v", err) + } + + // File should exist + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Fatal("identity key file was not created") + } + + // 2. Load it back — same PeerID + id2, err := encryption.LoadIdentity(keyPath) + if err != nil { + t.Fatalf("LoadIdentity: %v", err) + } + if id1.PeerID != id2.PeerID { + t.Fatalf("PeerID mismatch: generated %s, loaded %s", id1.PeerID, id2.PeerID) + } + + // 3. Load again — still the same + id3, err := encryption.LoadIdentity(keyPath) + if err != nil { + t.Fatalf("LoadIdentity (second): %v", err) + } + if id2.PeerID != id3.PeerID { + t.Fatalf("PeerID changed across loads: %s vs %s", id2.PeerID, id3.PeerID) + } +} + +func TestPersistentIdentity_DifferentFromRandom(t *testing.T) { + // A persistent identity should be different from a freshly generated one + id1, err := encryption.GenerateIdentity() + if err != nil { + t.Fatalf("GenerateIdentity: %v", err) + } + id2, err := encryption.GenerateIdentity() + if err != nil { + t.Fatalf("GenerateIdentity: %v", err) + } + if id1.PeerID == id2.PeerID { + t.Fatal("two independently generated identities should have different PeerIDs") + } +} + +func TestPersistentIdentity_FilePermissions(t *testing.T) { + dir := t.TempDir() + keyPath := filepath.Join(dir, "subdir", "identity.key") + + id, err := encryption.GenerateIdentity() + if err != nil { + t.Fatalf("GenerateIdentity: %v", err) + } + if err := encryption.SaveIdentity(id, keyPath); err != nil { + t.Fatalf("SaveIdentity: %v", err) + } + + info, err := os.Stat(keyPath) + if err != nil { + t.Fatalf("Stat: %v", err) + } + perm := info.Mode().Perm() + if perm != 0600 { + t.Fatalf("expected file permissions 0600, got %o", perm) + } +} diff --git a/pkg/config/database_config.go b/pkg/config/database_config.go index 3898503..c8ea6e6 100644 --- a/pkg/config/database_config.go +++ b/pkg/config/database_config.go @@ -22,6 +22,13 @@ type DatabaseConfig struct { NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set) NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs) + // Raft tuning (passed through to rqlited CLI flags). + // Higher defaults than rqlited's 1s suit WireGuard latency. + RaftElectionTimeout time.Duration `yaml:"raft_election_timeout"` // default: 5s + RaftHeartbeatTimeout time.Duration `yaml:"raft_heartbeat_timeout"` // default: 2s + RaftApplyTimeout time.Duration `yaml:"raft_apply_timeout"` // default: 30s + RaftLeaderLeaseTimeout time.Duration `yaml:"raft_leader_lease_timeout"` // default: 5s + // Dynamic discovery configuration (always enabled) ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h diff --git a/pkg/discovery/discovery.go b/pkg/discovery/discovery.go index e3e06b3..4eb4fa4 100644 --- a/pkg/discovery/discovery.go +++ b/pkg/discovery/discovery.go @@ -7,6 +7,7 @@ import ( "io" "strconv" "strings" + "sync" "time" "github.com/libp2p/go-libp2p/core/host" @@ -77,10 +78,14 @@ type PeerInfo struct { // interface{} to remain source-compatible with previous call sites that // passed a DHT instance. The value is ignored. type Manager struct { - host host.Host - logger *zap.Logger - cancel context.CancelFunc - failedPeerExchanges map[peer.ID]time.Time // Track failed peer exchange attempts to suppress repeated warnings + host host.Host + logger *zap.Logger + cancel context.CancelFunc + + // failedMu protects failedPeerExchanges from concurrent access during + // parallel peer exchange dials (H3 fix). + failedMu sync.Mutex + failedPeerExchanges map[peer.ID]time.Time } // Config contains discovery configuration @@ -364,8 +369,8 @@ func (d *Manager) discoverViaPeerExchange(ctx context.Context, maxConnections in // Add to peerstore (only valid addresses with port 4001) d.host.Peerstore().AddAddrs(parsedID, addrs, time.Hour*24) - // Try to connect - connectCtx, cancel := context.WithTimeout(ctx, 20*time.Second) + // Try to connect (5s timeout — WireGuard peers respond fast) + connectCtx, cancel := context.WithTimeout(ctx, 5*time.Second) peerAddrInfo := peer.AddrInfo{ID: parsedID, Addrs: addrs} if err := d.host.Connect(connectCtx, peerAddrInfo); err != nil { @@ -401,15 +406,15 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi // Open a stream to the peer stream, err := d.host.NewStream(ctx, peerID, PeerExchangeProtocol) if err != nil { - // Check if this is a "protocols not supported" error (expected for lightweight clients like gateway) + d.failedMu.Lock() if strings.Contains(err.Error(), "protocols not supported") { - // This is a lightweight client (gateway, etc.) that doesn't support peer exchange - expected behavior - // Track it to avoid repeated attempts, but don't log as it's not an error + // Lightweight client (gateway, etc.) — expected, track to suppress retries d.failedPeerExchanges[peerID] = time.Now() + d.failedMu.Unlock() return nil } - // For actual connection errors, log but suppress repeated warnings for the same peer + // Actual connection error — log but suppress repeated warnings lastFailure, seen := d.failedPeerExchanges[peerID] if !seen || time.Since(lastFailure) > time.Minute { d.logger.Debug("Failed to open peer exchange stream with node", @@ -418,12 +423,15 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi zap.Error(err)) d.failedPeerExchanges[peerID] = time.Now() } + d.failedMu.Unlock() return nil } defer stream.Close() // Clear failure tracking on success + d.failedMu.Lock() delete(d.failedPeerExchanges, peerID) + d.failedMu.Unlock() // Send request req := PeerExchangeRequest{Limit: limit} @@ -433,8 +441,8 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi return nil } - // Set read deadline - if err := stream.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil { + // Set read deadline (5s — small JSON payload) + if err := stream.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { d.logger.Debug("Failed to set read deadline", zap.Error(err)) return nil } @@ -451,10 +459,20 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi // Store remote peer's RQLite metadata if available if resp.RQLiteMetadata != nil { + // Verify sender identity — prevent metadata spoofing (H2 fix). + // If the metadata contains a PeerID, it must match the stream sender. + if resp.RQLiteMetadata.PeerID != "" && resp.RQLiteMetadata.PeerID != peerID.String() { + d.logger.Warn("Rejected metadata: PeerID mismatch", + zap.String("claimed", resp.RQLiteMetadata.PeerID[:8]+"..."), + zap.String("actual", peerID.String()[:8]+"...")) + return resp.Peers + } + // Stamp verified PeerID so downstream consumers can trust it + resp.RQLiteMetadata.PeerID = peerID.String() + metadataJSON, err := json.Marshal(resp.RQLiteMetadata) if err == nil { _ = d.host.Peerstore().Put(peerID, "rqlite_metadata", metadataJSON) - // Only log when new metadata is stored (useful for debugging) d.logger.Debug("Metadata stored", zap.String("peer", peerID.String()[:8]+"..."), zap.String("node", resp.RQLiteMetadata.NodeID)) diff --git a/pkg/discovery/metadata_publisher.go b/pkg/discovery/metadata_publisher.go new file mode 100644 index 0000000..9f8eaf6 --- /dev/null +++ b/pkg/discovery/metadata_publisher.go @@ -0,0 +1,81 @@ +package discovery + +import ( + "context" + "encoding/json" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "go.uber.org/zap" +) + +// MetadataProvider is implemented by subsystems that can supply node metadata. +// The publisher calls Provide() every cycle and stores the result in the peerstore. +type MetadataProvider interface { + ProvideMetadata() *RQLiteNodeMetadata +} + +// MetadataPublisher periodically writes local node metadata to the peerstore so +// it is included in every peer exchange response. This decouples metadata +// production (lifecycle, RQLite status, service health) from the exchange +// protocol itself. +type MetadataPublisher struct { + host host.Host + provider MetadataProvider + interval time.Duration + logger *zap.Logger +} + +// NewMetadataPublisher creates a publisher that writes metadata every interval. +func NewMetadataPublisher(h host.Host, provider MetadataProvider, interval time.Duration, logger *zap.Logger) *MetadataPublisher { + if interval <= 0 { + interval = 10 * time.Second + } + return &MetadataPublisher{ + host: h, + provider: provider, + interval: interval, + logger: logger.With(zap.String("component", "metadata-publisher")), + } +} + +// Start begins the periodic publish loop. It blocks until ctx is cancelled. +func (p *MetadataPublisher) Start(ctx context.Context) { + // Publish immediately on start + p.publish() + + ticker := time.NewTicker(p.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.publish() + } + } +} + +// PublishNow performs a single immediate metadata publish. +// Useful after lifecycle transitions or other state changes. +func (p *MetadataPublisher) PublishNow() { + p.publish() +} + +func (p *MetadataPublisher) publish() { + meta := p.provider.ProvideMetadata() + if meta == nil { + return + } + + data, err := json.Marshal(meta) + if err != nil { + p.logger.Error("Failed to marshal metadata", zap.Error(err)) + return + } + + if err := p.host.Peerstore().Put(p.host.ID(), "rqlite_metadata", data); err != nil { + p.logger.Error("Failed to store metadata in peerstore", zap.Error(err)) + } +} diff --git a/pkg/discovery/rqlite_metadata.go b/pkg/discovery/rqlite_metadata.go index e70d263..0f4d3bd 100644 --- a/pkg/discovery/rqlite_metadata.go +++ b/pkg/discovery/rqlite_metadata.go @@ -4,18 +4,101 @@ import ( "time" ) -// RQLiteNodeMetadata contains RQLite-specific information announced via LibP2P +// ServiceStatus represents the health of an individual service on a node. +type ServiceStatus struct { + Name string `json:"name"` // e.g. "rqlite", "gateway", "olric" + Running bool `json:"running"` // whether the process is up + Healthy bool `json:"healthy"` // whether it passed its health check + Message string `json:"message,omitempty"` // optional detail ("leader", "follower", etc.) +} + +// NamespaceStatus represents a namespace's status on a node. +type NamespaceStatus struct { + Name string `json:"name"` + Status string `json:"status"` // "healthy", "degraded", "recovering" +} + +// RQLiteNodeMetadata contains node information announced via LibP2P peer exchange. +// This struct is the single source of truth for node metadata propagated through +// the cluster. Go's json.Unmarshal silently ignores unknown fields, so old nodes +// reading metadata from new nodes simply skip the new fields — no protocol +// version change is needed. type RQLiteNodeMetadata struct { - NodeID string `json:"node_id"` // RQLite node ID (from config) - RaftAddress string `json:"raft_address"` // Raft port address (e.g., "51.83.128.181:7001") - HTTPAddress string `json:"http_address"` // HTTP API address (e.g., "51.83.128.181:5001") + // --- Existing fields (unchanged) --- + + NodeID string `json:"node_id"` // RQLite node ID (raft address) + RaftAddress string `json:"raft_address"` // Raft port address (e.g., "10.0.0.1:7001") + HTTPAddress string `json:"http_address"` // HTTP API address (e.g., "10.0.0.1:5001") NodeType string `json:"node_type"` // Node type identifier RaftLogIndex uint64 `json:"raft_log_index"` // Current Raft log index (for data comparison) LastSeen time.Time `json:"last_seen"` // Updated on every announcement ClusterVersion string `json:"cluster_version"` // For compatibility checking + + // --- New: Identity --- + + // PeerID is the LibP2P peer ID of the node. Used for metadata authentication: + // on receipt, the receiver verifies PeerID == stream sender to prevent spoofing. + PeerID string `json:"peer_id,omitempty"` + + // WireGuardIP is the node's WireGuard VPN address (e.g., "10.0.0.1"). + WireGuardIP string `json:"wireguard_ip,omitempty"` + + // --- New: Lifecycle --- + + // LifecycleState is the node's current lifecycle state: + // "joining", "active", "draining", or "maintenance". + // Zero value (empty string) from old nodes is treated as "active". + LifecycleState string `json:"lifecycle_state,omitempty"` + + // MaintenanceTTL is the time at which maintenance mode expires. + // Only meaningful when LifecycleState == "maintenance". + MaintenanceTTL time.Time `json:"maintenance_ttl,omitempty"` + + // --- New: Services --- + + // Services reports the status of each service running on the node. + Services map[string]*ServiceStatus `json:"services,omitempty"` + + // Namespaces reports the status of each namespace on the node. + Namespaces map[string]*NamespaceStatus `json:"namespaces,omitempty"` + + // --- New: Version --- + + // BinaryVersion is the node's binary version string (e.g., "1.2.3"). + BinaryVersion string `json:"binary_version,omitempty"` } -// PeerExchangeResponseV2 extends the original response with RQLite metadata +// EffectiveLifecycleState returns the lifecycle state, defaulting to "active" +// for old nodes that don't populate the field. +func (m *RQLiteNodeMetadata) EffectiveLifecycleState() string { + if m.LifecycleState == "" { + return "active" + } + return m.LifecycleState +} + +// IsInMaintenance returns true if the node has announced maintenance mode. +func (m *RQLiteNodeMetadata) IsInMaintenance() bool { + return m.EffectiveLifecycleState() == "maintenance" +} + +// IsAvailable returns true if the node is in a state that can serve requests. +func (m *RQLiteNodeMetadata) IsAvailable() bool { + return m.EffectiveLifecycleState() == "active" +} + +// IsMaintenanceExpired returns true if the node is in maintenance and the +// TTL has passed. Used by the leader's health monitor to enforce expiry. +func (m *RQLiteNodeMetadata) IsMaintenanceExpired() bool { + if !m.IsInMaintenance() { + return false + } + return !m.MaintenanceTTL.IsZero() && time.Now().After(m.MaintenanceTTL) +} + +// PeerExchangeResponseV2 extends the original response with RQLite metadata. +// Kept for backward compatibility — the V1 PeerExchangeResponse in discovery.go +// already includes the same RQLiteMetadata field, so this is effectively unused. type PeerExchangeResponseV2 struct { Peers []PeerInfo `json:"peers"` RQLiteMetadata *RQLiteNodeMetadata `json:"rqlite_metadata,omitempty"` diff --git a/pkg/discovery/rqlite_metadata_test.go b/pkg/discovery/rqlite_metadata_test.go new file mode 100644 index 0000000..13f5e75 --- /dev/null +++ b/pkg/discovery/rqlite_metadata_test.go @@ -0,0 +1,235 @@ +package discovery + +import ( + "encoding/json" + "testing" + "time" +) + +func TestEffectiveLifecycleState(t *testing.T) { + tests := []struct { + name string + state string + want string + }{ + {"empty defaults to active", "", "active"}, + {"explicit active", "active", "active"}, + {"joining", "joining", "joining"}, + {"maintenance", "maintenance", "maintenance"}, + {"draining", "draining", "draining"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &RQLiteNodeMetadata{LifecycleState: tt.state} + if got := m.EffectiveLifecycleState(); got != tt.want { + t.Fatalf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestIsInMaintenance(t *testing.T) { + m := &RQLiteNodeMetadata{LifecycleState: "maintenance"} + if !m.IsInMaintenance() { + t.Fatal("expected maintenance") + } + + m.LifecycleState = "active" + if m.IsInMaintenance() { + t.Fatal("expected not maintenance") + } + + // Empty state (old node) should not be maintenance + m.LifecycleState = "" + if m.IsInMaintenance() { + t.Fatal("empty state should not be maintenance") + } +} + +func TestIsAvailable(t *testing.T) { + m := &RQLiteNodeMetadata{LifecycleState: "active"} + if !m.IsAvailable() { + t.Fatal("expected available") + } + + // Empty state (old node) defaults to active → available + m.LifecycleState = "" + if !m.IsAvailable() { + t.Fatal("empty state should be available (backward compat)") + } + + m.LifecycleState = "maintenance" + if m.IsAvailable() { + t.Fatal("maintenance should not be available") + } +} + +func TestIsMaintenanceExpired(t *testing.T) { + // Expired + m := &RQLiteNodeMetadata{ + LifecycleState: "maintenance", + MaintenanceTTL: time.Now().Add(-1 * time.Minute), + } + if !m.IsMaintenanceExpired() { + t.Fatal("expected expired") + } + + // Not expired + m.MaintenanceTTL = time.Now().Add(5 * time.Minute) + if m.IsMaintenanceExpired() { + t.Fatal("expected not expired") + } + + // Zero TTL in maintenance + m.MaintenanceTTL = time.Time{} + if m.IsMaintenanceExpired() { + t.Fatal("zero TTL should not be considered expired") + } + + // Not in maintenance + m.LifecycleState = "active" + m.MaintenanceTTL = time.Now().Add(-1 * time.Minute) + if m.IsMaintenanceExpired() { + t.Fatal("active state should not report expired") + } +} + +// TestBackwardCompatibility verifies that old metadata (without new fields) +// unmarshals correctly — new fields get zero values, helpers return sane defaults. +func TestBackwardCompatibility(t *testing.T) { + oldJSON := `{ + "node_id": "10.0.0.1:7001", + "raft_address": "10.0.0.1:7001", + "http_address": "10.0.0.1:5001", + "node_type": "node", + "raft_log_index": 42, + "cluster_version": "1.0" + }` + + var m RQLiteNodeMetadata + if err := json.Unmarshal([]byte(oldJSON), &m); err != nil { + t.Fatalf("unmarshal old metadata: %v", err) + } + + // Existing fields preserved + if m.NodeID != "10.0.0.1:7001" { + t.Fatalf("expected node_id 10.0.0.1:7001, got %s", m.NodeID) + } + if m.RaftLogIndex != 42 { + t.Fatalf("expected raft_log_index 42, got %d", m.RaftLogIndex) + } + + // New fields default to zero values + if m.PeerID != "" { + t.Fatalf("expected empty PeerID, got %q", m.PeerID) + } + if m.LifecycleState != "" { + t.Fatalf("expected empty LifecycleState, got %q", m.LifecycleState) + } + if m.Services != nil { + t.Fatal("expected nil Services") + } + + // Helpers return correct defaults + if m.EffectiveLifecycleState() != "active" { + t.Fatalf("expected effective state 'active', got %q", m.EffectiveLifecycleState()) + } + if !m.IsAvailable() { + t.Fatal("old metadata should be available") + } + if m.IsInMaintenance() { + t.Fatal("old metadata should not be in maintenance") + } +} + +// TestNewFieldsRoundTrip verifies that new fields marshal/unmarshal correctly. +func TestNewFieldsRoundTrip(t *testing.T) { + original := &RQLiteNodeMetadata{ + NodeID: "10.0.0.1:7001", + RaftAddress: "10.0.0.1:7001", + HTTPAddress: "10.0.0.1:5001", + NodeType: "node", + RaftLogIndex: 100, + ClusterVersion: "1.0", + PeerID: "QmPeerID123", + WireGuardIP: "10.0.0.1", + LifecycleState: "maintenance", + MaintenanceTTL: time.Now().Add(10 * time.Minute).Truncate(time.Millisecond), + BinaryVersion: "1.2.3", + Services: map[string]*ServiceStatus{ + "rqlite": {Name: "rqlite", Running: true, Healthy: true, Message: "leader"}, + }, + Namespaces: map[string]*NamespaceStatus{ + "myapp": {Name: "myapp", Status: "healthy"}, + }, + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var decoded RQLiteNodeMetadata + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if decoded.PeerID != original.PeerID { + t.Fatalf("PeerID: got %q, want %q", decoded.PeerID, original.PeerID) + } + if decoded.WireGuardIP != original.WireGuardIP { + t.Fatalf("WireGuardIP: got %q, want %q", decoded.WireGuardIP, original.WireGuardIP) + } + if decoded.LifecycleState != original.LifecycleState { + t.Fatalf("LifecycleState: got %q, want %q", decoded.LifecycleState, original.LifecycleState) + } + if decoded.BinaryVersion != original.BinaryVersion { + t.Fatalf("BinaryVersion: got %q, want %q", decoded.BinaryVersion, original.BinaryVersion) + } + if decoded.Services["rqlite"] == nil || !decoded.Services["rqlite"].Running { + t.Fatal("expected rqlite service to be running") + } + if decoded.Namespaces["myapp"] == nil || decoded.Namespaces["myapp"].Status != "healthy" { + t.Fatal("expected myapp namespace to be healthy") + } +} + +// TestOldNodeReadsNewMetadata simulates an old node (that doesn't know about new fields) +// reading metadata from a new node. Go's JSON unmarshalling silently ignores unknown fields. +func TestOldNodeReadsNewMetadata(t *testing.T) { + newJSON := `{ + "node_id": "10.0.0.1:7001", + "raft_address": "10.0.0.1:7001", + "http_address": "10.0.0.1:5001", + "node_type": "node", + "raft_log_index": 42, + "cluster_version": "1.0", + "peer_id": "QmSomePeerID", + "wireguard_ip": "10.0.0.1", + "lifecycle_state": "maintenance", + "maintenance_ttl": "2025-01-01T00:00:00Z", + "binary_version": "2.0.0", + "services": {"rqlite": {"name": "rqlite", "running": true, "healthy": true}}, + "namespaces": {"app": {"name": "app", "status": "healthy"}}, + "some_future_field": "unknown" + }` + + // Simulate "old" struct with only original fields + type OldMetadata struct { + NodeID string `json:"node_id"` + RaftAddress string `json:"raft_address"` + HTTPAddress string `json:"http_address"` + NodeType string `json:"node_type"` + RaftLogIndex uint64 `json:"raft_log_index"` + ClusterVersion string `json:"cluster_version"` + } + + var old OldMetadata + if err := json.Unmarshal([]byte(newJSON), &old); err != nil { + t.Fatalf("old node should unmarshal new metadata without error: %v", err) + } + + if old.NodeID != "10.0.0.1:7001" || old.RaftLogIndex != 42 { + t.Fatal("old fields should be preserved") + } +} diff --git a/pkg/encryption/wallet_keygen.go b/pkg/encryption/wallet_keygen.go new file mode 100644 index 0000000..d65a182 --- /dev/null +++ b/pkg/encryption/wallet_keygen.go @@ -0,0 +1,194 @@ +package encryption + +import ( + "crypto/ed25519" + "crypto/sha256" + "fmt" + "io" + "os" + "os/exec" + "strings" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +// NodeKeys holds all cryptographic keys derived from a wallet's master key. +type NodeKeys struct { + LibP2PPrivateKey ed25519.PrivateKey // Ed25519 for LibP2P identity + LibP2PPublicKey ed25519.PublicKey + WireGuardKey [32]byte // Curve25519 private key (clamped) + WireGuardPubKey [32]byte // Curve25519 public key + IPFSPrivateKey ed25519.PrivateKey + IPFSPublicKey ed25519.PublicKey + ClusterPrivateKey ed25519.PrivateKey // IPFS Cluster identity + ClusterPublicKey ed25519.PublicKey + JWTPrivateKey ed25519.PrivateKey // EdDSA JWT signing key + JWTPublicKey ed25519.PublicKey +} + +// DeriveNodeKeysFromWallet calls `rw derive` to get a master key from the user's +// Root Wallet, then expands it into all node keys. The wallet's private key never +// leaves the `rw` process. +// +// vpsIP is used as the HKDF info parameter, so each VPS gets unique keys from the +// same wallet. Stdin is passed through so rw can prompt for the wallet password. +func DeriveNodeKeysFromWallet(vpsIP string) (*NodeKeys, error) { + if vpsIP == "" { + return nil, fmt.Errorf("VPS IP is required for key derivation") + } + + // Check rw is installed + if _, err := exec.LookPath("rw"); err != nil { + return nil, fmt.Errorf("Root Wallet (rw) not found in PATH — install it first") + } + + // Call rw derive to get master key bytes + cmd := exec.Command("rw", "derive", "--salt", "orama-node", "--info", vpsIP) + cmd.Stdin = os.Stdin // pass through for password prompts + cmd.Stderr = os.Stderr // rw UI messages go to terminal + out, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("rw derive failed: %w", err) + } + + masterHex := strings.TrimSpace(string(out)) + if len(masterHex) != 64 { // 32 bytes = 64 hex chars + return nil, fmt.Errorf("rw derive returned unexpected output length: %d (expected 64 hex chars)", len(masterHex)) + } + + masterKey, err := hexToBytes(masterHex) + if err != nil { + return nil, fmt.Errorf("rw derive returned invalid hex: %w", err) + } + defer zeroBytes(masterKey) + + return ExpandNodeKeys(masterKey) +} + +// ExpandNodeKeys expands a 32-byte master key into all node keys using HKDF-SHA256. +// The master key should come from `rw derive --salt "orama-node" --info ""`. +// +// Each key type uses a different HKDF info string under the salt "orama-expand", +// ensuring cryptographic independence between key types. +func ExpandNodeKeys(masterKey []byte) (*NodeKeys, error) { + if len(masterKey) != 32 { + return nil, fmt.Errorf("master key must be 32 bytes, got %d", len(masterKey)) + } + + salt := []byte("orama-expand") + keys := &NodeKeys{} + + // Derive LibP2P Ed25519 key + seed, err := deriveBytes(masterKey, salt, []byte("libp2p-identity"), ed25519.SeedSize) + if err != nil { + return nil, fmt.Errorf("failed to derive libp2p key: %w", err) + } + priv := ed25519.NewKeyFromSeed(seed) + zeroBytes(seed) + keys.LibP2PPrivateKey = priv + keys.LibP2PPublicKey = priv.Public().(ed25519.PublicKey) + + // Derive WireGuard Curve25519 key + wgSeed, err := deriveBytes(masterKey, salt, []byte("wireguard-key"), 32) + if err != nil { + return nil, fmt.Errorf("failed to derive wireguard key: %w", err) + } + copy(keys.WireGuardKey[:], wgSeed) + zeroBytes(wgSeed) + clampCurve25519Key(&keys.WireGuardKey) + pubKey, err := curve25519.X25519(keys.WireGuardKey[:], curve25519.Basepoint) + if err != nil { + return nil, fmt.Errorf("failed to compute wireguard public key: %w", err) + } + copy(keys.WireGuardPubKey[:], pubKey) + + // Derive IPFS Ed25519 key + seed, err = deriveBytes(masterKey, salt, []byte("ipfs-identity"), ed25519.SeedSize) + if err != nil { + return nil, fmt.Errorf("failed to derive ipfs key: %w", err) + } + priv = ed25519.NewKeyFromSeed(seed) + zeroBytes(seed) + keys.IPFSPrivateKey = priv + keys.IPFSPublicKey = priv.Public().(ed25519.PublicKey) + + // Derive IPFS Cluster Ed25519 key + seed, err = deriveBytes(masterKey, salt, []byte("ipfs-cluster"), ed25519.SeedSize) + if err != nil { + return nil, fmt.Errorf("failed to derive cluster key: %w", err) + } + priv = ed25519.NewKeyFromSeed(seed) + zeroBytes(seed) + keys.ClusterPrivateKey = priv + keys.ClusterPublicKey = priv.Public().(ed25519.PublicKey) + + // Derive JWT EdDSA signing key + seed, err = deriveBytes(masterKey, salt, []byte("jwt-signing"), ed25519.SeedSize) + if err != nil { + return nil, fmt.Errorf("failed to derive jwt key: %w", err) + } + priv = ed25519.NewKeyFromSeed(seed) + zeroBytes(seed) + keys.JWTPrivateKey = priv + keys.JWTPublicKey = priv.Public().(ed25519.PublicKey) + + return keys, nil +} + +// deriveBytes uses HKDF-SHA256 to derive n bytes from the given IKM, salt, and info. +func deriveBytes(ikm, salt, info []byte, n int) ([]byte, error) { + hkdfReader := hkdf.New(sha256.New, ikm, salt, info) + out := make([]byte, n) + if _, err := io.ReadFull(hkdfReader, out); err != nil { + return nil, err + } + return out, nil +} + +// clampCurve25519Key applies the standard Curve25519 clamping to a private key. +func clampCurve25519Key(key *[32]byte) { + key[0] &= 248 + key[31] &= 127 + key[31] |= 64 +} + +// hexToBytes decodes a hex string to bytes. +func hexToBytes(hex string) ([]byte, error) { + if len(hex)%2 != 0 { + return nil, fmt.Errorf("odd-length hex string") + } + b := make([]byte, len(hex)/2) + for i := 0; i < len(hex); i += 2 { + var hi, lo byte + var err error + if hi, err = hexCharToByte(hex[i]); err != nil { + return nil, err + } + if lo, err = hexCharToByte(hex[i+1]); err != nil { + return nil, err + } + b[i/2] = hi<<4 | lo + } + return b, nil +} + +func hexCharToByte(c byte) (byte, error) { + switch { + case c >= '0' && c <= '9': + return c - '0', nil + case c >= 'a' && c <= 'f': + return c - 'a' + 10, nil + case c >= 'A' && c <= 'F': + return c - 'A' + 10, nil + default: + return 0, fmt.Errorf("invalid hex character: %c", c) + } +} + +// zeroBytes zeroes a byte slice to clear sensitive data from memory. +func zeroBytes(b []byte) { + for i := range b { + b[i] = 0 + } +} diff --git a/pkg/encryption/wallet_keygen_test.go b/pkg/encryption/wallet_keygen_test.go new file mode 100644 index 0000000..d06cd86 --- /dev/null +++ b/pkg/encryption/wallet_keygen_test.go @@ -0,0 +1,202 @@ +package encryption + +import ( + "bytes" + "crypto/ed25519" + "testing" +) + +// testMasterKey is a deterministic 32-byte key for testing ExpandNodeKeys. +// In production, this comes from `rw derive --salt "orama-node" --info ""`. +var testMasterKey = bytes.Repeat([]byte{0xab}, 32) +var testMasterKey2 = bytes.Repeat([]byte{0xcd}, 32) + +func TestExpandNodeKeys_Determinism(t *testing.T) { + keys1, err := ExpandNodeKeys(testMasterKey) + if err != nil { + t.Fatalf("ExpandNodeKeys: %v", err) + } + keys2, err := ExpandNodeKeys(testMasterKey) + if err != nil { + t.Fatalf("ExpandNodeKeys (second): %v", err) + } + + if !bytes.Equal(keys1.LibP2PPrivateKey, keys2.LibP2PPrivateKey) { + t.Error("LibP2P private keys differ for same input") + } + if !bytes.Equal(keys1.WireGuardKey[:], keys2.WireGuardKey[:]) { + t.Error("WireGuard keys differ for same input") + } + if !bytes.Equal(keys1.IPFSPrivateKey, keys2.IPFSPrivateKey) { + t.Error("IPFS private keys differ for same input") + } + if !bytes.Equal(keys1.ClusterPrivateKey, keys2.ClusterPrivateKey) { + t.Error("Cluster private keys differ for same input") + } + if !bytes.Equal(keys1.JWTPrivateKey, keys2.JWTPrivateKey) { + t.Error("JWT private keys differ for same input") + } +} + +func TestExpandNodeKeys_Uniqueness(t *testing.T) { + keys1, err := ExpandNodeKeys(testMasterKey) + if err != nil { + t.Fatalf("ExpandNodeKeys(master1): %v", err) + } + keys2, err := ExpandNodeKeys(testMasterKey2) + if err != nil { + t.Fatalf("ExpandNodeKeys(master2): %v", err) + } + + if bytes.Equal(keys1.LibP2PPrivateKey, keys2.LibP2PPrivateKey) { + t.Error("LibP2P keys should differ for different master keys") + } + if bytes.Equal(keys1.WireGuardKey[:], keys2.WireGuardKey[:]) { + t.Error("WireGuard keys should differ for different master keys") + } + if bytes.Equal(keys1.IPFSPrivateKey, keys2.IPFSPrivateKey) { + t.Error("IPFS keys should differ for different master keys") + } + if bytes.Equal(keys1.ClusterPrivateKey, keys2.ClusterPrivateKey) { + t.Error("Cluster keys should differ for different master keys") + } + if bytes.Equal(keys1.JWTPrivateKey, keys2.JWTPrivateKey) { + t.Error("JWT keys should differ for different master keys") + } +} + +func TestExpandNodeKeys_KeysAreMutuallyUnique(t *testing.T) { + keys, err := ExpandNodeKeys(testMasterKey) + if err != nil { + t.Fatalf("ExpandNodeKeys: %v", err) + } + + privKeys := [][]byte{ + keys.LibP2PPrivateKey.Seed(), + keys.IPFSPrivateKey.Seed(), + keys.ClusterPrivateKey.Seed(), + keys.JWTPrivateKey.Seed(), + keys.WireGuardKey[:], + } + labels := []string{"LibP2P", "IPFS", "Cluster", "JWT", "WireGuard"} + + for i := 0; i < len(privKeys); i++ { + for j := i + 1; j < len(privKeys); j++ { + if bytes.Equal(privKeys[i], privKeys[j]) { + t.Errorf("%s and %s keys should differ", labels[i], labels[j]) + } + } + } +} + +func TestExpandNodeKeys_Ed25519Validity(t *testing.T) { + keys, err := ExpandNodeKeys(testMasterKey) + if err != nil { + t.Fatalf("ExpandNodeKeys: %v", err) + } + + msg := []byte("test message for verification") + + pairs := []struct { + name string + priv ed25519.PrivateKey + pub ed25519.PublicKey + }{ + {"LibP2P", keys.LibP2PPrivateKey, keys.LibP2PPublicKey}, + {"IPFS", keys.IPFSPrivateKey, keys.IPFSPublicKey}, + {"Cluster", keys.ClusterPrivateKey, keys.ClusterPublicKey}, + {"JWT", keys.JWTPrivateKey, keys.JWTPublicKey}, + } + + for _, p := range pairs { + signature := ed25519.Sign(p.priv, msg) + if !ed25519.Verify(p.pub, msg, signature) { + t.Errorf("%s key pair: signature verification failed", p.name) + } + } +} + +func TestExpandNodeKeys_WireGuardClamping(t *testing.T) { + keys, err := ExpandNodeKeys(testMasterKey) + if err != nil { + t.Fatalf("ExpandNodeKeys: %v", err) + } + + if keys.WireGuardKey[0]&7 != 0 { + t.Errorf("WireGuard key not properly clamped: low 3 bits of first byte should be 0, got %08b", keys.WireGuardKey[0]) + } + if keys.WireGuardKey[31]&128 != 0 { + t.Errorf("WireGuard key not properly clamped: high bit of last byte should be 0, got %08b", keys.WireGuardKey[31]) + } + if keys.WireGuardKey[31]&64 != 64 { + t.Errorf("WireGuard key not properly clamped: second-high bit of last byte should be 1, got %08b", keys.WireGuardKey[31]) + } + + var zero [32]byte + if keys.WireGuardPubKey == zero { + t.Error("WireGuard public key is all zeros") + } +} + +func TestExpandNodeKeys_InvalidMasterKeyLength(t *testing.T) { + _, err := ExpandNodeKeys(nil) + if err == nil { + t.Error("expected error for nil master key") + } + + _, err = ExpandNodeKeys([]byte{}) + if err == nil { + t.Error("expected error for empty master key") + } + + _, err = ExpandNodeKeys(make([]byte, 16)) + if err == nil { + t.Error("expected error for 16-byte master key") + } + + _, err = ExpandNodeKeys(make([]byte, 64)) + if err == nil { + t.Error("expected error for 64-byte master key") + } +} + +func TestHexToBytes(t *testing.T) { + tests := []struct { + input string + expected []byte + wantErr bool + }{ + {"", []byte{}, false}, + {"00", []byte{0}, false}, + {"ff", []byte{255}, false}, + {"FF", []byte{255}, false}, + {"0a1b2c", []byte{10, 27, 44}, false}, + {"0", nil, true}, // odd length + {"zz", nil, true}, // invalid chars + {"gg", nil, true}, // invalid chars + } + + for _, tt := range tests { + got, err := hexToBytes(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("hexToBytes(%q): expected error", tt.input) + } + continue + } + if err != nil { + t.Errorf("hexToBytes(%q): unexpected error: %v", tt.input, err) + continue + } + if !bytes.Equal(got, tt.expected) { + t.Errorf("hexToBytes(%q) = %v, want %v", tt.input, got, tt.expected) + } + } +} + +func TestDeriveNodeKeysFromWallet_EmptyIP(t *testing.T) { + _, err := DeriveNodeKeysFromWallet("") + if err == nil { + t.Error("expected error for empty VPS IP") + } +} diff --git a/pkg/gateway/auth/jwt.go b/pkg/gateway/auth/jwt.go index 14a7fcd..b5da8c6 100644 --- a/pkg/gateway/auth/jwt.go +++ b/pkg/gateway/auth/jwt.go @@ -2,6 +2,7 @@ package auth import ( "crypto" + "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/sha256" @@ -15,31 +16,46 @@ import ( func (s *Service) JWKSHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if s.signingKey == nil { - _ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{}}) - return + + keys := make([]any, 0, 2) + + // RSA key (RS256) + if s.signingKey != nil { + pub := s.signingKey.Public().(*rsa.PublicKey) + n := pub.N.Bytes() + eVal := pub.E + eb := make([]byte, 0) + for eVal > 0 { + eb = append([]byte{byte(eVal & 0xff)}, eb...) + eVal >>= 8 + } + if len(eb) == 0 { + eb = []byte{0} + } + keys = append(keys, map[string]string{ + "kty": "RSA", + "use": "sig", + "alg": "RS256", + "kid": s.keyID, + "n": base64.RawURLEncoding.EncodeToString(n), + "e": base64.RawURLEncoding.EncodeToString(eb), + }) } - pub := s.signingKey.Public().(*rsa.PublicKey) - n := pub.N.Bytes() - // Encode exponent as big-endian bytes - eVal := pub.E - eb := make([]byte, 0) - for eVal > 0 { - eb = append([]byte{byte(eVal & 0xff)}, eb...) - eVal >>= 8 + + // Ed25519 key (EdDSA) + if s.edSigningKey != nil { + pubKey := s.edSigningKey.Public().(ed25519.PublicKey) + keys = append(keys, map[string]string{ + "kty": "OKP", + "use": "sig", + "alg": "EdDSA", + "kid": s.edKeyID, + "crv": "Ed25519", + "x": base64.RawURLEncoding.EncodeToString(pubKey), + }) } - if len(eb) == 0 { - eb = []byte{0} - } - jwk := map[string]string{ - "kty": "RSA", - "use": "sig", - "alg": "RS256", - "kid": s.keyID, - "n": base64.RawURLEncoding.EncodeToString(n), - "e": base64.RawURLEncoding.EncodeToString(eb), - } - _ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{jwk}}) + + _ = json.NewEncoder(w).Encode(map[string]any{"keys": keys}) } // Internal types for JWT handling @@ -59,11 +75,12 @@ type JWTClaims struct { Namespace string `json:"namespace"` } -// ParseAndVerifyJWT verifies an RS256 JWT created by this gateway and returns claims +// ParseAndVerifyJWT verifies a JWT created by this gateway using kid-based key +// selection. It accepts both RS256 (legacy) and EdDSA (new) tokens. +// +// Security (C3 fix): The key is selected by kid, then cross-checked against alg +// to prevent algorithm confusion attacks. Only RS256 and EdDSA are accepted. func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) { - if s.signingKey == nil { - return nil, errors.New("signing key unavailable") - } parts := strings.Split(token, ".") if len(parts) != 3 { return nil, errors.New("invalid token format") @@ -80,20 +97,60 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) { if err != nil { return nil, errors.New("invalid signature encoding") } + var header jwtHeader if err := json.Unmarshal(hb, &header); err != nil { return nil, errors.New("invalid header json") } - if header.Alg != "RS256" { - return nil, errors.New("unsupported alg") + + // Explicit algorithm allowlist — reject everything else before verification + if header.Alg != "RS256" && header.Alg != "EdDSA" { + return nil, errors.New("unsupported algorithm") } - // Verify signature + signingInput := parts[0] + "." + parts[1] - sum := sha256.Sum256([]byte(signingInput)) - pub := s.signingKey.Public().(*rsa.PublicKey) - if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil { - return nil, errors.New("invalid signature") + + // Key selection by kid (not alg) — prevents algorithm confusion (C3 fix) + switch { + case header.Kid != "" && header.Kid == s.edKeyID && s.edSigningKey != nil: + // EdDSA key matched by kid — cross-check alg + if header.Alg != "EdDSA" { + return nil, errors.New("algorithm mismatch for key") + } + pubKey := s.edSigningKey.Public().(ed25519.PublicKey) + if !ed25519.Verify(pubKey, []byte(signingInput), sb) { + return nil, errors.New("invalid signature") + } + + case header.Kid != "" && header.Kid == s.keyID && s.signingKey != nil: + // RSA key matched by kid — cross-check alg + if header.Alg != "RS256" { + return nil, errors.New("algorithm mismatch for key") + } + sum := sha256.Sum256([]byte(signingInput)) + pub := s.signingKey.Public().(*rsa.PublicKey) + if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil { + return nil, errors.New("invalid signature") + } + + case header.Kid == "": + // Legacy token without kid — RS256 only (backward compat) + if header.Alg != "RS256" { + return nil, errors.New("legacy token must be RS256") + } + if s.signingKey == nil { + return nil, errors.New("signing key unavailable") + } + sum := sha256.Sum256([]byte(signingInput)) + pub := s.signingKey.Public().(*rsa.PublicKey) + if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil { + return nil, errors.New("invalid signature") + } + + default: + return nil, errors.New("unknown key ID") } + // Parse claims var claims JWTClaims if err := json.Unmarshal(pb, &claims); err != nil { @@ -105,8 +162,7 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) { } // Validate registered claims now := time.Now().Unix() - // allow small clock skew ±60s - const skew = int64(60) + const skew = int64(60) // allow small clock skew ±60s if claims.Nbf != 0 && now+skew < claims.Nbf { return nil, errors.New("token not yet valid") } @@ -123,6 +179,44 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) { } func (s *Service) GenerateJWT(ns, subject string, ttl time.Duration) (string, int64, error) { + // Prefer EdDSA when available + if s.preferEdDSA && s.edSigningKey != nil { + return s.generateEdDSAJWT(ns, subject, ttl) + } + return s.generateRSAJWT(ns, subject, ttl) +} + +func (s *Service) generateEdDSAJWT(ns, subject string, ttl time.Duration) (string, int64, error) { + if s.edSigningKey == nil { + return "", 0, errors.New("EdDSA signing key unavailable") + } + header := map[string]string{ + "alg": "EdDSA", + "typ": "JWT", + "kid": s.edKeyID, + } + hb, _ := json.Marshal(header) + now := time.Now().UTC() + exp := now.Add(ttl) + payload := map[string]any{ + "iss": "debros-gateway", + "sub": subject, + "aud": "gateway", + "iat": now.Unix(), + "nbf": now.Unix(), + "exp": exp.Unix(), + "namespace": ns, + } + pb, _ := json.Marshal(payload) + hb64 := base64.RawURLEncoding.EncodeToString(hb) + pb64 := base64.RawURLEncoding.EncodeToString(pb) + signingInput := hb64 + "." + pb64 + sig := ed25519.Sign(s.edSigningKey, []byte(signingInput)) + sb64 := base64.RawURLEncoding.EncodeToString(sig) + return signingInput + "." + sb64, exp.Unix(), nil +} + +func (s *Service) generateRSAJWT(ns, subject string, ttl time.Duration) (string, int64, error) { if s.signingKey == nil { return "", 0, errors.New("signing key unavailable") } diff --git a/pkg/gateway/auth/service.go b/pkg/gateway/auth/service.go index be8f40d..0fe7176 100644 --- a/pkg/gateway/auth/service.go +++ b/pkg/gateway/auth/service.go @@ -24,11 +24,14 @@ import ( // Service handles authentication business logic type Service struct { - logger *logging.ColoredLogger - orm client.NetworkClient - signingKey *rsa.PrivateKey - keyID string - defaultNS string + logger *logging.ColoredLogger + orm client.NetworkClient + signingKey *rsa.PrivateKey + keyID string + edSigningKey ed25519.PrivateKey + edKeyID string + preferEdDSA bool + defaultNS string } func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signingKeyPEM string, defaultNS string) (*Service, error) { @@ -58,6 +61,16 @@ func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signing return s, nil } +// SetEdDSAKey configures an Ed25519 signing key for EdDSA JWT support. +// When set, new tokens are signed with EdDSA; RS256 is still accepted for verification. +func (s *Service) SetEdDSAKey(privKey ed25519.PrivateKey) { + s.edSigningKey = privKey + pubBytes := []byte(privKey.Public().(ed25519.PublicKey)) + sum := sha256.Sum256(pubBytes) + s.edKeyID = "ed_" + hex.EncodeToString(sum[:8]) + s.preferEdDSA = true +} + // CreateNonce generates a new nonce and stores it in the database func (s *Service) CreateNonce(ctx context.Context, wallet, purpose, namespace string) (string, error) { // Generate a URL-safe random nonce (32 bytes) diff --git a/pkg/gateway/auth/service_test.go b/pkg/gateway/auth/service_test.go index 61dcf5f..55b418f 100644 --- a/pkg/gateway/auth/service_test.go +++ b/pkg/gateway/auth/service_test.go @@ -2,11 +2,18 @@ package auth import ( "context" + "crypto/ed25519" "crypto/rand" "crypto/rsa" + "crypto/sha256" "crypto/x509" + "encoding/base64" "encoding/hex" + "encoding/json" "encoding/pem" + "net/http" + "net/http/httptest" + "strings" "testing" "time" @@ -164,3 +171,248 @@ func TestVerifySolSignature(t *testing.T) { t.Error("VerifySignature should have failed for invalid base64 signature") } } + +// createDualKeyService creates a service with both RSA and EdDSA keys configured +func createDualKeyService(t *testing.T) *Service { + t.Helper() + s := createTestService(t) // has RSA + _, edPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate ed25519 key: %v", err) + } + s.SetEdDSAKey(edPriv) + return s +} + +func TestEdDSAJWTFlow(t *testing.T) { + s := createDualKeyService(t) + + ns := "test-ns" + sub := "0xabcdef1234567890abcdef1234567890abcdef12" + ttl := 15 * time.Minute + + // With EdDSA preferred, GenerateJWT should produce an EdDSA token + token, exp, err := s.GenerateJWT(ns, sub, ttl) + if err != nil { + t.Fatalf("GenerateJWT (EdDSA) failed: %v", err) + } + if token == "" { + t.Fatal("generated EdDSA token is empty") + } + if exp <= time.Now().Unix() { + t.Errorf("expiration time %d is in the past", exp) + } + + // Verify the header contains EdDSA + parts := strings.Split(token, ".") + hb, _ := base64.RawURLEncoding.DecodeString(parts[0]) + var header map[string]string + json.Unmarshal(hb, &header) + if header["alg"] != "EdDSA" { + t.Errorf("expected alg EdDSA, got %s", header["alg"]) + } + if header["kid"] != s.edKeyID { + t.Errorf("expected kid %s, got %s", s.edKeyID, header["kid"]) + } + + // Verify the token + claims, err := s.ParseAndVerifyJWT(token) + if err != nil { + t.Fatalf("ParseAndVerifyJWT (EdDSA) failed: %v", err) + } + if claims.Sub != sub { + t.Errorf("expected subject %s, got %s", sub, claims.Sub) + } + if claims.Namespace != ns { + t.Errorf("expected namespace %s, got %s", ns, claims.Namespace) + } +} + +func TestRS256BackwardCompat(t *testing.T) { + s := createDualKeyService(t) + + // Generate an RS256 token directly (simulating a legacy token) + s.preferEdDSA = false + token, _, err := s.GenerateJWT("test-ns", "user1", 15*time.Minute) + if err != nil { + t.Fatalf("GenerateJWT (RS256) failed: %v", err) + } + s.preferEdDSA = true // re-enable EdDSA preference + + // Verify the RS256 token still works with dual-key service + claims, err := s.ParseAndVerifyJWT(token) + if err != nil { + t.Fatalf("ParseAndVerifyJWT should accept RS256 token: %v", err) + } + if claims.Sub != "user1" { + t.Errorf("expected subject user1, got %s", claims.Sub) + } +} + +func TestAlgorithmConfusion_Rejected(t *testing.T) { + s := createDualKeyService(t) + + t.Run("none_algorithm", func(t *testing.T) { + // Craft a token with alg=none + header := map[string]string{"alg": "none", "typ": "JWT"} + hb, _ := json.Marshal(header) + payload := map[string]any{ + "iss": "debros-gateway", "sub": "attacker", "aud": "gateway", + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns", + } + pb, _ := json.Marshal(payload) + token := base64.RawURLEncoding.EncodeToString(hb) + "." + + base64.RawURLEncoding.EncodeToString(pb) + "." + + _, err := s.ParseAndVerifyJWT(token) + if err == nil { + t.Error("should reject alg=none") + } + }) + + t.Run("HS256_algorithm", func(t *testing.T) { + header := map[string]string{"alg": "HS256", "typ": "JWT", "kid": s.keyID} + hb, _ := json.Marshal(header) + payload := map[string]any{ + "iss": "debros-gateway", "sub": "attacker", "aud": "gateway", + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns", + } + pb, _ := json.Marshal(payload) + token := base64.RawURLEncoding.EncodeToString(hb) + "." + + base64.RawURLEncoding.EncodeToString(pb) + "." + + base64.RawURLEncoding.EncodeToString([]byte("fake-sig")) + + _, err := s.ParseAndVerifyJWT(token) + if err == nil { + t.Error("should reject alg=HS256") + } + }) + + t.Run("kid_alg_mismatch_EdDSA_kid_RS256_alg", func(t *testing.T) { + // Use EdDSA kid but claim RS256 alg + header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": s.edKeyID} + hb, _ := json.Marshal(header) + payload := map[string]any{ + "iss": "debros-gateway", "sub": "attacker", "aud": "gateway", + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns", + } + pb, _ := json.Marshal(payload) + // Sign with RSA (trying to confuse the verifier into using RSA on EdDSA kid) + hb64 := base64.RawURLEncoding.EncodeToString(hb) + pb64 := base64.RawURLEncoding.EncodeToString(pb) + signingInput := hb64 + "." + pb64 + sum := sha256.Sum256([]byte(signingInput)) + rsaSig, _ := rsa.SignPKCS1v15(rand.Reader, s.signingKey, 4, sum[:]) // crypto.SHA256 = 4 + token := signingInput + "." + base64.RawURLEncoding.EncodeToString(rsaSig) + + _, err := s.ParseAndVerifyJWT(token) + if err == nil { + t.Error("should reject kid/alg mismatch (EdDSA kid with RS256 alg)") + } + if err != nil && !strings.Contains(err.Error(), "algorithm mismatch") { + t.Errorf("expected 'algorithm mismatch' error, got: %v", err) + } + }) + + t.Run("unknown_kid", func(t *testing.T) { + header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": "unknown-kid-123"} + hb, _ := json.Marshal(header) + payload := map[string]any{ + "iss": "debros-gateway", "sub": "attacker", "aud": "gateway", + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns", + } + pb, _ := json.Marshal(payload) + token := base64.RawURLEncoding.EncodeToString(hb) + "." + + base64.RawURLEncoding.EncodeToString(pb) + "." + + base64.RawURLEncoding.EncodeToString([]byte("fake-sig")) + + _, err := s.ParseAndVerifyJWT(token) + if err == nil { + t.Error("should reject unknown kid") + } + }) + + t.Run("legacy_token_EdDSA_rejected", func(t *testing.T) { + // Token with no kid and alg=EdDSA — should be rejected (legacy must be RS256) + header := map[string]string{"alg": "EdDSA", "typ": "JWT"} + hb, _ := json.Marshal(header) + payload := map[string]any{ + "iss": "debros-gateway", "sub": "attacker", "aud": "gateway", + "iat": time.Now().Unix(), "nbf": time.Now().Unix(), + "exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns", + } + pb, _ := json.Marshal(payload) + hb64 := base64.RawURLEncoding.EncodeToString(hb) + pb64 := base64.RawURLEncoding.EncodeToString(pb) + signingInput := hb64 + "." + pb64 + sig := ed25519.Sign(s.edSigningKey, []byte(signingInput)) + token := signingInput + "." + base64.RawURLEncoding.EncodeToString(sig) + + _, err := s.ParseAndVerifyJWT(token) + if err == nil { + t.Error("should reject legacy token (no kid) with EdDSA alg") + } + }) +} + +func TestJWKSHandler_DualKey(t *testing.T) { + s := createDualKeyService(t) + + req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil) + w := httptest.NewRecorder() + s.JWKSHandler(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var result struct { + Keys []map[string]string `json:"keys"` + } + if err := json.NewDecoder(w.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode JWKS response: %v", err) + } + + if len(result.Keys) != 2 { + t.Fatalf("expected 2 keys in JWKS, got %d", len(result.Keys)) + } + + // Verify we have both RSA and OKP keys + algSet := map[string]bool{} + for _, k := range result.Keys { + algSet[k["alg"]] = true + if k["kid"] == "" { + t.Error("key missing kid") + } + } + if !algSet["RS256"] { + t.Error("JWKS missing RS256 key") + } + if !algSet["EdDSA"] { + t.Error("JWKS missing EdDSA key") + } +} + +func TestJWKSHandler_RSAOnly(t *testing.T) { + s := createTestService(t) // RSA only, no EdDSA + + req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil) + w := httptest.NewRecorder() + s.JWKSHandler(w, req) + + var result struct { + Keys []map[string]string `json:"keys"` + } + json.NewDecoder(w.Body).Decode(&result) + + if len(result.Keys) != 1 { + t.Fatalf("expected 1 key in JWKS (RSA only), got %d", len(result.Keys)) + } + if result.Keys[0]["alg"] != "RS256" { + t.Errorf("expected RS256, got %s", result.Keys[0]["alg"]) + } +} diff --git a/pkg/gateway/dependencies.go b/pkg/gateway/dependencies.go index 237795d..4fafb24 100644 --- a/pkg/gateway/dependencies.go +++ b/pkg/gateway/dependencies.go @@ -429,7 +429,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe logger.Logger, ) - // Initialize auth service with persistent signing key + // Initialize auth service with persistent signing keys (RSA + EdDSA) keyPEM, err := loadOrCreateSigningKey(cfg.DataDir, logger) if err != nil { return fmt.Errorf("failed to load or create JWT signing key: %w", err) @@ -438,6 +438,17 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe if err != nil { return fmt.Errorf("failed to initialize auth service: %w", err) } + + // Load or create EdDSA key for new JWT tokens + edKey, err := loadOrCreateEdSigningKey(cfg.DataDir, logger) + if err != nil { + logger.ComponentWarn(logging.ComponentGeneral, "Failed to load EdDSA signing key; new JWTs will use RS256", + zap.Error(err)) + } else { + authService.SetEdDSAKey(edKey) + logger.ComponentInfo(logging.ComponentGeneral, "EdDSA signing key loaded; new JWTs will use EdDSA") + } + deps.AuthService = authService logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready", diff --git a/pkg/gateway/namespace_health.go b/pkg/gateway/namespace_health.go index 75b508f..c54b7b5 100644 --- a/pkg/gateway/namespace_health.go +++ b/pkg/gateway/namespace_health.go @@ -42,7 +42,7 @@ func (g *Gateway) startNamespaceHealthLoop(ctx context.Context) { } probeTicker := time.NewTicker(30 * time.Second) - reconcileTicker := time.NewTicker(1 * time.Hour) + reconcileTicker := time.NewTicker(5 * time.Minute) defer probeTicker.Stop() defer reconcileTicker.Stop() diff --git a/pkg/gateway/signing_key.go b/pkg/gateway/signing_key.go index 8c77521..30e8ba2 100644 --- a/pkg/gateway/signing_key.go +++ b/pkg/gateway/signing_key.go @@ -1,6 +1,7 @@ package gateway import ( + "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -14,6 +15,7 @@ import ( ) const jwtKeyFileName = "jwt-signing-key.pem" +const eddsaKeyFileName = "jwt-eddsa-key.pem" // loadOrCreateSigningKey loads the JWT signing key from disk, or generates a new one // if none exists. This ensures JWTs survive gateway restarts. @@ -61,3 +63,56 @@ func loadOrCreateSigningKey(dataDir string, logger *logging.ColoredLogger) ([]by zap.String("path", keyPath)) return keyPEM, nil } + +// loadOrCreateEdSigningKey loads or generates an Ed25519 private key for EdDSA JWT signing. +// The key is stored as a PKCS8-encoded PEM file alongside the RSA key. +func loadOrCreateEdSigningKey(dataDir string, logger *logging.ColoredLogger) (ed25519.PrivateKey, error) { + keyPath := filepath.Join(dataDir, "secrets", eddsaKeyFileName) + + // Try to load existing key + if keyPEM, err := os.ReadFile(keyPath); err == nil && len(keyPEM) > 0 { + block, _ := pem.Decode(keyPEM) + if block != nil { + parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err == nil { + if edKey, ok := parsed.(ed25519.PrivateKey); ok { + logger.ComponentInfo(logging.ComponentGeneral, "Loaded existing EdDSA signing key", + zap.String("path", keyPath)) + return edKey, nil + } + } + } + logger.ComponentWarn(logging.ComponentGeneral, "Existing EdDSA signing key is invalid, generating new one", + zap.String("path", keyPath)) + } + + // Generate new Ed25519 key + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate Ed25519 key: %w", err) + } + + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, fmt.Errorf("marshal Ed25519 key: %w", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Bytes, + }) + + // Ensure secrets directory exists + secretsDir := filepath.Dir(keyPath) + if err := os.MkdirAll(secretsDir, 0700); err != nil { + return nil, fmt.Errorf("create secrets directory: %w", err) + } + + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + return nil, fmt.Errorf("write EdDSA signing key: %w", err) + } + + logger.ComponentInfo(logging.ComponentGeneral, "Generated and saved new EdDSA signing key", + zap.String("path", keyPath)) + return priv, nil +} diff --git a/pkg/namespace/cluster_recovery.go b/pkg/namespace/cluster_recovery.go index e934ac9..35607b3 100644 --- a/pkg/namespace/cluster_recovery.go +++ b/pkg/namespace/cluster_recovery.go @@ -130,11 +130,14 @@ func (cm *ClusterManager) HandleRecoveredNode(ctx context.Context, nodeID string } if len(results) > 0 && results[0].Count > 0 { - // Node still has legitimate assignments — just mark active and return + // Node still has legitimate assignments — mark active and repair degraded clusters cm.logger.Info("Recovered node still has cluster assignments, marking active", zap.String("node_id", nodeID), zap.Int("assignments", results[0].Count)) cm.markNodeActive(ctx, nodeID) + + // Trigger repair for any degraded clusters this node belongs to + cm.repairDegradedClusters(ctx, nodeID) return } @@ -522,6 +525,37 @@ func (cm *ClusterManager) markNodeActive(ctx context.Context, nodeID string) { } } +// repairDegradedClusters finds degraded clusters that the recovered node +// belongs to and triggers RepairCluster for each one. +func (cm *ClusterManager) repairDegradedClusters(ctx context.Context, nodeID string) { + type clusterRef struct { + NamespaceName string `db:"namespace_name"` + } + var refs []clusterRef + query := ` + SELECT DISTINCT c.namespace_name + FROM namespace_cluster_nodes cn + JOIN namespace_clusters c ON cn.namespace_cluster_id = c.id + WHERE cn.node_id = ? AND c.status = 'degraded' + ` + if err := cm.db.Query(ctx, &refs, query, nodeID); err != nil { + cm.logger.Warn("Failed to query degraded clusters for recovered node", + zap.String("node_id", nodeID), zap.Error(err)) + return + } + + for _, ref := range refs { + cm.logger.Info("Triggering repair for degraded cluster after node recovery", + zap.String("namespace", ref.NamespaceName), + zap.String("recovered_node", nodeID)) + if err := cm.RepairCluster(ctx, ref.NamespaceName); err != nil { + cm.logger.Warn("Failed to repair degraded cluster", + zap.String("namespace", ref.NamespaceName), + zap.Error(err)) + } + } +} + // removeDeadNodeFromRaft sends a DELETE request to a surviving RQLite node // to remove the dead node from the Raft voter set. func (cm *ClusterManager) removeDeadNodeFromRaft(ctx context.Context, deadRaftAddr string, survivingNodes []survivingNodePorts) { diff --git a/pkg/node/health/monitor.go b/pkg/node/health/monitor.go index ee7ea53..c756fa9 100644 --- a/pkg/node/health/monitor.go +++ b/pkg/node/health/monitor.go @@ -25,8 +25,19 @@ const ( DefaultDeadAfter = 12 // consecutive misses → dead DefaultQuorumWindow = 5 * time.Minute DefaultMinQuorum = 2 // out of K observers must agree + + // DefaultStartupGracePeriod prevents false dead declarations after + // cluster-wide restart. During this period, no nodes are declared dead. + DefaultStartupGracePeriod = 5 * time.Minute ) +// MetadataReader provides lifecycle metadata for peers. Implemented by +// ClusterDiscoveryService. The health monitor uses this to check maintenance +// status and LastSeen before falling through to HTTP probes. +type MetadataReader interface { + GetPeerLifecycleState(nodeID string) (state string, lastSeen time.Time, found bool) +} + // Config holds the configuration for a Monitor. type Config struct { NodeID string // this node's ID (dns_nodes.id / peer ID) @@ -35,6 +46,16 @@ type Config struct { ProbeInterval time.Duration // how often to probe (default 10s) ProbeTimeout time.Duration // per-probe HTTP timeout (default 3s) Neighbors int // K — how many ring neighbors to monitor (default 3) + + // MetadataReader provides LibP2P lifecycle metadata for peers. + // When set, the monitor checks peer maintenance state and LastSeen + // before falling through to HTTP probes. + MetadataReader MetadataReader + + // StartupGracePeriod prevents false dead declarations after cluster-wide + // restart. During this period, nodes can be marked suspect but never dead. + // Default: 5 minutes. + StartupGracePeriod time.Duration } // nodeInfo is a row from dns_nodes used for probing. @@ -56,6 +77,7 @@ type Monitor struct { cfg Config httpClient *http.Client logger *zap.Logger + startTime time.Time // when the monitor was created mu sync.Mutex peers map[string]*peerState // nodeID → state @@ -75,6 +97,9 @@ func NewMonitor(cfg Config) *Monitor { if cfg.Neighbors == 0 { cfg.Neighbors = DefaultNeighbors } + if cfg.StartupGracePeriod == 0 { + cfg.StartupGracePeriod = DefaultStartupGracePeriod + } if cfg.Logger == nil { cfg.Logger = zap.NewNop() } @@ -84,19 +109,20 @@ func NewMonitor(cfg Config) *Monitor { httpClient: &http.Client{ Timeout: cfg.ProbeTimeout, }, - logger: cfg.Logger.With(zap.String("component", "health-monitor")), - peers: make(map[string]*peerState), + logger: cfg.Logger.With(zap.String("component", "health-monitor")), + startTime: time.Now(), + peers: make(map[string]*peerState), } } // OnNodeDead registers a callback invoked when a node is confirmed dead by -// quorum. The callback runs synchronously in the monitor goroutine. +// quorum. The callback runs with the monitor lock released. func (m *Monitor) OnNodeDead(fn func(nodeID string)) { m.onDeadFn = fn } // OnNodeRecovered registers a callback invoked when a previously dead node -// transitions back to healthy. Used to clean up orphaned services. +// transitions back to healthy. The callback runs with the monitor lock released. func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) { m.onRecoveredFn = fn } @@ -107,6 +133,7 @@ func (m *Monitor) Start(ctx context.Context) { zap.String("node_id", m.cfg.NodeID), zap.Duration("probe_interval", m.cfg.ProbeInterval), zap.Int("neighbors", m.cfg.Neighbors), + zap.Duration("startup_grace", m.cfg.StartupGracePeriod), ) ticker := time.NewTicker(m.cfg.ProbeInterval) @@ -123,6 +150,11 @@ func (m *Monitor) Start(ctx context.Context) { } } +// isInStartupGrace returns true if the startup grace period is still active. +func (m *Monitor) isInStartupGrace() bool { + return time.Since(m.startTime) < m.cfg.StartupGracePeriod +} + // probeRound runs a single round of probing our ring neighbors. func (m *Monitor) probeRound(ctx context.Context) { neighbors, err := m.getRingNeighbors(ctx) @@ -140,7 +172,7 @@ func (m *Monitor) probeRound(ctx context.Context) { wg.Add(1) go func(node nodeInfo) { defer wg.Done() - ok := m.probe(ctx, node) + ok := m.probeNode(ctx, node) m.updateState(ctx, node.ID, ok) }(n) } @@ -150,6 +182,28 @@ func (m *Monitor) probeRound(ctx context.Context) { m.pruneStaleState(neighbors) } +// probeNode checks a node's health. It first checks LibP2P metadata (if +// available) to avoid unnecessary HTTP probes, then falls through to HTTP. +func (m *Monitor) probeNode(ctx context.Context, node nodeInfo) bool { + if m.cfg.MetadataReader != nil { + state, lastSeen, found := m.cfg.MetadataReader.GetPeerLifecycleState(node.ID) + if found { + // Maintenance node with recent LastSeen → count as healthy + if state == "maintenance" && time.Since(lastSeen) < 2*time.Minute { + return true + } + + // Recently seen active node → count as healthy (no HTTP needed) + if state == "active" && time.Since(lastSeen) < 30*time.Second { + return true + } + } + } + + // Fall through to HTTP probe + return m.probe(ctx, node) +} + // probe sends an HTTP ping to a single node. Returns true if healthy. func (m *Monitor) probe(ctx context.Context, node nodeInfo) bool { url := fmt.Sprintf("http://%s:6001/v1/internal/ping", node.InternalIP) @@ -167,9 +221,9 @@ func (m *Monitor) probe(ctx context.Context, node nodeInfo) bool { } // updateState updates the in-memory state for a peer after a probe. +// Callbacks are invoked with the lock released to prevent deadlocks (C2 fix). func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) { m.mu.Lock() - defer m.mu.Unlock() ps, exists := m.peers[nodeID] if !exists { @@ -178,23 +232,26 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) } if healthy { - // Recovered - if ps.status != "healthy" { - wasDead := ps.status == "dead" - m.logger.Info("Node recovered", zap.String("target", nodeID), - zap.String("previous_status", ps.status)) - m.writeEvent(ctx, nodeID, "recovered") + wasDead := ps.status == "dead" + shouldCallback := wasDead && m.onRecoveredFn != nil + prevStatus := ps.status - // Fire recovery callback for nodes that were confirmed dead - if wasDead && m.onRecoveredFn != nil { - m.mu.Unlock() - m.onRecoveredFn(nodeID) - m.mu.Lock() - } - } + // Update state BEFORE releasing lock (C2 fix) ps.missCount = 0 ps.status = "healthy" ps.reportedDead = false + m.mu.Unlock() + + if prevStatus != "healthy" { + m.logger.Info("Node recovered", zap.String("target", nodeID), + zap.String("previous_status", prevStatus)) + m.writeEvent(ctx, nodeID, "recovered") + } + + // Fire recovery callback without holding the lock (C2 fix) + if shouldCallback { + m.onRecoveredFn(nodeID) + } return } @@ -203,6 +260,23 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) switch { case ps.missCount >= DefaultDeadAfter && !ps.reportedDead: + // During startup grace period, don't declare dead — only suspect + if m.isInStartupGrace() { + if ps.status != "suspect" { + ps.status = "suspect" + ps.suspectAt = time.Now() + m.mu.Unlock() + m.logger.Warn("Node SUSPECT (startup grace — deferring dead)", + zap.String("target", nodeID), + zap.Int("misses", ps.missCount), + ) + m.writeEvent(ctx, nodeID, "suspect") + return + } + m.mu.Unlock() + return + } + if ps.status != "dead" { m.logger.Error("Node declared DEAD", zap.String("target", nodeID), @@ -211,22 +285,34 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) } ps.status = "dead" ps.reportedDead = true + + // Copy what we need before releasing lock + shouldCheckQuorum := m.cfg.DB != nil && m.onDeadFn != nil + m.mu.Unlock() + m.writeEvent(ctx, nodeID, "dead") - m.checkQuorum(ctx, nodeID) + if shouldCheckQuorum { + m.checkQuorum(ctx, nodeID) + } + return case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy": ps.status = "suspect" ps.suspectAt = time.Now() + m.mu.Unlock() + m.logger.Warn("Node SUSPECT", zap.String("target", nodeID), zap.Int("misses", ps.missCount), ) m.writeEvent(ctx, nodeID, "suspect") + return } + + m.mu.Unlock() } -// writeEvent inserts a health event into node_health_events. Must be called -// with m.mu held. +// writeEvent inserts a health event into node_health_events. func (m *Monitor) writeEvent(ctx context.Context, targetID, status string) { if m.cfg.DB == nil { return @@ -242,7 +328,8 @@ func (m *Monitor) writeEvent(ctx context.Context, targetID, status string) { } // checkQuorum queries the events table to see if enough observers agree the -// target is dead, then fires the onDead callback. Must be called with m.mu held. +// target is dead, then fires the onDead callback. Called WITHOUT the lock held +// (C2 fix — previously called with lock held, causing deadlocks in callbacks). func (m *Monitor) checkQuorum(ctx context.Context, targetID string) { if m.cfg.DB == nil || m.onDeadFn == nil { return @@ -287,10 +374,7 @@ func (m *Monitor) checkQuorum(ctx context.Context, targetID string) { zap.String("target", targetID), zap.Int("observers", count), ) - // Release the lock before calling the callback to avoid deadlocks. - m.mu.Unlock() m.onDeadFn(targetID) - m.mu.Lock() } // getRingNeighbors queries dns_nodes for active nodes, sorts them, and diff --git a/pkg/node/health/monitor_test.go b/pkg/node/health/monitor_test.go index 18030ee..b09ffff 100644 --- a/pkg/node/health/monitor_test.go +++ b/pkg/node/health/monitor_test.go @@ -158,10 +158,12 @@ func TestRingNeighbors_KLargerThanRing(t *testing.T) { func TestStateTransitions(t *testing.T) { m := NewMonitor(Config{ - NodeID: "self", - ProbeInterval: time.Second, - Neighbors: 3, + NodeID: "self", + ProbeInterval: time.Second, + Neighbors: 3, + StartupGracePeriod: 1 * time.Millisecond, // disable grace for this test }) + time.Sleep(2 * time.Millisecond) // ensure grace period expired ctx := context.Background() @@ -298,9 +300,11 @@ func TestOnNodeDead_Callback(t *testing.T) { var called atomic.Int32 m := NewMonitor(Config{ - NodeID: "self", - Neighbors: 3, + NodeID: "self", + Neighbors: 3, + StartupGracePeriod: 1 * time.Millisecond, }) + time.Sleep(2 * time.Millisecond) m.OnNodeDead(func(nodeID string) { called.Add(1) }) @@ -316,3 +320,204 @@ func TestOnNodeDead_Callback(t *testing.T) { t.Fatalf("expected dead, got %s", m.peers["victim"].status) } } + +// --------------------------------------------------------------- +// Startup grace period +// --------------------------------------------------------------- + +func TestStartupGrace_PreventsDead(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + Neighbors: 3, + StartupGracePeriod: 1 * time.Hour, // very long grace + }) + + ctx := context.Background() + + // Accumulate enough misses for dead (12) + for i := 0; i < DefaultDeadAfter+5; i++ { + m.updateState(ctx, "peer1", false) + } + + m.mu.Lock() + status := m.peers["peer1"].status + m.mu.Unlock() + + // During grace, should be suspect, NOT dead + if status != "suspect" { + t.Fatalf("expected suspect during startup grace, got %s", status) + } +} + +func TestStartupGrace_AllowsDeadAfterExpiry(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + Neighbors: 3, + StartupGracePeriod: 1 * time.Millisecond, + }) + time.Sleep(2 * time.Millisecond) // grace expired + + ctx := context.Background() + for i := 0; i < DefaultDeadAfter; i++ { + m.updateState(ctx, "peer1", false) + } + + m.mu.Lock() + status := m.peers["peer1"].status + m.mu.Unlock() + + if status != "dead" { + t.Fatalf("expected dead after grace expired, got %s", status) + } +} + +// --------------------------------------------------------------- +// MetadataReader integration +// --------------------------------------------------------------- + +type mockMetadataReader struct { + state string + lastSeen time.Time + found bool +} + +func (m *mockMetadataReader) GetPeerLifecycleState(nodeID string) (string, time.Time, bool) { + return m.state, m.lastSeen, m.found +} + +func TestProbeNode_MaintenanceCountsHealthy(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + MetadataReader: &mockMetadataReader{ + state: "maintenance", + lastSeen: time.Now(), + found: true, + }, + }) + + node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"} // unreachable + ok := m.probeNode(context.Background(), node) + if !ok { + t.Fatal("maintenance node with recent LastSeen should count as healthy") + } +} + +func TestProbeNode_RecentActiveSkipsHTTP(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + MetadataReader: &mockMetadataReader{ + state: "active", + lastSeen: time.Now(), + found: true, + }, + }) + + // Use unreachable IP — if HTTP were attempted, it would fail + node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"} + ok := m.probeNode(context.Background(), node) + if !ok { + t.Fatal("recently seen active node should skip HTTP and count as healthy") + } +} + +func TestProbeNode_StaleMetadataFallsToHTTP(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + ProbeTimeout: 100 * time.Millisecond, + MetadataReader: &mockMetadataReader{ + state: "active", + lastSeen: time.Now().Add(-5 * time.Minute), // stale + found: true, + }, + }) + + node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"} // unreachable + ok := m.probeNode(context.Background(), node) + if ok { + t.Fatal("stale metadata should fall through to HTTP probe, which should fail") + } +} + +func TestProbeNode_UnknownPeerFallsToHTTP(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + ProbeTimeout: 100 * time.Millisecond, + MetadataReader: &mockMetadataReader{ + found: false, // peer not found in metadata + }, + }) + + node := nodeInfo{ID: "unknown", InternalIP: "192.0.2.1"} + ok := m.probeNode(context.Background(), node) + if ok { + t.Fatal("unknown peer should fall through to HTTP probe, which should fail") + } +} + +func TestProbeNode_NoMetadataReader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + m := NewMonitor(Config{ + NodeID: "self", + ProbeTimeout: 2 * time.Second, + MetadataReader: nil, // no metadata reader + }) + + // Without MetadataReader, should go straight to HTTP + addr := strings.TrimPrefix(srv.URL, "http://") + node := nodeInfo{ID: "peer1", InternalIP: addr} + // Note: probe() hardcodes port 6001, so this won't hit our server. + // But we verify it doesn't panic and falls through correctly. + _ = m.probeNode(context.Background(), node) +} + +// --------------------------------------------------------------- +// Recovery callback (C2 fix) +// --------------------------------------------------------------- + +func TestRecoveryCallback_InvokedWithoutLock(t *testing.T) { + m := NewMonitor(Config{ + NodeID: "self", + Neighbors: 3, + StartupGracePeriod: 1 * time.Millisecond, + }) + time.Sleep(2 * time.Millisecond) + + var recoveredNode string + m.OnNodeRecovered(func(nodeID string) { + recoveredNode = nodeID + // If lock were held, this would deadlock since we try to access peers + // Just verify callback fires correctly + }) + + ctx := context.Background() + + // Drive to dead state + for i := 0; i < DefaultDeadAfter; i++ { + m.updateState(ctx, "peer1", false) + } + + m.mu.Lock() + if m.peers["peer1"].status != "dead" { + m.mu.Unlock() + t.Fatal("expected dead state") + } + m.mu.Unlock() + + // Recover + m.updateState(ctx, "peer1", true) + + if recoveredNode != "peer1" { + t.Fatalf("expected recovery callback for peer1, got %q", recoveredNode) + } + + m.mu.Lock() + if m.peers["peer1"].status != "healthy" { + m.mu.Unlock() + t.Fatal("expected healthy after recovery") + } + m.mu.Unlock() +} diff --git a/pkg/node/lifecycle/manager.go b/pkg/node/lifecycle/manager.go new file mode 100644 index 0000000..6bda68a --- /dev/null +++ b/pkg/node/lifecycle/manager.go @@ -0,0 +1,184 @@ +package lifecycle + +import ( + "fmt" + "sync" + "time" +) + +// State represents a node's lifecycle state. +type State string + +const ( + StateJoining State = "joining" + StateActive State = "active" + StateDraining State = "draining" + StateMaintenance State = "maintenance" +) + +// MaxMaintenanceTTL is the maximum duration a node can remain in maintenance +// mode. The leader's health monitor enforces this limit — nodes that exceed +// it are treated as unreachable so they can't hide in maintenance forever. +const MaxMaintenanceTTL = 15 * time.Minute + +// validTransitions defines the allowed state machine transitions. +// Each entry maps from-state → set of valid to-states. +var validTransitions = map[State]map[State]bool{ + StateJoining: {StateActive: true}, + StateActive: {StateDraining: true, StateMaintenance: true}, + StateDraining: {StateMaintenance: true}, + StateMaintenance: {StateActive: true}, +} + +// StateChangeCallback is called when the lifecycle state changes. +type StateChangeCallback func(old, new State) + +// Manager manages a node's lifecycle state machine. +// It has no external dependencies (no LibP2P, no discovery imports) +// and is fully testable in isolation. +type Manager struct { + mu sync.RWMutex + state State + maintenanceTTL time.Time + enterTime time.Time // when the current state was entered + onStateChange []StateChangeCallback +} + +// NewManager creates a new lifecycle manager in the joining state. +func NewManager() *Manager { + return &Manager{ + state: StateJoining, + enterTime: time.Now(), + } +} + +// State returns the current lifecycle state. +func (m *Manager) State() State { + m.mu.RLock() + defer m.mu.RUnlock() + return m.state +} + +// MaintenanceTTL returns the maintenance mode expiration time. +// Returns zero value if not in maintenance. +func (m *Manager) MaintenanceTTL() time.Time { + m.mu.RLock() + defer m.mu.RUnlock() + return m.maintenanceTTL +} + +// StateEnteredAt returns when the current state was entered. +func (m *Manager) StateEnteredAt() time.Time { + m.mu.RLock() + defer m.mu.RUnlock() + return m.enterTime +} + +// OnStateChange registers a callback invoked on state transitions. +// Callbacks are called with the lock released to avoid deadlocks. +func (m *Manager) OnStateChange(cb StateChangeCallback) { + m.mu.Lock() + defer m.mu.Unlock() + m.onStateChange = append(m.onStateChange, cb) +} + +// TransitionTo moves the node to a new lifecycle state. +// Returns an error if the transition is not valid. +func (m *Manager) TransitionTo(newState State) error { + m.mu.Lock() + old := m.state + + allowed, exists := validTransitions[old] + if !exists || !allowed[newState] { + m.mu.Unlock() + return fmt.Errorf("invalid lifecycle transition: %s → %s", old, newState) + } + + m.state = newState + m.enterTime = time.Now() + + // Clear maintenance TTL when leaving maintenance + if newState != StateMaintenance { + m.maintenanceTTL = time.Time{} + } + + // Copy callbacks before releasing lock + callbacks := make([]StateChangeCallback, len(m.onStateChange)) + copy(callbacks, m.onStateChange) + m.mu.Unlock() + + // Invoke callbacks without holding the lock + for _, cb := range callbacks { + cb(old, newState) + } + + return nil +} + +// EnterMaintenance transitions to maintenance with a TTL. +// The TTL is capped at MaxMaintenanceTTL. +func (m *Manager) EnterMaintenance(ttl time.Duration) error { + if ttl <= 0 { + ttl = MaxMaintenanceTTL + } + if ttl > MaxMaintenanceTTL { + ttl = MaxMaintenanceTTL + } + + m.mu.Lock() + old := m.state + + // Allow both active→maintenance and draining→maintenance + allowed, exists := validTransitions[old] + if !exists || !allowed[StateMaintenance] { + m.mu.Unlock() + return fmt.Errorf("invalid lifecycle transition: %s → %s", old, StateMaintenance) + } + + m.state = StateMaintenance + m.maintenanceTTL = time.Now().Add(ttl) + m.enterTime = time.Now() + + callbacks := make([]StateChangeCallback, len(m.onStateChange)) + copy(callbacks, m.onStateChange) + m.mu.Unlock() + + for _, cb := range callbacks { + cb(old, StateMaintenance) + } + + return nil +} + +// IsMaintenanceExpired returns true if the node is in maintenance and the TTL +// has expired. Used by the leader's health monitor to enforce the max TTL. +func (m *Manager) IsMaintenanceExpired() bool { + m.mu.RLock() + defer m.mu.RUnlock() + if m.state != StateMaintenance { + return false + } + return !m.maintenanceTTL.IsZero() && time.Now().After(m.maintenanceTTL) +} + +// IsAvailable returns true if the node is in a state that can serve requests. +func (m *Manager) IsAvailable() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.state == StateActive +} + +// IsInMaintenance returns true if the node is in maintenance mode. +func (m *Manager) IsInMaintenance() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.state == StateMaintenance +} + +// Snapshot returns a point-in-time copy of the lifecycle state for +// embedding in metadata without holding the lock. +func (m *Manager) Snapshot() (state State, ttl time.Time) { + m.mu.RLock() + defer m.mu.RUnlock() + return m.state, m.maintenanceTTL +} diff --git a/pkg/node/lifecycle/manager_test.go b/pkg/node/lifecycle/manager_test.go new file mode 100644 index 0000000..8467df4 --- /dev/null +++ b/pkg/node/lifecycle/manager_test.go @@ -0,0 +1,320 @@ +package lifecycle + +import ( + "sync" + "testing" + "time" +) + +func TestNewManager(t *testing.T) { + m := NewManager() + if m.State() != StateJoining { + t.Fatalf("expected initial state %q, got %q", StateJoining, m.State()) + } + if m.IsAvailable() { + t.Fatal("joining node should not be available") + } + if m.IsInMaintenance() { + t.Fatal("joining node should not be in maintenance") + } +} + +func TestValidTransitions(t *testing.T) { + tests := []struct { + name string + from State + to State + wantErr bool + }{ + {"joining→active", StateJoining, StateActive, false}, + {"active→draining", StateActive, StateDraining, false}, + {"draining→maintenance", StateDraining, StateMaintenance, false}, + {"active→maintenance", StateActive, StateMaintenance, false}, + {"maintenance→active", StateMaintenance, StateActive, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &Manager{state: tt.from, enterTime: time.Now()} + err := m.TransitionTo(tt.to) + if (err != nil) != tt.wantErr { + t.Fatalf("TransitionTo(%q): err=%v, wantErr=%v", tt.to, err, tt.wantErr) + } + if err == nil && m.State() != tt.to { + t.Fatalf("expected state %q, got %q", tt.to, m.State()) + } + }) + } +} + +func TestInvalidTransitions(t *testing.T) { + tests := []struct { + name string + from State + to State + }{ + {"joining→draining", StateJoining, StateDraining}, + {"joining→maintenance", StateJoining, StateMaintenance}, + {"joining→joining", StateJoining, StateJoining}, + {"active→active", StateActive, StateActive}, + {"active→joining", StateActive, StateJoining}, + {"draining→active", StateDraining, StateActive}, + {"draining→joining", StateDraining, StateJoining}, + {"maintenance→draining", StateMaintenance, StateDraining}, + {"maintenance→joining", StateMaintenance, StateJoining}, + {"maintenance→maintenance", StateMaintenance, StateMaintenance}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &Manager{state: tt.from, enterTime: time.Now()} + err := m.TransitionTo(tt.to) + if err == nil { + t.Fatalf("expected error for transition %s → %s", tt.from, tt.to) + } + }) + } +} + +func TestEnterMaintenance(t *testing.T) { + m := NewManager() + _ = m.TransitionTo(StateActive) + + err := m.EnterMaintenance(5 * time.Minute) + if err != nil { + t.Fatalf("EnterMaintenance: %v", err) + } + + if !m.IsInMaintenance() { + t.Fatal("expected maintenance state") + } + + ttl := m.MaintenanceTTL() + if ttl.IsZero() { + t.Fatal("expected non-zero maintenance TTL") + } + + // TTL should be roughly 5 minutes from now + remaining := time.Until(ttl) + if remaining < 4*time.Minute || remaining > 6*time.Minute { + t.Fatalf("expected TTL ~5min from now, got %v", remaining) + } +} + +func TestEnterMaintenanceTTLCapped(t *testing.T) { + m := NewManager() + _ = m.TransitionTo(StateActive) + + // Request 1 hour, should be capped at MaxMaintenanceTTL + err := m.EnterMaintenance(1 * time.Hour) + if err != nil { + t.Fatalf("EnterMaintenance: %v", err) + } + + ttl := m.MaintenanceTTL() + remaining := time.Until(ttl) + if remaining > MaxMaintenanceTTL+time.Second { + t.Fatalf("TTL should be capped at %v, got %v remaining", MaxMaintenanceTTL, remaining) + } +} + +func TestEnterMaintenanceZeroTTL(t *testing.T) { + m := NewManager() + _ = m.TransitionTo(StateActive) + + // Zero TTL should default to MaxMaintenanceTTL + err := m.EnterMaintenance(0) + if err != nil { + t.Fatalf("EnterMaintenance: %v", err) + } + + ttl := m.MaintenanceTTL() + remaining := time.Until(ttl) + if remaining < MaxMaintenanceTTL-time.Second { + t.Fatalf("zero TTL should default to MaxMaintenanceTTL, got %v remaining", remaining) + } +} + +func TestMaintenanceTTLClearedOnExit(t *testing.T) { + m := NewManager() + _ = m.TransitionTo(StateActive) + _ = m.EnterMaintenance(5 * time.Minute) + + if m.MaintenanceTTL().IsZero() { + t.Fatal("expected non-zero TTL in maintenance") + } + + _ = m.TransitionTo(StateActive) + + if !m.MaintenanceTTL().IsZero() { + t.Fatal("expected zero TTL after leaving maintenance") + } +} + +func TestIsMaintenanceExpired(t *testing.T) { + m := &Manager{ + state: StateMaintenance, + maintenanceTTL: time.Now().Add(-1 * time.Minute), // expired 1 minute ago + enterTime: time.Now().Add(-20 * time.Minute), + } + + if !m.IsMaintenanceExpired() { + t.Fatal("expected maintenance to be expired") + } + + // Not expired + m.maintenanceTTL = time.Now().Add(5 * time.Minute) + if m.IsMaintenanceExpired() { + t.Fatal("expected maintenance to not be expired") + } + + // Not in maintenance + m.state = StateActive + if m.IsMaintenanceExpired() { + t.Fatal("expected non-maintenance state to not report expired") + } +} + +func TestStateChangeCallback(t *testing.T) { + m := NewManager() + + var callbackOld, callbackNew State + called := false + m.OnStateChange(func(old, new State) { + callbackOld = old + callbackNew = new + called = true + }) + + _ = m.TransitionTo(StateActive) + + if !called { + t.Fatal("callback was not called") + } + if callbackOld != StateJoining || callbackNew != StateActive { + t.Fatalf("callback got old=%q new=%q, want old=%q new=%q", + callbackOld, callbackNew, StateJoining, StateActive) + } +} + +func TestMultipleCallbacks(t *testing.T) { + m := NewManager() + + count := 0 + m.OnStateChange(func(_, _ State) { count++ }) + m.OnStateChange(func(_, _ State) { count++ }) + + _ = m.TransitionTo(StateActive) + + if count != 2 { + t.Fatalf("expected 2 callbacks, got %d", count) + } +} + +func TestSnapshot(t *testing.T) { + m := NewManager() + _ = m.TransitionTo(StateActive) + _ = m.EnterMaintenance(10 * time.Minute) + + state, ttl := m.Snapshot() + if state != StateMaintenance { + t.Fatalf("expected maintenance, got %q", state) + } + if ttl.IsZero() { + t.Fatal("expected non-zero TTL in snapshot") + } +} + +func TestConcurrentAccess(t *testing.T) { + m := NewManager() + _ = m.TransitionTo(StateActive) + + var wg sync.WaitGroup + // Concurrent reads + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = m.State() + _ = m.IsAvailable() + _ = m.IsInMaintenance() + _ = m.IsMaintenanceExpired() + _, _ = m.Snapshot() + }() + } + + // Concurrent maintenance enter/exit cycles + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _ = m.EnterMaintenance(1 * time.Minute) + _ = m.TransitionTo(StateActive) + }() + } + + wg.Wait() +} + +func TestStateEnteredAt(t *testing.T) { + before := time.Now() + m := NewManager() + after := time.Now() + + entered := m.StateEnteredAt() + if entered.Before(before) || entered.After(after) { + t.Fatalf("StateEnteredAt %v not between %v and %v", entered, before, after) + } + + time.Sleep(10 * time.Millisecond) + _ = m.TransitionTo(StateActive) + + newEntered := m.StateEnteredAt() + if !newEntered.After(entered) { + t.Fatal("expected StateEnteredAt to update after transition") + } +} + +func TestEnterMaintenanceFromInvalidState(t *testing.T) { + m := NewManager() // joining state + err := m.EnterMaintenance(5 * time.Minute) + if err == nil { + t.Fatal("expected error entering maintenance from joining state") + } +} + +func TestFullLifecycle(t *testing.T) { + m := NewManager() + + // joining → active + if err := m.TransitionTo(StateActive); err != nil { + t.Fatalf("joining→active: %v", err) + } + if !m.IsAvailable() { + t.Fatal("active node should be available") + } + + // active → draining + if err := m.TransitionTo(StateDraining); err != nil { + t.Fatalf("active→draining: %v", err) + } + if m.IsAvailable() { + t.Fatal("draining node should not be available") + } + + // draining → maintenance + if err := m.EnterMaintenance(10 * time.Minute); err != nil { + t.Fatalf("draining→maintenance: %v", err) + } + if !m.IsInMaintenance() { + t.Fatal("should be in maintenance") + } + + // maintenance → active + if err := m.TransitionTo(StateActive); err != nil { + t.Fatalf("maintenance→active: %v", err) + } + if !m.IsAvailable() { + t.Fatal("should be available after maintenance") + } +} diff --git a/pkg/node/node.go b/pkg/node/node.go index 2686284..978a040 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -12,6 +12,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/gateway" "github.com/DeBrosOfficial/network/pkg/ipfs" "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/node/lifecycle" "github.com/DeBrosOfficial/network/pkg/pubsub" database "github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/libp2p/go-libp2p/core/host" @@ -24,6 +25,9 @@ type Node struct { logger *logging.ColoredLogger host host.Host + // Lifecycle state machine (joining → active ⇄ maintenance) + lifecycle *lifecycle.Manager + rqliteManager *database.RQLiteManager rqliteAdapter *database.RQLiteAdapter clusterDiscovery *database.ClusterDiscoveryService @@ -54,8 +58,9 @@ func NewNode(cfg *config.Config) (*Node, error) { } return &Node{ - config: cfg, - logger: logger, + config: cfg, + logger: logger, + lifecycle: lifecycle.NewManager(), }, nil } @@ -124,9 +129,20 @@ func (n *Node) Start(ctx context.Context) error { } } + // All services started — transition lifecycle: joining → active + if err := n.lifecycle.TransitionTo(lifecycle.StateActive); err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "Failed to transition lifecycle to active", zap.Error(err)) + } + + // Publish updated metadata with active lifecycle state + if n.clusterDiscovery != nil { + n.clusterDiscovery.UpdateOwnMetadata() + } + n.logger.ComponentInfo(logging.ComponentNode, "Network node started successfully", zap.String("peer_id", n.GetPeerID()), zap.Strings("listen_addrs", listenAddrs), + zap.String("lifecycle", string(n.lifecycle.State())), ) n.startConnectionMonitoring() @@ -138,6 +154,17 @@ func (n *Node) Start(ctx context.Context) error { func (n *Node) Stop() error { n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node") + // Enter maintenance so peers know we're shutting down + if n.lifecycle.IsAvailable() { + if err := n.lifecycle.EnterMaintenance(5 * time.Minute); err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "Failed to enter maintenance on shutdown", zap.Error(err)) + } + // Publish maintenance state before tearing down services + if n.clusterDiscovery != nil { + n.clusterDiscovery.UpdateOwnMetadata() + } + } + // Stop HTTP Gateway server if n.apiGatewayServer != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/pkg/node/rqlite.go b/pkg/node/rqlite.go index 359e235..8b5f4e8 100644 --- a/pkg/node/rqlite.go +++ b/pkg/node/rqlite.go @@ -34,6 +34,7 @@ func (n *Node) startRQLite(ctx context.Context) error { n.config.Discovery.RaftAdvAddress, n.config.Discovery.HttpAdvAddress, n.config.Node.DataDir, + n.lifecycle, n.logger.Logger, ) diff --git a/pkg/rqlite/cluster.go b/pkg/rqlite/cluster.go index bbdc296..af8c308 100644 --- a/pkg/rqlite/cluster.go +++ b/pkg/rqlite/cluster.go @@ -71,9 +71,10 @@ func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rq } if remotePeerCount >= requiredRemotePeers { - peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json") + // Check discovery-peers.json (safe location outside raft dir) + peersPath := filepath.Join(rqliteDataDir, "discovery-peers.json") r.discoveryService.TriggerSync() - time.Sleep(2 * time.Second) + time.Sleep(500 * time.Millisecond) if info, err := os.Stat(peersPath); err == nil && info.Size() > 10 { data, err := os.ReadFile(peersPath) @@ -97,13 +98,11 @@ func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rql if err := r.discoveryService.TriggerPeerExchange(ctx); err != nil { r.logger.Warn("Failed to trigger peer exchange during pre-start discovery", zap.Error(err)) } - time.Sleep(1 * time.Second) r.discoveryService.TriggerSync() - time.Sleep(2 * time.Second) + time.Sleep(500 * time.Millisecond) - // Wait up to 2 minutes for peer discovery - LibP2P DHT can take 60+ seconds - // to re-establish connections after simultaneous restart - discoveryDeadline := time.Now().Add(2 * time.Minute) + // Wait up to 45s for peer discovery — parallel dials compensate for the shorter deadline + discoveryDeadline := time.Now().Add(45 * time.Second) var discoveredPeers int for time.Now().Before(discoveryDeadline) { @@ -151,7 +150,7 @@ func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rql } r.discoveryService.TriggerSync() - time.Sleep(2 * time.Second) + time.Sleep(500 * time.Millisecond) return nil } @@ -182,13 +181,12 @@ func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error { } r.discoveryService.TriggerPeerExchange(ctx) - time.Sleep(2 * time.Second) r.discoveryService.TriggerSync() - time.Sleep(2 * time.Second) + time.Sleep(500 * time.Millisecond) rqliteDataDir, _ := r.rqliteDataDirPath() ourIndex := r.getRaftLogIndex() - + maxPeerIndex := uint64(0) for _, peer := range r.discoveryService.GetAllPeers() { if peer.NodeID != r.discoverConfig.RaftAdvAddress && peer.RaftLogIndex > maxPeerIndex { @@ -201,7 +199,7 @@ func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error { r.logger.Warn("Failed to clear raft state during split-brain recovery", zap.Error(err)) } r.discoveryService.TriggerPeerExchange(ctx) - time.Sleep(1 * time.Second) + time.Sleep(500 * time.Millisecond) if err := r.discoveryService.ForceWritePeersJSON(); err != nil { r.logger.Warn("Failed to write peers.json during split-brain recovery", zap.Error(err)) } @@ -326,14 +324,15 @@ func (r *RQLiteManager) hasExistingRaftState(rqliteDataDir string) bool { if info, err := os.Stat(raftLogPath); err == nil && info.Size() > 1024 { return true } - peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json") - _, err := os.Stat(peersPath) - return err == nil + // Don't check peers.json — discovery-peers.json is now written outside + // the raft dir and should not be treated as existing Raft state. + return false } func (r *RQLiteManager) clearRaftState(rqliteDataDir string) error { _ = os.Remove(filepath.Join(rqliteDataDir, "raft.db")) _ = os.Remove(filepath.Join(rqliteDataDir, "raft", "peers.json")) + _ = os.Remove(filepath.Join(rqliteDataDir, "discovery-peers.json")) return nil } diff --git a/pkg/rqlite/cluster_discovery.go b/pkg/rqlite/cluster_discovery.go index d411a5b..c291513 100644 --- a/pkg/rqlite/cluster_discovery.go +++ b/pkg/rqlite/cluster_discovery.go @@ -3,10 +3,12 @@ package rqlite import ( "context" "fmt" + "net" "sync" "time" "github.com/DeBrosOfficial/network/pkg/discovery" + "github.com/DeBrosOfficial/network/pkg/node/lifecycle" "github.com/libp2p/go-libp2p/core/host" "go.uber.org/zap" ) @@ -20,9 +22,13 @@ type ClusterDiscoveryService struct { nodeType string raftAddress string httpAddress string + wireGuardIP string // extracted from raftAddress (IP component) dataDir string minClusterSize int // Minimum cluster size required + // Lifecycle manager for this node's state machine + lifecycle *lifecycle.Manager + knownPeers map[string]*discovery.RQLiteNodeMetadata // NodeID -> Metadata peerHealth map[string]*PeerHealth // NodeID -> Health lastUpdate time.Time @@ -45,6 +51,7 @@ func NewClusterDiscoveryService( raftAddress string, httpAddress string, dataDir string, + lm *lifecycle.Manager, logger *zap.Logger, ) *ClusterDiscoveryService { minClusterSize := 1 @@ -52,6 +59,12 @@ func NewClusterDiscoveryService( minClusterSize = rqliteManager.config.MinClusterSize } + // Extract WireGuard IP from the raft address (e.g., "10.0.0.1" from "10.0.0.1:7001") + wgIP := "" + if host, _, err := net.SplitHostPort(raftAddress); err == nil { + wgIP = host + } + return &ClusterDiscoveryService{ host: h, discoveryMgr: discoveryMgr, @@ -60,8 +73,10 @@ func NewClusterDiscoveryService( nodeType: nodeType, raftAddress: raftAddress, httpAddress: httpAddress, + wireGuardIP: wgIP, dataDir: dataDir, minClusterSize: minClusterSize, + lifecycle: lm, knownPeers: make(map[string]*discovery.RQLiteNodeMetadata), peerHealth: make(map[string]*PeerHealth), updateInterval: 30 * time.Second, @@ -119,6 +134,25 @@ func (c *ClusterDiscoveryService) Stop() { c.logger.Info("Cluster discovery service stopped") } +// Lifecycle returns the node's lifecycle manager. +func (c *ClusterDiscoveryService) Lifecycle() *lifecycle.Manager { + return c.lifecycle +} + +// GetPeerLifecycleState returns the lifecycle state and last-seen time for a +// peer identified by its RQLite node ID (raft address). This method implements +// the MetadataReader interface used by the health monitor. +func (c *ClusterDiscoveryService) GetPeerLifecycleState(nodeID string) (state string, lastSeen time.Time, found bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + peer, ok := c.knownPeers[nodeID] + if !ok { + return "", time.Time{}, false + } + return peer.EffectiveLifecycleState(), peer.LastSeen, true +} + // IsVoter returns true if the given raft address should be a voter // in the default cluster based on the current known peers. func (c *ClusterDiscoveryService) IsVoter(raftAddress string) bool { diff --git a/pkg/rqlite/cluster_discovery_membership.go b/pkg/rqlite/cluster_discovery_membership.go index 7f2ff83..f1260b4 100644 --- a/pkg/rqlite/cluster_discovery_membership.go +++ b/pkg/rqlite/cluster_discovery_membership.go @@ -39,6 +39,17 @@ func (c *ClusterDiscoveryService) collectPeerMetadata() []*discovery.RQLiteNodeM RaftLogIndex: c.rqliteManager.getRaftLogIndex(), LastSeen: time.Now(), ClusterVersion: "1.0", + PeerID: c.host.ID().String(), + WireGuardIP: c.wireGuardIP, + } + + // Populate lifecycle state + if c.lifecycle != nil { + state, ttl := c.lifecycle.Snapshot() + ourMetadata.LifecycleState = string(state) + if state == "maintenance" { + ourMetadata.MaintenanceTTL = ttl + } } if c.adjustSelfAdvertisedAddresses(ourMetadata) { @@ -272,7 +283,7 @@ func (c *ClusterDiscoveryService) getPeersJSONUnlocked() []map[string]interface{ } // computeVoterSet returns the set of raft addresses that should be voters. -// It sorts addresses by their IP component and selects the first maxVoters. +// It sorts addresses by their numeric IP and selects the first maxVoters. // This is deterministic — all nodes compute the same voter set from the same peer list. func computeVoterSet(raftAddrs []string, maxVoters int) map[string]struct{} { sorted := make([]string, len(raftAddrs)) @@ -281,7 +292,7 @@ func computeVoterSet(raftAddrs []string, maxVoters int) map[string]struct{} { sort.Slice(sorted, func(i, j int) bool { ipI := extractIPForSort(sorted[i]) ipJ := extractIPForSort(sorted[j]) - return ipI < ipJ + return compareIPs(ipI, ipJ) }) voters := make(map[string]struct{}) @@ -303,6 +314,31 @@ func extractIPForSort(raftAddr string) string { return host } +// compareIPs compares two IP strings numerically (not alphabetically). +// Alphabetical sort gives wrong results: "10.0.0.10" < "10.0.0.2" alphabetically, +// but numerically 10.0.0.2 < 10.0.0.10. This was causing wrong nodes to be +// selected as voters (e.g., 10.0.0.1, 10.0.0.10, 10.0.0.11 instead of 10.0.0.1-5). +func compareIPs(a, b string) bool { + ipA := net.ParseIP(a) + ipB := net.ParseIP(b) + + // Fallback to string comparison if parsing fails + if ipA == nil || ipB == nil { + return a < b + } + + // Normalize to 16-byte representation for consistent comparison + ipA = ipA.To16() + ipB = ipB.To16() + + for i := range ipA { + if ipA[i] != ipB[i] { + return ipA[i] < ipB[i] + } + } + return false +} + // IsVoter returns true if the given raft address is in the voter set // based on the current known peers. Must be called with c.mu held. func (c *ClusterDiscoveryService) IsVoterLocked(raftAddress string) bool { @@ -328,6 +364,14 @@ func (c *ClusterDiscoveryService) writePeersJSON() error { return c.writePeersJSONWithData(peers) } +// writePeersJSONWithData writes the discovery peers file to a SAFE location +// outside the raft directory. This is critical: rqlite v8 treats any +// peers.json inside /raft/ as a recovery signal and RESETS +// the Raft configuration on startup. Writing there on every periodic sync +// caused split-brain on every node restart. +// +// Safe location: /rqlite/discovery-peers.json +// Dangerous location: /rqlite/raft/peers.json (only for explicit recovery) func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error { dataDir := os.ExpandEnv(c.dataDir) if strings.HasPrefix(dataDir, "~") { @@ -338,30 +382,25 @@ func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]inte dataDir = filepath.Join(home, dataDir[1:]) } - rqliteDir := filepath.Join(dataDir, "rqlite", "raft") + // Write to /rqlite/ — NOT inside raft/ subdirectory. + // rqlite v8 auto-recovers from raft/peers.json on every startup, + // which resets the Raft config and causes split-brain. + rqliteDir := filepath.Join(dataDir, "rqlite") if err := os.MkdirAll(rqliteDir, 0755); err != nil { - return fmt.Errorf("failed to create raft directory %s: %w", rqliteDir, err) + return fmt.Errorf("failed to create rqlite directory %s: %w", rqliteDir, err) } - peersFile := filepath.Join(rqliteDir, "peers.json") - backupFile := filepath.Join(rqliteDir, "peers.json.backup") - - if _, err := os.Stat(peersFile); err == nil { - data, err := os.ReadFile(peersFile) - if err == nil { - _ = os.WriteFile(backupFile, data, 0644) - } - } + peersFile := filepath.Join(rqliteDir, "discovery-peers.json") data, err := json.MarshalIndent(peers, "", " ") if err != nil { - return fmt.Errorf("failed to marshal peers.json: %w", err) + return fmt.Errorf("failed to marshal discovery-peers.json: %w", err) } tempFile := peersFile + ".tmp" if err := os.WriteFile(tempFile, data, 0644); err != nil { - return fmt.Errorf("failed to write temp peers.json %s: %w", tempFile, err) + return fmt.Errorf("failed to write temp discovery-peers.json %s: %w", tempFile, err) } if err := os.Rename(tempFile, peersFile); err != nil { @@ -375,7 +414,57 @@ func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]inte } } - c.logger.Info("peers.json written", + c.logger.Debug("discovery-peers.json written", + zap.Int("peers", len(peers)), + zap.Strings("nodes", nodeIDs)) + + return nil +} + +// writeRecoveryPeersJSON writes peers.json to the raft directory for +// INTENTIONAL cluster recovery only. rqlite v8 will read this file on +// startup and reset the Raft configuration accordingly. Only call this +// when you explicitly want to trigger Raft recovery. +func (c *ClusterDiscoveryService) writeRecoveryPeersJSON(peers []map[string]interface{}) error { + dataDir := os.ExpandEnv(c.dataDir) + if strings.HasPrefix(dataDir, "~") { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("failed to determine home directory: %w", err) + } + dataDir = filepath.Join(home, dataDir[1:]) + } + + raftDir := filepath.Join(dataDir, "rqlite", "raft") + + if err := os.MkdirAll(raftDir, 0755); err != nil { + return fmt.Errorf("failed to create raft directory %s: %w", raftDir, err) + } + + peersFile := filepath.Join(raftDir, "peers.json") + + data, err := json.MarshalIndent(peers, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal recovery peers.json: %w", err) + } + + tempFile := peersFile + ".tmp" + if err := os.WriteFile(tempFile, data, 0644); err != nil { + return fmt.Errorf("failed to write temp recovery peers.json %s: %w", tempFile, err) + } + + if err := os.Rename(tempFile, peersFile); err != nil { + return fmt.Errorf("failed to rename %s to %s: %w", tempFile, peersFile, err) + } + + nodeIDs := make([]string, 0, len(peers)) + for _, p := range peers { + if id, ok := p["id"].(string); ok { + nodeIDs = append(nodeIDs, id) + } + } + + c.logger.Warn("RECOVERY peers.json written to raft directory — rqlited will reset Raft config on next startup", zap.Int("peers", len(peers)), zap.Strings("nodes", nodeIDs)) diff --git a/pkg/rqlite/cluster_discovery_queries.go b/pkg/rqlite/cluster_discovery_queries.go index 3d0960f..a45b9a2 100644 --- a/pkg/rqlite/cluster_discovery_queries.go +++ b/pkg/rqlite/cluster_discovery_queries.go @@ -128,9 +128,12 @@ func (c *ClusterDiscoveryService) TriggerSync() { c.updateClusterMembership() } -// ForceWritePeersJSON forces writing peers.json regardless of membership changes +// ForceWritePeersJSON writes peers.json to the RAFT directory for intentional +// cluster recovery. rqlite v8 will read this on startup and reset its Raft +// configuration. Only call this when you explicitly want Raft recovery +// (e.g., after clearing raft state or during split-brain recovery). func (c *ClusterDiscoveryService) ForceWritePeersJSON() error { - c.logger.Info("Force writing peers.json") + c.logger.Info("Force writing recovery peers.json to raft directory") metadata := c.collectPeerMetadata() @@ -153,16 +156,17 @@ func (c *ClusterDiscoveryService) ForceWritePeersJSON() error { peers := c.getPeersJSONUnlocked() c.mu.Unlock() - if err := c.writePeersJSONWithData(peers); err != nil { - c.logger.Error("Failed to force write peers.json", + // Write to RAFT directory — this is intentional recovery + if err := c.writeRecoveryPeersJSON(peers); err != nil { + c.logger.Error("Failed to force write recovery peers.json", zap.Error(err), zap.String("data_dir", c.dataDir), zap.Int("peers", len(peers))) return err } - c.logger.Info("peers.json written", - zap.Int("peers", len(peers))) + // Also update discovery location + _ = c.writePeersJSONWithData(peers) return nil } @@ -179,7 +183,9 @@ func (c *ClusterDiscoveryService) TriggerPeerExchange(ctx context.Context) error return nil } -// UpdateOwnMetadata updates our own RQLite metadata in the peerstore +// UpdateOwnMetadata updates our own RQLite metadata in the peerstore. +// This is called periodically and after significant state changes (lifecycle +// transitions, service status updates) to ensure peers have current info. func (c *ClusterDiscoveryService) UpdateOwnMetadata() { c.mu.RLock() currentRaftAddr := c.raftAddress @@ -194,6 +200,17 @@ func (c *ClusterDiscoveryService) UpdateOwnMetadata() { RaftLogIndex: c.rqliteManager.getRaftLogIndex(), LastSeen: time.Now(), ClusterVersion: "1.0", + PeerID: c.host.ID().String(), + WireGuardIP: c.wireGuardIP, + } + + // Populate lifecycle state from the lifecycle manager + if c.lifecycle != nil { + state, ttl := c.lifecycle.Snapshot() + metadata.LifecycleState = string(state) + if state == "maintenance" { + metadata.MaintenanceTTL = ttl + } } if c.adjustSelfAdvertisedAddresses(metadata) { @@ -215,7 +232,41 @@ func (c *ClusterDiscoveryService) UpdateOwnMetadata() { c.logger.Debug("Metadata updated", zap.String("node", metadata.NodeID), - zap.Uint64("log_index", metadata.RaftLogIndex)) + zap.Uint64("log_index", metadata.RaftLogIndex), + zap.String("lifecycle", metadata.LifecycleState)) +} + +// ProvideMetadata builds and returns the current node metadata without storing it. +// Implements discovery.MetadataProvider so the MetadataPublisher can call this +// on a regular interval and store the result in the peerstore. +func (c *ClusterDiscoveryService) ProvideMetadata() *discovery.RQLiteNodeMetadata { + c.mu.RLock() + currentRaftAddr := c.raftAddress + currentHTTPAddr := c.httpAddress + c.mu.RUnlock() + + metadata := &discovery.RQLiteNodeMetadata{ + NodeID: currentRaftAddr, + RaftAddress: currentRaftAddr, + HTTPAddress: currentHTTPAddr, + NodeType: c.nodeType, + RaftLogIndex: c.rqliteManager.getRaftLogIndex(), + LastSeen: time.Now(), + ClusterVersion: "1.0", + PeerID: c.host.ID().String(), + WireGuardIP: c.wireGuardIP, + } + + if c.lifecycle != nil { + state, ttl := c.lifecycle.Snapshot() + metadata.LifecycleState = string(state) + if state == "maintenance" { + metadata.MaintenanceTTL = ttl + } + } + + c.adjustSelfAdvertisedAddresses(metadata) + return metadata } // StoreRemotePeerMetadata stores metadata received from a remote peer diff --git a/pkg/rqlite/instance_spawner.go b/pkg/rqlite/instance_spawner.go index f34348e..afb8f85 100644 --- a/pkg/rqlite/instance_spawner.go +++ b/pkg/rqlite/instance_spawner.go @@ -83,6 +83,14 @@ func (is *InstanceSpawner) SpawnInstance(ctx context.Context, cfg InstanceConfig "-raft-adv-addr", cfg.RaftAdvAddress, } + // Raft tuning — match the global node's tuning for consistency + args = append(args, + "-raft-election-timeout", "5s", + "-raft-heartbeat-timeout", "2s", + "-raft-apply-timeout", "30s", + "-raft-leader-lease-timeout", "5s", + ) + // Add join addresses if not the leader (must be before data directory) if !cfg.IsLeader && len(cfg.JoinAddresses) > 0 { for _, addr := range cfg.JoinAddresses { diff --git a/pkg/rqlite/leadership.go b/pkg/rqlite/leadership.go new file mode 100644 index 0000000..b78a143 --- /dev/null +++ b/pkg/rqlite/leadership.go @@ -0,0 +1,131 @@ +package rqlite + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "go.uber.org/zap" +) + +// TransferLeadership attempts to transfer Raft leadership to another voter. +// Used by both the RQLiteManager (on Stop) and the CLI (pre-upgrade). +// Returns nil if this node is not the leader or if transfer succeeds. +func TransferLeadership(port int, logger *zap.Logger) error { + client := &http.Client{Timeout: 5 * time.Second} + + // 1. Check if we're the leader + statusURL := fmt.Sprintf("http://localhost:%d/status", port) + resp, err := client.Get(statusURL) + if err != nil { + return fmt.Errorf("failed to query status: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read status: %w", err) + } + + var status RQLiteStatus + if err := json.Unmarshal(body, &status); err != nil { + return fmt.Errorf("failed to parse status: %w", err) + } + + if status.Store.Raft.State != "Leader" { + logger.Debug("Not the leader, skipping transfer", zap.Int("port", port)) + return nil + } + + logger.Info("This node is the Raft leader, attempting leadership transfer", + zap.Int("port", port), + zap.String("leader_id", status.Store.Raft.LeaderID)) + + // 2. Find an eligible voter to transfer to + nodesURL := fmt.Sprintf("http://localhost:%d/nodes?nonvoters&ver=2&timeout=5s", port) + nodesResp, err := client.Get(nodesURL) + if err != nil { + return fmt.Errorf("failed to query nodes: %w", err) + } + defer nodesResp.Body.Close() + + nodesBody, err := io.ReadAll(nodesResp.Body) + if err != nil { + return fmt.Errorf("failed to read nodes: %w", err) + } + + // Try ver=2 wrapped format, fall back to plain array + var nodes RQLiteNodes + var wrapped struct { + Nodes RQLiteNodes `json:"nodes"` + } + if err := json.Unmarshal(nodesBody, &wrapped); err == nil && wrapped.Nodes != nil { + nodes = wrapped.Nodes + } else { + _ = json.Unmarshal(nodesBody, &nodes) + } + + // Find a reachable voter that is NOT us + var targetID string + for _, n := range nodes { + if n.Voter && n.Reachable && n.ID != status.Store.Raft.LeaderID { + targetID = n.ID + break + } + } + + if targetID == "" { + logger.Warn("No eligible voter found for leadership transfer — will rely on SIGTERM graceful step-down", + zap.Int("port", port)) + return nil + } + + // 3. Attempt transfer via rqlite v8+ API + // POST /nodes//transfer-leadership + // If the API doesn't exist (404), fall back to relying on SIGTERM. + transferURL := fmt.Sprintf("http://localhost:%d/nodes/%s/transfer-leadership", port, targetID) + transferResp, err := client.Post(transferURL, "application/json", nil) + if err != nil { + logger.Warn("Leadership transfer request failed, relying on SIGTERM", + zap.Error(err)) + return nil + } + transferResp.Body.Close() + + if transferResp.StatusCode == http.StatusNotFound { + logger.Info("Leadership transfer API not available (rqlite version), relying on SIGTERM") + return nil + } + + if transferResp.StatusCode != http.StatusOK { + logger.Warn("Leadership transfer returned unexpected status", + zap.Int("status", transferResp.StatusCode)) + return nil + } + + // 4. Verify transfer + time.Sleep(2 * time.Second) + verifyResp, err := client.Get(statusURL) + if err != nil { + logger.Info("Could not verify transfer (node may have already stepped down)") + return nil + } + defer verifyResp.Body.Close() + + verifyBody, _ := io.ReadAll(verifyResp.Body) + var newStatus RQLiteStatus + if err := json.Unmarshal(verifyBody, &newStatus); err == nil { + if newStatus.Store.Raft.State != "Leader" { + logger.Info("Leadership transferred successfully", + zap.String("new_leader", newStatus.Store.Raft.LeaderID), + zap.Int("port", port)) + } else { + logger.Warn("Still leader after transfer attempt — will rely on SIGTERM", + zap.Int("port", port)) + } + } + + return nil +} diff --git a/pkg/rqlite/process.go b/pkg/rqlite/process.go index 894fbe4..8df7bed 100644 --- a/pkg/rqlite/process.go +++ b/pkg/rqlite/process.go @@ -66,6 +66,29 @@ func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) // Kill any orphaned rqlited from a previous crash r.killOrphanedRQLite() + // Remove stale peers.json from the raft directory to prevent rqlite v8 + // from triggering automatic Raft recovery on normal restarts. + // + // Only delete when raft.db EXISTS (normal restart). If raft.db does NOT + // exist, peers.json was likely placed intentionally by ForceWritePeersJSON() + // as part of a recovery flow (clearRaftState + ForceWritePeersJSON + launch). + stalePeersPath := filepath.Join(rqliteDataDir, "raft", "peers.json") + raftDBPath := filepath.Join(rqliteDataDir, "raft.db") + if _, err := os.Stat(stalePeersPath); err == nil { + if _, err := os.Stat(raftDBPath); err == nil { + // raft.db exists → this is a normal restart, peers.json is stale + r.logger.Warn("Removing stale peers.json from raft directory to prevent accidental recovery", + zap.String("path", stalePeersPath)) + _ = os.Remove(stalePeersPath) + _ = os.Remove(stalePeersPath + ".backup") + _ = os.Remove(stalePeersPath + ".tmp") + } else { + // raft.db missing → intentional recovery, keep peers.json for rqlited + r.logger.Info("Keeping peers.json in raft directory for intentional cluster recovery", + zap.String("path", stalePeersPath)) + } + } + // Build RQLite command args := []string{ "-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort), @@ -90,6 +113,30 @@ func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) } } + // Raft tuning — higher timeouts suit WireGuard latency + raftElection := r.config.RaftElectionTimeout + if raftElection == 0 { + raftElection = 5 * time.Second + } + raftHeartbeat := r.config.RaftHeartbeatTimeout + if raftHeartbeat == 0 { + raftHeartbeat = 2 * time.Second + } + raftApply := r.config.RaftApplyTimeout + if raftApply == 0 { + raftApply = 30 * time.Second + } + raftLeaderLease := r.config.RaftLeaderLeaseTimeout + if raftLeaderLease == 0 { + raftLeaderLease = 5 * time.Second + } + args = append(args, + "-raft-election-timeout", raftElection.String(), + "-raft-heartbeat-timeout", raftHeartbeat.String(), + "-raft-apply-timeout", raftApply.String(), + "-raft-leader-lease-timeout", raftLeaderLease.String(), + ) + if r.config.RQLiteJoinAddress != "" && !r.hasExistingState(rqliteDataDir) { r.logger.Info("First-time join to RQLite cluster", zap.String("join_address", r.config.RQLiteJoinAddress)) diff --git a/pkg/rqlite/rqlite.go b/pkg/rqlite/rqlite.go index 30c034a..eda0c44 100644 --- a/pkg/rqlite/rqlite.go +++ b/pkg/rqlite/rqlite.go @@ -70,6 +70,7 @@ func (r *RQLiteManager) Start(ctx context.Context) error { if r.discoveryService != nil { go r.startHealthMonitoring(ctx) go r.startVoterReconciliation(ctx) + go r.startOrphanedNodeRecovery(ctx) // C1 fix: recover nodes orphaned by failed voter changes } // Start child process watchdog to detect and recover from crashes @@ -138,21 +139,9 @@ func (r *RQLiteManager) Stop() error { // transferLeadershipIfLeader checks if this node is the Raft leader and // requests a leadership transfer to minimize election disruption. func (r *RQLiteManager) transferLeadershipIfLeader() { - status, err := r.getRQLiteStatus() - if err != nil { - return + if err := TransferLeadership(r.config.RQLitePort, r.logger); err != nil { + r.logger.Warn("Leadership transfer failed, relying on SIGTERM", zap.Error(err)) } - if status.Store.Raft.State != "Leader" { - return - } - - r.logger.Info("This node is the Raft leader, requesting leadership transfer before shutdown") - - // RQLite doesn't have a direct leadership transfer API, but we can - // signal readiness to step down. The fastest approach is to let the - // SIGTERM handler in rqlited handle this — rqlite v8 gracefully - // steps down on SIGTERM when possible. We log the state for visibility. - r.logger.Info("Leader will transfer on SIGTERM (rqlite built-in behavior)") } // cleanupPIDFile removes the PID file on shutdown diff --git a/pkg/rqlite/voter_reconciliation.go b/pkg/rqlite/voter_reconciliation.go index d98254d..747ea18 100644 --- a/pkg/rqlite/voter_reconciliation.go +++ b/pkg/rqlite/voter_reconciliation.go @@ -7,15 +7,32 @@ import ( "fmt" "io" "net/http" + "sync" "time" "go.uber.org/zap" ) +const ( + // voterChangeCooldown is how long to wait after a failed voter change + // before retrying the same node. + voterChangeCooldown = 10 * time.Minute +) + +// voterReconciler holds voter change cooldown state. +type voterReconciler struct { + mu sync.Mutex + cooldowns map[string]time.Time // nodeID → earliest next attempt +} + // startVoterReconciliation periodically checks and corrects voter/non-voter // assignments. Only takes effect on the leader node. Corrects at most one // node per cycle to minimize disruption. func (r *RQLiteManager) startVoterReconciliation(ctx context.Context) { + reconciler := &voterReconciler{ + cooldowns: make(map[string]time.Time), + } + // Wait for cluster to stabilize after startup time.Sleep(3 * time.Minute) @@ -27,21 +44,104 @@ func (r *RQLiteManager) startVoterReconciliation(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - if err := r.reconcileVoters(); err != nil { + if err := r.reconcileVoters(reconciler); err != nil { r.logger.Debug("Voter reconciliation skipped", zap.Error(err)) } } } } +// startOrphanedNodeRecovery runs every 5 minutes on the leader. It scans for +// nodes that appear in the discovery peer list but NOT in the Raft cluster +// (orphaned by a failed remove+rejoin during voter reconciliation). For each +// orphaned node, it re-adds them via POST /join. (C1 fix) +func (r *RQLiteManager) startOrphanedNodeRecovery(ctx context.Context) { + // Wait for cluster to stabilize + time.Sleep(5 * time.Minute) + + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + r.recoverOrphanedNodes() + } + } +} + +// recoverOrphanedNodes finds nodes known to discovery but missing from the +// Raft cluster and re-adds them. +func (r *RQLiteManager) recoverOrphanedNodes() { + if r.discoveryService == nil { + return + } + + // Only the leader runs orphan recovery + status, err := r.getRQLiteStatus() + if err != nil || status.Store.Raft.State != "Leader" { + return + } + + // Get all Raft cluster members + raftNodes, err := r.getAllClusterNodes() + if err != nil { + return + } + raftNodeSet := make(map[string]bool, len(raftNodes)) + for _, n := range raftNodes { + raftNodeSet[n.ID] = true + } + + // Get all discovery peers + discoveryPeers := r.discoveryService.GetAllPeers() + + for _, peer := range discoveryPeers { + if peer.RaftAddress == r.discoverConfig.RaftAdvAddress { + continue // skip self + } + if raftNodeSet[peer.RaftAddress] { + continue // already in cluster + } + + // This peer is in discovery but not in Raft — it's orphaned + r.logger.Warn("Found orphaned node (in discovery but not in Raft cluster), re-adding", + zap.String("node_raft_addr", peer.RaftAddress), + zap.String("node_id", peer.NodeID)) + + // Determine voter status + raftAddrs := make([]string, 0, len(discoveryPeers)) + for _, p := range discoveryPeers { + raftAddrs = append(raftAddrs, p.RaftAddress) + } + voters := computeVoterSet(raftAddrs, MaxDefaultVoters) + _, shouldBeVoter := voters[peer.RaftAddress] + + if err := r.joinClusterNode(peer.RaftAddress, peer.RaftAddress, shouldBeVoter); err != nil { + r.logger.Error("Failed to re-add orphaned node", + zap.String("node", peer.RaftAddress), + zap.Bool("voter", shouldBeVoter), + zap.Error(err)) + } else { + r.logger.Info("Successfully re-added orphaned node to Raft cluster", + zap.String("node", peer.RaftAddress), + zap.Bool("voter", shouldBeVoter)) + } + } +} + // reconcileVoters compares actual cluster voter assignments (from RQLite's // /nodes endpoint) against the deterministic desired set (computeVoterSet) -// and corrects mismatches. Uses remove + re-join since RQLite's /join -// ignores voter flag changes for existing members. +// and corrects mismatches. // -// Safety: only runs on the leader, only when all nodes are reachable, -// never demotes the leader, and fixes at most one node per cycle. -func (r *RQLiteManager) reconcileVoters() error { +// Improvements over original: +// - Promotion: tries direct POST /join with voter=true first (no remove needed) +// - Leader stability: verifies leader is stable before demotion +// - Cooldown: skips nodes that recently failed a voter change +// - Fixes at most one node per cycle +func (r *RQLiteManager) reconcileVoters(reconciler *voterReconciler) error { // 1. Only the leader reconciles status, err := r.getRQLiteStatus() if err != nil { @@ -68,14 +168,21 @@ func (r *RQLiteManager) reconcileVoters() error { } } - // 4. Compute desired voter set from raft addresses + // 4. Leader stability: verify term hasn't changed recently + // (Re-check status to confirm we're still the stable leader) + status2, err := r.getRQLiteStatus() + if err != nil || status2.Store.Raft.State != "Leader" || status2.Store.Raft.Term != status.Store.Raft.Term { + return fmt.Errorf("leader state changed during reconciliation check") + } + + // 5. Compute desired voter set from raft addresses raftAddrs := make([]string, 0, len(nodes)) for _, n := range nodes { raftAddrs = append(raftAddrs, n.ID) } desiredVoters := computeVoterSet(raftAddrs, MaxDefaultVoters) - // 5. Safety: never demote ourselves (the current leader) + // 6. Safety: never demote ourselves (the current leader) myRaftAddr := status.Store.Raft.LeaderID if _, shouldBeVoter := desiredVoters[myRaftAddr]; !shouldBeVoter { r.logger.Warn("Leader is not in computed voter set — skipping reconciliation", @@ -83,10 +190,19 @@ func (r *RQLiteManager) reconcileVoters() error { return nil } - // 6. Find one mismatch to fix (one change per cycle) + // 7. Find one mismatch to fix (one change per cycle) for _, n := range nodes { _, shouldBeVoter := desiredVoters[n.ID] + // Check cooldown + reconciler.mu.Lock() + cooldownUntil, hasCooldown := reconciler.cooldowns[n.ID] + if hasCooldown && time.Now().Before(cooldownUntil) { + reconciler.mu.Unlock() + continue + } + reconciler.mu.Unlock() + if n.Voter && !shouldBeVoter { // Skip if this is the leader if n.ID == myRaftAddr { @@ -100,6 +216,9 @@ func (r *RQLiteManager) reconcileVoters() error { r.logger.Warn("Failed to demote voter", zap.String("node_id", n.ID), zap.Error(err)) + reconciler.mu.Lock() + reconciler.cooldowns[n.ID] = time.Now().Add(voterChangeCooldown) + reconciler.mu.Unlock() return err } @@ -112,10 +231,24 @@ func (r *RQLiteManager) reconcileVoters() error { r.logger.Info("Promoting non-voter to voter", zap.String("node_id", n.ID)) + // Try direct promotion first (POST /join with voter=true) + if err := r.joinClusterNode(n.ID, n.ID, true); err == nil { + r.logger.Info("Successfully promoted non-voter to voter (direct join)", + zap.String("node_id", n.ID)) + return nil + } + + // Direct join didn't change voter status, fall back to remove+rejoin + r.logger.Info("Direct promotion didn't work, trying remove+rejoin", + zap.String("node_id", n.ID)) + if err := r.changeNodeVoterStatus(n.ID, true); err != nil { r.logger.Warn("Failed to promote non-voter", zap.String("node_id", n.ID), zap.Error(err)) + reconciler.mu.Lock() + reconciler.cooldowns[n.ID] = time.Now().Add(voterChangeCooldown) + reconciler.mu.Unlock() return err } @@ -130,13 +263,12 @@ func (r *RQLiteManager) reconcileVoters() error { // changeNodeVoterStatus changes a node's voter status by removing it from the // cluster and immediately re-adding it with the desired voter flag. -// This is necessary because RQLite's /join endpoint ignores voter flag changes -// for nodes that are already cluster members with the same ID and address. // // Safety improvements: -// - Pre-check: verify quorum would survive the temporary removal -// - Rollback: if rejoin fails, attempt to re-add with original status -// - Retry: attempt rejoin up to 3 times with backoff +// - Pre-check: verify quorum would survive the temporary removal +// - Pre-check: verify target node is still reachable +// - Rollback: if rejoin fails, attempt to re-add with original status +// - Retry: 5 attempts with exponential backoff (2s, 4s, 8s, 15s, 30s) func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error { // Pre-check: if demoting a voter, verify quorum safety if !voter { @@ -145,34 +277,53 @@ func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error { return fmt.Errorf("quorum pre-check: %w", err) } voterCount := 0 + targetReachable := false for _, n := range nodes { if n.Voter && n.Reachable { voterCount++ } + if n.ID == nodeID && n.Reachable { + targetReachable = true + } + } + if !targetReachable { + return fmt.Errorf("target node %s is not reachable, skipping voter change", nodeID) } // After removing this voter, we need (voterCount-1)/2 + 1 for quorum - // which means voterCount-1 > (voterCount-1)/2, i.e., voterCount >= 3 if voterCount <= 2 { return fmt.Errorf("cannot remove voter: only %d reachable voters, quorum would be lost", voterCount) } } + // Fresh quorum check immediately before removal + nodes, err := r.getAllClusterNodes() + if err != nil { + return fmt.Errorf("fresh quorum check: %w", err) + } + for _, n := range nodes { + if !n.Reachable { + return fmt.Errorf("node %s is unreachable, aborting voter change", n.ID) + } + } + // Step 1: Remove the node from the cluster if err := r.removeClusterNode(nodeID); err != nil { return fmt.Errorf("remove node: %w", err) } // Wait for Raft to commit the configuration change, then rejoin with retries + // Exponential backoff: 2s, 4s, 8s, 15s, 30s + backoffs := []time.Duration{2 * time.Second, 4 * time.Second, 8 * time.Second, 15 * time.Second, 30 * time.Second} var lastErr error - for attempt := 0; attempt < 3; attempt++ { - waitTime := time.Duration(2+attempt*2) * time.Second // 2s, 4s, 6s - time.Sleep(waitTime) + for attempt, wait := range backoffs { + time.Sleep(wait) if err := r.joinClusterNode(nodeID, nodeID, voter); err != nil { lastErr = err r.logger.Warn("Rejoin attempt failed, retrying", zap.String("node_id", nodeID), zap.Int("attempt", attempt+1), + zap.Int("max_attempts", len(backoffs)), zap.Error(err)) continue } @@ -187,12 +338,12 @@ func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error { originalVoter := !voter if err := r.joinClusterNode(nodeID, nodeID, originalVoter); err != nil { - r.logger.Error("Rollback also failed — node may be orphaned from cluster", + r.logger.Error("Rollback also failed — node may be orphaned (orphan recovery will re-add it)", zap.String("node_id", nodeID), zap.Error(err)) } - return fmt.Errorf("rejoin node after 3 attempts: %w", lastErr) + return fmt.Errorf("rejoin node after %d attempts: %w", len(backoffs), lastErr) } // getAllClusterNodes queries /nodes?nonvoters&ver=2 to get all cluster members diff --git a/pkg/rqlite/watchdog.go b/pkg/rqlite/watchdog.go index 9c35d4a..7669fd2 100644 --- a/pkg/rqlite/watchdog.go +++ b/pkg/rqlite/watchdog.go @@ -10,13 +10,34 @@ import ( ) const ( - watchdogInterval = 30 * time.Second + // watchdogInterval is how often we check if rqlited is alive. + watchdogInterval = 30 * time.Second + + // watchdogMaxRestart is the maximum number of restart attempts before giving up. watchdogMaxRestart = 3 + + // watchdogGracePeriod is how long to wait after a restart before + // the watchdog starts checking. This gives rqlited time to rejoin + // the Raft cluster — Raft election timeouts + log replay can take + // 60-120 seconds after a restart. + watchdogGracePeriod = 120 * time.Second ) // startProcessWatchdog monitors the RQLite child process and restarts it if it crashes. -// It checks both process liveness and HTTP responsiveness. +// It only restarts when the process has actually DIED (exited). It does NOT kill +// rqlited for being slow to find a leader — that's normal during cluster rejoin. func (r *RQLiteManager) startProcessWatchdog(ctx context.Context) { + // Wait for the grace period before starting to monitor. + // rqlited needs time to: + // 1. Open the raft log and snapshots + // 2. Reconnect to existing Raft peers + // 3. Either rejoin as follower or participate in a new election + select { + case <-ctx.Done(): + return + case <-time.After(watchdogGracePeriod): + } + ticker := time.NewTicker(watchdogInterval) defer ticker.Stop() @@ -46,10 +67,21 @@ func (r *RQLiteManager) startProcessWatchdog(ctx context.Context) { restartCount++ r.logger.Info("RQLite process restarted by watchdog", zap.Int("restart_count", restartCount)) + + // Give the restarted process time to stabilize before checking again + select { + case <-ctx.Done(): + return + case <-time.After(watchdogGracePeriod): + } } else { - // Process is alive — check HTTP responsiveness - if !r.isHTTPResponsive() { - r.logger.Warn("RQLite process is alive but not responding to HTTP") + // Process is alive — reset restart counter on sustained health + if r.isHTTPResponsive() { + if restartCount > 0 { + r.logger.Info("RQLite process has stabilized, resetting restart counter", + zap.Int("previous_restart_count", restartCount)) + restartCount = 0 + } } } } diff --git a/scripts/clean-testnet.sh b/scripts/clean-testnet.sh new file mode 100755 index 0000000..3c6c6b3 --- /dev/null +++ b/scripts/clean-testnet.sh @@ -0,0 +1,208 @@ +#!/usr/bin/env bash +# +# Clean all testnet nodes for fresh reinstall. +# Preserves Anyone relay keys (/var/lib/anon/) for --anyone-migrate. +# DOES NOT TOUCH DEVNET NODES. +# +# Usage: scripts/clean-testnet.sh [--nuclear] +# --nuclear Also remove shared binaries (rqlited, ipfs, coredns, caddy, etc.) +# +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +CONF="$ROOT_DIR/scripts/remote-nodes.conf" + +[[ -f "$CONF" ]] || { echo "ERROR: Missing $CONF"; exit 1; } +command -v sshpass >/dev/null 2>&1 || { echo "ERROR: sshpass not installed (brew install sshpass / apt install sshpass)"; exit 1; } + +NUCLEAR=false +[[ "${1:-}" == "--nuclear" ]] && NUCLEAR=true + +SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=10 -o LogLevel=ERROR -o PubkeyAuthentication=no) + +# ── Cleanup script (runs as root on each remote node) ───────────────────── +# Uses a quoted heredoc so NO local variable expansion happens. +# This script is uploaded to /tmp/orama-clean.sh and executed remotely. +CLEANUP_SCRIPT=$(cat <<'SCRIPT_END' +#!/bin/bash +set -e +export DEBIAN_FRONTEND=noninteractive + +echo " Stopping services..." +systemctl stop debros-node debros-gateway debros-ipfs debros-ipfs-cluster debros-olric debros-anyone-relay debros-anyone-client coredns caddy 2>/dev/null || true +systemctl disable debros-node debros-gateway debros-ipfs debros-ipfs-cluster debros-olric debros-anyone-relay debros-anyone-client coredns caddy 2>/dev/null || true + +echo " Removing systemd service files..." +rm -f /etc/systemd/system/debros-*.service +rm -f /etc/systemd/system/coredns.service +rm -f /etc/systemd/system/caddy.service +rm -f /etc/systemd/system/orama-deploy-*.service +systemctl daemon-reload + +echo " Tearing down WireGuard..." +systemctl stop wg-quick@wg0 2>/dev/null || true +wg-quick down wg0 2>/dev/null || true +systemctl disable wg-quick@wg0 2>/dev/null || true +rm -f /etc/wireguard/wg0.conf + +echo " Resetting UFW firewall..." +ufw --force reset +ufw allow 22/tcp +ufw --force enable + +echo " Killing debros processes..." +pkill -u debros 2>/dev/null || true +sleep 1 + +echo " Removing debros user and data..." +userdel -r debros 2>/dev/null || true +rm -rf /home/debros + +echo " Removing sudoers files..." +rm -f /etc/sudoers.d/debros-access +rm -f /etc/sudoers.d/debros-deployments +rm -f /etc/sudoers.d/debros-wireguard + +echo " Removing CoreDNS and Caddy configs..." +rm -rf /etc/coredns +rm -rf /etc/caddy +rm -rf /var/lib/caddy + +echo " Cleaning temp files..." +rm -f /tmp/orama /tmp/network-source.tar.gz /tmp/network-source.zip +rm -rf /tmp/network-extract /tmp/coredns-build /tmp/caddy-build + +# Nuclear: also remove shared binaries +if [ "${1:-}" = "--nuclear" ]; then + echo " Removing shared binaries (nuclear)..." + rm -f /usr/local/bin/rqlited + rm -f /usr/local/bin/ipfs + rm -f /usr/local/bin/ipfs-cluster-service + rm -f /usr/local/bin/olric-server + rm -f /usr/local/bin/coredns + rm -f /usr/local/bin/xcaddy + rm -f /usr/bin/caddy + rm -f /usr/local/bin/orama +fi + +# Verify Anyone relay keys are preserved +if [ -d /var/lib/anon/keys ]; then + echo " Anyone relay keys PRESERVED at /var/lib/anon/keys" + if [ -f /var/lib/anon/fingerprint ]; then + fp=$(cat /var/lib/anon/fingerprint 2>/dev/null || true) + echo " Relay fingerprint: $fp" + fi + if [ -f /var/lib/anon/wallet ]; then + wallet=$(cat /var/lib/anon/wallet 2>/dev/null || true) + echo " Relay wallet: $wallet" + fi +else + echo " WARNING: No Anyone relay keys found at /var/lib/anon/" +fi + +echo " DONE" +SCRIPT_END +) + +# ── Parse testnet nodes only ────────────────────────────────────────────── +hosts=() +passes=() +users=() + +while IFS='|' read -r env hostspec pass role key; do + [[ -z "$env" || "$env" == \#* ]] && continue + env="${env%%#*}" + env="$(echo "$env" | xargs)" + [[ "$env" != "testnet" ]] && continue + + hosts+=("$hostspec") + passes+=("$pass") + users+=("${hostspec%%@*}") +done < "$CONF" + +if [[ ${#hosts[@]} -eq 0 ]]; then + echo "ERROR: No testnet nodes found in $CONF" + exit 1 +fi + +echo "== clean-testnet.sh — ${#hosts[@]} testnet nodes ==" +for i in "${!hosts[@]}"; do + echo " [$((i+1))] ${hosts[$i]}" +done +echo "" +echo "This will CLEAN all testnet nodes (stop services, remove data)." +echo "Anyone relay keys (/var/lib/anon/) will be PRESERVED." +echo "Devnet nodes will NOT be touched." +$NUCLEAR && echo "Nuclear mode: shared binaries will also be removed." +echo "" +read -rp "Type 'yes' to continue: " confirm +if [[ "$confirm" != "yes" ]]; then + echo "Aborted." + exit 0 +fi + +# ── Execute cleanup on each node ────────────────────────────────────────── +failed=() +succeeded=0 +NUCLEAR_FLAG="" +$NUCLEAR && NUCLEAR_FLAG="--nuclear" + +for i in "${!hosts[@]}"; do + h="${hosts[$i]}" + p="${passes[$i]}" + u="${users[$i]}" + echo "" + echo "== [$((i+1))/${#hosts[@]}] Cleaning $h ==" + + # Step 1: Upload cleanup script + # No -n flag here — we're piping the script content via stdin + if ! echo "$CLEANUP_SCRIPT" | sshpass -p "$p" ssh "${SSH_OPTS[@]}" "$h" \ + "cat > /tmp/orama-clean.sh && chmod +x /tmp/orama-clean.sh" 2>&1; then + echo " !! FAILED to upload script to $h" + failed+=("$h") + continue + fi + + # Step 2: Execute the cleanup script as root + if [[ "$u" == "root" ]]; then + # Root: run directly + if ! sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" \ + "bash /tmp/orama-clean.sh $NUCLEAR_FLAG; rm -f /tmp/orama-clean.sh" 2>&1; then + echo " !! FAILED: $h" + failed+=("$h") + continue + fi + else + # Non-root: escape password for single-quote embedding, pipe to sudo -S + escaped_p=$(printf '%s' "$p" | sed "s/'/'\\\\''/g") + if ! sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" \ + "printf '%s\n' '${escaped_p}' | sudo -S bash /tmp/orama-clean.sh $NUCLEAR_FLAG; rm -f /tmp/orama-clean.sh" 2>&1; then + echo " !! FAILED: $h" + failed+=("$h") + continue + fi + fi + + echo " OK: $h cleaned" + ((succeeded++)) || true +done + +echo "" +echo "========================================" +echo "Cleanup complete: $succeeded succeeded, ${#failed[@]} failed" +if [[ ${#failed[@]} -gt 0 ]]; then + echo "" + echo "Failed nodes:" + for f in "${failed[@]}"; do + echo " - $f" + done + echo "" + echo "Troubleshooting:" + echo " 1. Check connectivity: ssh @" + echo " 2. Check password in remote-nodes.conf" + echo " 3. Try cleaning manually: docs/CLEAN_NODE.md" +fi +echo "" +echo "Anyone relay keys preserved at /var/lib/anon/ on all nodes." +echo "Use --anyone-migrate during install to reuse existing relay identity." +echo "========================================" diff --git a/scripts/recover-rqlite.sh b/scripts/recover-rqlite.sh new file mode 100644 index 0000000..15b3be3 --- /dev/null +++ b/scripts/recover-rqlite.sh @@ -0,0 +1,289 @@ +#!/usr/bin/env bash +# +# Recover RQLite cluster from split-brain. +# +# Strategy: +# 1. Stop debros-node on ALL nodes simultaneously +# 2. Keep raft/ data ONLY on the node with the highest commit index (leader candidate) +# 3. Delete raft/ on all other nodes (they'll join fresh via -join) +# 4. Start the leader candidate first, wait for it to become Leader +# 5. Start all other nodes — they discover the leader via LibP2P and join +# 6. Verify cluster health +# +# Usage: +# scripts/recover-rqlite.sh --devnet --leader 57.129.7.232 +# scripts/recover-rqlite.sh --testnet --leader +# +set -euo pipefail + +# ── Parse flags ────────────────────────────────────────────────────────────── +ENV="" +LEADER_HOST="" + +for arg in "$@"; do + case "$arg" in + --devnet) ENV="devnet" ;; + --testnet) ENV="testnet" ;; + --leader=*) LEADER_HOST="${arg#--leader=}" ;; + -h|--help) + echo "Usage: scripts/recover-rqlite.sh --devnet|--testnet --leader=" + exit 0 + ;; + *) + echo "Unknown flag: $arg" >&2 + exit 1 + ;; + esac +done + +if [[ -z "$ENV" ]]; then + echo "ERROR: specify --devnet or --testnet" >&2 + exit 1 +fi +if [[ -z "$LEADER_HOST" ]]; then + echo "ERROR: specify --leader= (the node with highest commit index)" >&2 + exit 1 +fi + +# ── Paths ──────────────────────────────────────────────────────────────────── +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +CONF="$ROOT_DIR/scripts/remote-nodes.conf" + +die() { echo "ERROR: $*" >&2; exit 1; } +[[ -f "$CONF" ]] || die "Missing $CONF" + +# ── Load nodes from conf ──────────────────────────────────────────────────── +HOSTS=() +PASSES=() +ROLES=() +SSH_KEYS=() + +while IFS='|' read -r env host pass role key; do + [[ -z "$env" || "$env" == \#* ]] && continue + env="${env%%#*}" + env="$(echo "$env" | xargs)" + [[ "$env" != "$ENV" ]] && continue + + HOSTS+=("$host") + PASSES+=("$pass") + ROLES+=("${role:-node}") + SSH_KEYS+=("${key:-}") +done < "$CONF" + +if [[ ${#HOSTS[@]} -eq 0 ]]; then + die "No nodes found for environment '$ENV' in $CONF" +fi + +echo "== recover-rqlite.sh ($ENV) — ${#HOSTS[@]} nodes ==" +echo "Leader candidate: $LEADER_HOST" +echo "" + +# Find leader index +LEADER_IDX=-1 +for i in "${!HOSTS[@]}"; do + if [[ "${HOSTS[$i]}" == *"$LEADER_HOST"* ]]; then + LEADER_IDX=$i + break + fi +done + +if [[ $LEADER_IDX -eq -1 ]]; then + die "Leader host '$LEADER_HOST' not found in node list" +fi + +echo "Nodes:" +for i in "${!HOSTS[@]}"; do + marker="" + [[ $i -eq $LEADER_IDX ]] && marker=" ← LEADER (keep data)" + echo " [$i] ${HOSTS[$i]} (${ROLES[$i]})$marker" +done +echo "" + +# ── SSH helpers ────────────────────────────────────────────────────────────── +SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=10) + +node_ssh() { + local idx="$1" + shift + local h="${HOSTS[$idx]}" + local p="${PASSES[$idx]}" + local k="${SSH_KEYS[$idx]:-}" + + if [[ -n "$k" ]]; then + local expanded_key="${k/#\~/$HOME}" + if [[ -f "$expanded_key" ]]; then + ssh -i "$expanded_key" "${SSH_OPTS[@]}" "$h" "$@" 2>/dev/null + return $? + fi + fi + sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" "$@" 2>/dev/null +} + +# ── Confirmation ───────────────────────────────────────────────────────────── +echo "⚠️ THIS WILL:" +echo " 1. Stop debros-node on ALL ${#HOSTS[@]} nodes" +echo " 2. DELETE raft/ data on ${#HOSTS[@]}-1 nodes (backup to /tmp/rqlite-raft-backup/)" +echo " 3. Keep raft/ data ONLY on ${HOSTS[$LEADER_IDX]} (leader candidate)" +echo " 4. Restart all nodes to reform the cluster" +echo "" +read -r -p "Continue? [y/N] " confirm +if [[ "$confirm" != "y" && "$confirm" != "Y" ]]; then + echo "Aborted." + exit 0 +fi +echo "" + +RAFT_DIR="/home/debros/.orama/data/rqlite/raft" +BACKUP_DIR="/tmp/rqlite-raft-backup" + +# ── Phase 1: Stop debros-node on ALL nodes ─────────────────────────────────── +echo "== Phase 1: Stopping debros-node on all ${#HOSTS[@]} nodes ==" +failed=() +for i in "${!HOSTS[@]}"; do + h="${HOSTS[$i]}" + p="${PASSES[$i]}" + echo -n " Stopping $h ... " + if node_ssh "$i" "printf '%s\n' '$p' | sudo -S systemctl stop debros-node 2>&1 && echo STOPPED"; then + echo "" + else + echo "FAILED" + failed+=("$h") + fi +done + +if [[ ${#failed[@]} -gt 0 ]]; then + echo "" + echo "⚠️ ${#failed[@]} nodes failed to stop. Attempting kill..." + for i in "${!HOSTS[@]}"; do + h="${HOSTS[$i]}" + p="${PASSES[$i]}" + for fh in "${failed[@]}"; do + if [[ "$h" == "$fh" ]]; then + node_ssh "$i" "printf '%s\n' '$p' | sudo -S killall -9 orama-node rqlited 2>/dev/null; echo KILLED" || true + fi + done + done +fi + +echo "" +echo "Waiting 5s for processes to fully stop..." +sleep 5 + +# ── Phase 2: Backup and delete raft/ on non-leader nodes ──────────────────── +echo "== Phase 2: Clearing raft state on non-leader nodes ==" +for i in "${!HOSTS[@]}"; do + [[ $i -eq $LEADER_IDX ]] && continue + + h="${HOSTS[$i]}" + p="${PASSES[$i]}" + echo -n " Clearing $h ... " + if node_ssh "$i" " + printf '%s\n' '$p' | sudo -S bash -c ' + rm -rf $BACKUP_DIR + if [ -d $RAFT_DIR ]; then + cp -r $RAFT_DIR $BACKUP_DIR 2>/dev/null || true + rm -rf $RAFT_DIR + echo \"CLEARED (backup at $BACKUP_DIR)\" + else + echo \"NO_RAFT_DIR (nothing to clear)\" + fi + ' + "; then + true + else + echo "FAILED" + fi +done + +echo "" +echo "Leader node ${HOSTS[$LEADER_IDX]} raft/ data preserved." + +# ── Phase 3: Start leader node ────────────────────────────────────────────── +echo "" +echo "== Phase 3: Starting leader node (${HOSTS[$LEADER_IDX]}) ==" +lp="${PASSES[$LEADER_IDX]}" +node_ssh "$LEADER_IDX" "printf '%s\n' '$lp' | sudo -S systemctl start debros-node" || die "Failed to start leader node" + +echo " Waiting for leader to become Leader..." +max_wait=120 +elapsed=0 +while [[ $elapsed -lt $max_wait ]]; do + state=$(node_ssh "$LEADER_IDX" "curl -s --max-time 3 http://localhost:5001/status 2>/dev/null | python3 -c \"import sys,json; d=json.load(sys.stdin); print(d.get('store',{}).get('raft',{}).get('state',''))\" 2>/dev/null" || echo "") + if [[ "$state" == "Leader" ]]; then + echo " ✓ Leader node is Leader after ${elapsed}s" + break + fi + echo " ... state=$state (${elapsed}s / ${max_wait}s)" + sleep 5 + ((elapsed+=5)) +done + +if [[ "$state" != "Leader" ]]; then + echo " ⚠️ Leader did not become Leader within ${max_wait}s (state=$state)" + echo " The node may need more time. Continuing anyway..." +fi + +# ── Phase 4: Start all other nodes ────────────────────────────────────────── +echo "" +echo "== Phase 4: Starting remaining nodes ==" + +# Start non-leader nodes in batches of 3 with 15s between batches +batch_size=3 +batch_count=0 +for i in "${!HOSTS[@]}"; do + [[ $i -eq $LEADER_IDX ]] && continue + + h="${HOSTS[$i]}" + p="${PASSES[$i]}" + echo -n " Starting $h ... " + if node_ssh "$i" "printf '%s\n' '$p' | sudo -S systemctl start debros-node && echo STARTED"; then + true + else + echo "FAILED" + fi + + ((batch_count++)) + if [[ $((batch_count % batch_size)) -eq 0 ]]; then + echo " (waiting 15s between batches for cluster stability)" + sleep 15 + fi +done + +# ── Phase 5: Wait and verify ──────────────────────────────────────────────── +echo "" +echo "== Phase 5: Waiting for cluster to form (120s) ==" +sleep 30 +echo " ... 30s" +sleep 30 +echo " ... 60s" +sleep 30 +echo " ... 90s" +sleep 30 +echo " ... 120s" + +echo "" +echo "== Cluster status ==" +for i in "${!HOSTS[@]}"; do + h="${HOSTS[$i]}" + result=$(node_ssh "$i" "curl -s --max-time 5 http://localhost:5001/status 2>/dev/null | python3 -c \" +import sys,json +try: + d=json.load(sys.stdin) + r=d.get('store',{}).get('raft',{}) + n=d.get('store',{}).get('num_nodes','?') + print(f'state={r.get(\"state\",\"?\")} commit={r.get(\"commit_index\",\"?\")} leader={r.get(\"leader\",{}).get(\"node_id\",\"?\")} nodes={n}') +except: + print('NO_RESPONSE') +\" 2>/dev/null" || echo "SSH_FAILED") + marker="" + [[ $i -eq $LEADER_IDX ]] && marker=" ← LEADER" + echo " ${HOSTS[$i]}: $result$marker" +done + +echo "" +echo "== Recovery complete ==" +echo "" +echo "Next steps:" +echo " 1. Run 'scripts/inspect.sh --devnet' to verify full cluster health" +echo " 2. If some nodes show Candidate state, give them more time (up to 5 min)" +echo " 3. If nodes fail to join, check /home/debros/.orama/logs/rqlite-node.log on the node" diff --git a/scripts/redeploy.sh b/scripts/redeploy.sh index 5ce4d16..ca34fc5 100755 --- a/scripts/redeploy.sh +++ b/scripts/redeploy.sh @@ -3,18 +3,19 @@ # Redeploy to all nodes in a given environment (devnet or testnet). # Reads node credentials from scripts/remote-nodes.conf. # -# Flow (per docs/DEV_DEPLOY.md): -# 1) make build-linux +# Flow: +# 1) make build-linux-all # 2) scripts/generate-source-archive.sh -> /tmp/network-source.tar.gz # 3) scp archive + extract-deploy.sh + conf to hub node # 4) from hub: sshpass scp to all other nodes + sudo bash /tmp/extract-deploy.sh -# 5) rolling: upgrade followers one-by-one, leader last +# 5) rolling upgrade: followers first, leader last +# per node: pre-upgrade -> stop -> extract binary -> post-upgrade # # Usage: # scripts/redeploy.sh --devnet # scripts/redeploy.sh --testnet # scripts/redeploy.sh --devnet --no-build -# scripts/redeploy.sh --testnet --no-build +# scripts/redeploy.sh --devnet --skip-build # set -euo pipefail @@ -26,14 +27,14 @@ for arg in "$@"; do case "$arg" in --devnet) ENV="devnet" ;; --testnet) ENV="testnet" ;; - --no-build) NO_BUILD=1 ;; + --no-build|--skip-build) NO_BUILD=1 ;; -h|--help) - echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build]" + echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build|--skip-build]" exit 0 ;; *) echo "Unknown flag: $arg" >&2 - echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build]" >&2 + echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build|--skip-build]" >&2 exit 1 ;; esac @@ -106,9 +107,9 @@ echo "Hub: $HUB_HOST (idx=$HUB_IDX, key=${HUB_KEY:-none})" # ── Build ──────────────────────────────────────────────────────────────────── if [[ "$NO_BUILD" -eq 0 ]]; then - echo "== build-linux ==" - (cd "$ROOT_DIR" && make build-linux) || { - echo "WARN: make build-linux failed; continuing if existing bin-linux is acceptable." + echo "== build-linux-all ==" + (cd "$ROOT_DIR" && make build-linux-all) || { + echo "WARN: make build-linux-all failed; continuing if existing bin-linux is acceptable." } else echo "== skipping build (--no-build) ==" @@ -192,12 +193,16 @@ if [[ ${#hosts[@]} -gt 0 ]] && ! command -v sshpass >/dev/null 2>&1; then fi echo "== fan-out: upload to ${#hosts[@]} nodes ==" +upload_failed=() for i in "${!hosts[@]}"; do h="${hosts[$i]}" p="${passes[$i]}" echo " -> $h" - sshpass -p "$p" scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - "$TAR" "$EX" "$h":/tmp/ + if ! sshpass -p "$p" scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ + "$TAR" "$EX" "$h":/tmp/; then + echo " !! UPLOAD FAILED: $h" + upload_failed+=("$h") + fi done echo "== extract on all fan-out nodes ==" @@ -205,10 +210,22 @@ for i in "${!hosts[@]}"; do h="${hosts[$i]}" p="${passes[$i]}" echo " -> $h" - sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - "$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK" + if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ + "$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK"; then + echo " !! EXTRACT FAILED: $h" + upload_failed+=("$h") + fi done +if [[ ${#upload_failed[@]} -gt 0 ]]; then + echo "" + echo "WARNING: ${#upload_failed[@]} nodes had upload/extract failures:" + for uf in "${upload_failed[@]}"; do + echo " - $uf" + done + echo "Continuing with rolling restart..." +fi + echo "== extract on hub ==" printf '%s\n' "$hub_pass" | sudo -S bash "$EX" >/tmp/extract.log 2>&1 @@ -253,44 +270,131 @@ if [[ -z "$leader" ]]; then fi echo "Leader: $leader" +failed_nodes=() + +# ── Per-node upgrade flow ── +# Uses pre-upgrade (maintenance + leadership transfer + propagation wait) +# then stops, deploys binary, and post-upgrade (start + health verification). upgrade_one() { local h="$1" p="$2" echo "== upgrade $h ==" - sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ - "$h" "printf '%s\n' '$p' | sudo -S orama prod stop && printf '%s\n' '$p' | sudo -S orama upgrade --no-pull --pre-built --restart" + + # 1. Pre-upgrade: enter maintenance, transfer leadership, wait for propagation + echo " [1/4] pre-upgrade..." + if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ + "$h" "printf '%s\n' '$p' | sudo -S orama prod pre-upgrade" 2>&1; then + echo " !! pre-upgrade failed on $h (continuing with stop)" + fi + + # 2. Stop all services + echo " [2/4] stopping services..." + if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ + "$h" "printf '%s\n' '$p' | sudo -S systemctl stop 'debros-*'" 2>&1; then + echo " !! stop failed on $h" + failed_nodes+=("$h") + return 1 + fi + + # 3. Deploy new binary + echo " [3/4] deploying binary..." + if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ + "$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK" 2>&1; then + echo " !! extract failed on $h" + failed_nodes+=("$h") + return 1 + fi + + # 4. Post-upgrade: start services, verify health, exit maintenance + echo " [4/4] post-upgrade..." + if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ + "$h" "printf '%s\n' '$p' | sudo -S orama prod post-upgrade" 2>&1; then + echo " !! post-upgrade failed on $h" + failed_nodes+=("$h") + return 1 + fi + + echo " OK: $h" } upgrade_hub() { echo "== upgrade hub (localhost) ==" - printf '%s\n' "$hub_pass" | sudo -S orama prod stop - printf '%s\n' "$hub_pass" | sudo -S orama upgrade --no-pull --pre-built --restart + + # 1. Pre-upgrade + echo " [1/4] pre-upgrade..." + if ! (printf '%s\n' "$hub_pass" | sudo -S orama prod pre-upgrade) 2>&1; then + echo " !! pre-upgrade failed on hub (continuing with stop)" + fi + + # 2. Stop all services + echo " [2/4] stopping services..." + if ! (printf '%s\n' "$hub_pass" | sudo -S systemctl stop 'debros-*') 2>&1; then + echo " !! stop failed on hub ($hub_host)" + failed_nodes+=("$hub_host (hub)") + return 1 + fi + + # 3. Deploy new binary + echo " [3/4] deploying binary..." + if ! (printf '%s\n' "$hub_pass" | sudo -S bash "$EX" >/tmp/extract.log 2>&1); then + echo " !! extract failed on hub ($hub_host)" + failed_nodes+=("$hub_host (hub)") + return 1 + fi + + # 4. Post-upgrade + echo " [4/4] post-upgrade..." + if ! (printf '%s\n' "$hub_pass" | sudo -S orama prod post-upgrade) 2>&1; then + echo " !! post-upgrade failed on hub ($hub_host)" + failed_nodes+=("$hub_host (hub)") + return 1 + fi + + echo " OK: hub ($hub_host)" } -echo "== rolling upgrade (followers first) ==" +echo "== rolling upgrade (followers first, leader last) ==" for i in "${!hosts[@]}"; do h="${hosts[$i]}" p="${passes[$i]}" [[ "$h" == "$leader" ]] && continue - upgrade_one "$h" "$p" + upgrade_one "$h" "$p" || true done # Upgrade hub if not the leader if [[ "$leader" != "HUB" ]]; then - upgrade_hub + upgrade_hub || true fi # Upgrade leader last echo "== upgrade leader last ==" if [[ "$leader" == "HUB" ]]; then - upgrade_hub + upgrade_hub || true else - upgrade_one "$leader" "$leader_pass" + upgrade_one "$leader" "$leader_pass" || true fi # Clean up conf from hub rm -f "$CONF" -echo "Done." +# ── Report results ── +echo "" +echo "========================================" +if [[ ${#failed_nodes[@]} -gt 0 ]]; then + echo "UPGRADE COMPLETED WITH FAILURES (${#failed_nodes[@]} nodes failed):" + for fn in "${failed_nodes[@]}"; do + echo " FAILED: $fn" + done + echo "" + echo "Recommended actions:" + echo " 1. SSH into the failed node(s)" + echo " 2. Check logs: sudo orama prod logs node --follow" + echo " 3. Manually run: sudo orama prod post-upgrade" + echo "========================================" + exit 1 +else + echo "All nodes upgraded successfully." + echo "========================================" +fi REMOTE echo "== complete =="