Bro i did so many things to fix the problematic discovery and redeployment and i dont even remember what i did

This commit is contained in:
anonpenguin23 2026-02-14 10:56:26 +02:00
parent afbb7d4ede
commit 749d5ed5e7
37 changed files with 3559 additions and 201 deletions

View File

@ -43,6 +43,10 @@ func HandleCommand(args []string) {
case "restart":
force := hasFlag(subargs, "--force")
lifecycle.HandleRestartWithFlags(force)
case "pre-upgrade":
lifecycle.HandlePreUpgrade()
case "post-upgrade":
lifecycle.HandlePostUpgrade()
case "logs":
logs.Handle(subargs)
case "uninstall":
@ -105,6 +109,10 @@ func ShowHelp() {
fmt.Printf(" restart - Restart all production services (requires root/sudo)\n")
fmt.Printf(" Options:\n")
fmt.Printf(" --force - Bypass quorum safety check\n")
fmt.Printf(" pre-upgrade - Prepare node for safe restart (requires root/sudo)\n")
fmt.Printf(" Transfers leadership, enters maintenance mode, waits for propagation\n")
fmt.Printf(" post-upgrade - Bring node back online after restart (requires root/sudo)\n")
fmt.Printf(" Starts services, verifies RQLite health, exits maintenance\n")
fmt.Printf(" logs <service> - View production service logs\n")
fmt.Printf(" Service aliases: node, ipfs, cluster, gateway, olric\n")
fmt.Printf(" Options:\n")

View File

@ -19,6 +19,7 @@ import (
libp2ppubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/DeBrosOfficial/network/pkg/encryption"
"github.com/DeBrosOfficial/network/pkg/pubsub"
)
@ -144,6 +145,30 @@ func (c *Client) Connect() error {
libp2p.DefaultMuxers,
)
opts = append(opts, libp2p.Transport(tcp.NewTCPTransport))
// Load or create persistent identity if IdentityPath is configured
if c.config.IdentityPath != "" {
identity, loadErr := encryption.LoadIdentity(c.config.IdentityPath)
if loadErr != nil {
// File doesn't exist yet — generate and save
identity, loadErr = encryption.GenerateIdentity()
if loadErr != nil {
return fmt.Errorf("failed to generate identity: %w", loadErr)
}
if saveErr := encryption.SaveIdentity(identity, c.config.IdentityPath); saveErr != nil {
return fmt.Errorf("failed to save identity: %w", saveErr)
}
c.logger.Info("Generated new persistent identity",
zap.String("peer_id", identity.PeerID.String()),
zap.String("path", c.config.IdentityPath))
} else {
c.logger.Info("Loaded persistent identity",
zap.String("peer_id", identity.PeerID.String()),
zap.String("path", c.config.IdentityPath))
}
opts = append(opts, libp2p.Identity(identity.PrivateKey))
}
// Enable QUIC only when not proxying. When proxy is enabled, prefer TCP via SOCKS5.
h, err := libp2p.New(opts...)
if err != nil {

View File

@ -18,6 +18,7 @@ type ClientConfig struct {
QuietMode bool `json:"quiet_mode"` // Suppress debug/info logs
APIKey string `json:"api_key"` // API key for gateway auth
JWT string `json:"jwt"` // Optional JWT bearer token
IdentityPath string `json:"identity_path"` // Path to persistent LibP2P identity key file
}
// DefaultClientConfig returns a default client configuration

View File

@ -0,0 +1,92 @@
package client
import (
"os"
"path/filepath"
"testing"
"github.com/DeBrosOfficial/network/pkg/encryption"
)
func TestPersistentIdentity_NoPath(t *testing.T) {
// Without IdentityPath, Connect() generates a random ID each time.
// We can't easily test Connect() (needs network), so verify config defaults.
cfg := DefaultClientConfig("test-app")
if cfg.IdentityPath != "" {
t.Fatalf("expected empty IdentityPath by default, got %q", cfg.IdentityPath)
}
}
func TestPersistentIdentity_GenerateAndReload(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "identity.key")
// 1. No file exists — generate + save
id1, err := encryption.GenerateIdentity()
if err != nil {
t.Fatalf("GenerateIdentity: %v", err)
}
if err := encryption.SaveIdentity(id1, keyPath); err != nil {
t.Fatalf("SaveIdentity: %v", err)
}
// File should exist
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Fatal("identity key file was not created")
}
// 2. Load it back — same PeerID
id2, err := encryption.LoadIdentity(keyPath)
if err != nil {
t.Fatalf("LoadIdentity: %v", err)
}
if id1.PeerID != id2.PeerID {
t.Fatalf("PeerID mismatch: generated %s, loaded %s", id1.PeerID, id2.PeerID)
}
// 3. Load again — still the same
id3, err := encryption.LoadIdentity(keyPath)
if err != nil {
t.Fatalf("LoadIdentity (second): %v", err)
}
if id2.PeerID != id3.PeerID {
t.Fatalf("PeerID changed across loads: %s vs %s", id2.PeerID, id3.PeerID)
}
}
func TestPersistentIdentity_DifferentFromRandom(t *testing.T) {
// A persistent identity should be different from a freshly generated one
id1, err := encryption.GenerateIdentity()
if err != nil {
t.Fatalf("GenerateIdentity: %v", err)
}
id2, err := encryption.GenerateIdentity()
if err != nil {
t.Fatalf("GenerateIdentity: %v", err)
}
if id1.PeerID == id2.PeerID {
t.Fatal("two independently generated identities should have different PeerIDs")
}
}
func TestPersistentIdentity_FilePermissions(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "subdir", "identity.key")
id, err := encryption.GenerateIdentity()
if err != nil {
t.Fatalf("GenerateIdentity: %v", err)
}
if err := encryption.SaveIdentity(id, keyPath); err != nil {
t.Fatalf("SaveIdentity: %v", err)
}
info, err := os.Stat(keyPath)
if err != nil {
t.Fatalf("Stat: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Fatalf("expected file permissions 0600, got %o", perm)
}
}

View File

@ -22,6 +22,13 @@ type DatabaseConfig struct {
NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set)
NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs)
// Raft tuning (passed through to rqlited CLI flags).
// Higher defaults than rqlited's 1s suit WireGuard latency.
RaftElectionTimeout time.Duration `yaml:"raft_election_timeout"` // default: 5s
RaftHeartbeatTimeout time.Duration `yaml:"raft_heartbeat_timeout"` // default: 2s
RaftApplyTimeout time.Duration `yaml:"raft_apply_timeout"` // default: 30s
RaftLeaderLeaseTimeout time.Duration `yaml:"raft_leader_lease_timeout"` // default: 5s
// Dynamic discovery configuration (always enabled)
ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s
PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h

View File

@ -7,6 +7,7 @@ import (
"io"
"strconv"
"strings"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/host"
@ -80,7 +81,11 @@ type Manager struct {
host host.Host
logger *zap.Logger
cancel context.CancelFunc
failedPeerExchanges map[peer.ID]time.Time // Track failed peer exchange attempts to suppress repeated warnings
// failedMu protects failedPeerExchanges from concurrent access during
// parallel peer exchange dials (H3 fix).
failedMu sync.Mutex
failedPeerExchanges map[peer.ID]time.Time
}
// Config contains discovery configuration
@ -364,8 +369,8 @@ func (d *Manager) discoverViaPeerExchange(ctx context.Context, maxConnections in
// Add to peerstore (only valid addresses with port 4001)
d.host.Peerstore().AddAddrs(parsedID, addrs, time.Hour*24)
// Try to connect
connectCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
// Try to connect (5s timeout — WireGuard peers respond fast)
connectCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
peerAddrInfo := peer.AddrInfo{ID: parsedID, Addrs: addrs}
if err := d.host.Connect(connectCtx, peerAddrInfo); err != nil {
@ -401,15 +406,15 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
// Open a stream to the peer
stream, err := d.host.NewStream(ctx, peerID, PeerExchangeProtocol)
if err != nil {
// Check if this is a "protocols not supported" error (expected for lightweight clients like gateway)
d.failedMu.Lock()
if strings.Contains(err.Error(), "protocols not supported") {
// This is a lightweight client (gateway, etc.) that doesn't support peer exchange - expected behavior
// Track it to avoid repeated attempts, but don't log as it's not an error
// Lightweight client (gateway, etc.) — expected, track to suppress retries
d.failedPeerExchanges[peerID] = time.Now()
d.failedMu.Unlock()
return nil
}
// For actual connection errors, log but suppress repeated warnings for the same peer
// Actual connection error — log but suppress repeated warnings
lastFailure, seen := d.failedPeerExchanges[peerID]
if !seen || time.Since(lastFailure) > time.Minute {
d.logger.Debug("Failed to open peer exchange stream with node",
@ -418,12 +423,15 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
zap.Error(err))
d.failedPeerExchanges[peerID] = time.Now()
}
d.failedMu.Unlock()
return nil
}
defer stream.Close()
// Clear failure tracking on success
d.failedMu.Lock()
delete(d.failedPeerExchanges, peerID)
d.failedMu.Unlock()
// Send request
req := PeerExchangeRequest{Limit: limit}
@ -433,8 +441,8 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
return nil
}
// Set read deadline
if err := stream.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
// Set read deadline (5s — small JSON payload)
if err := stream.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
d.logger.Debug("Failed to set read deadline", zap.Error(err))
return nil
}
@ -451,10 +459,20 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
// Store remote peer's RQLite metadata if available
if resp.RQLiteMetadata != nil {
// Verify sender identity — prevent metadata spoofing (H2 fix).
// If the metadata contains a PeerID, it must match the stream sender.
if resp.RQLiteMetadata.PeerID != "" && resp.RQLiteMetadata.PeerID != peerID.String() {
d.logger.Warn("Rejected metadata: PeerID mismatch",
zap.String("claimed", resp.RQLiteMetadata.PeerID[:8]+"..."),
zap.String("actual", peerID.String()[:8]+"..."))
return resp.Peers
}
// Stamp verified PeerID so downstream consumers can trust it
resp.RQLiteMetadata.PeerID = peerID.String()
metadataJSON, err := json.Marshal(resp.RQLiteMetadata)
if err == nil {
_ = d.host.Peerstore().Put(peerID, "rqlite_metadata", metadataJSON)
// Only log when new metadata is stored (useful for debugging)
d.logger.Debug("Metadata stored",
zap.String("peer", peerID.String()[:8]+"..."),
zap.String("node", resp.RQLiteMetadata.NodeID))

View File

@ -0,0 +1,81 @@
package discovery
import (
"context"
"encoding/json"
"time"
"github.com/libp2p/go-libp2p/core/host"
"go.uber.org/zap"
)
// MetadataProvider is implemented by subsystems that can supply node metadata.
// The publisher calls Provide() every cycle and stores the result in the peerstore.
type MetadataProvider interface {
ProvideMetadata() *RQLiteNodeMetadata
}
// MetadataPublisher periodically writes local node metadata to the peerstore so
// it is included in every peer exchange response. This decouples metadata
// production (lifecycle, RQLite status, service health) from the exchange
// protocol itself.
type MetadataPublisher struct {
host host.Host
provider MetadataProvider
interval time.Duration
logger *zap.Logger
}
// NewMetadataPublisher creates a publisher that writes metadata every interval.
func NewMetadataPublisher(h host.Host, provider MetadataProvider, interval time.Duration, logger *zap.Logger) *MetadataPublisher {
if interval <= 0 {
interval = 10 * time.Second
}
return &MetadataPublisher{
host: h,
provider: provider,
interval: interval,
logger: logger.With(zap.String("component", "metadata-publisher")),
}
}
// Start begins the periodic publish loop. It blocks until ctx is cancelled.
func (p *MetadataPublisher) Start(ctx context.Context) {
// Publish immediately on start
p.publish()
ticker := time.NewTicker(p.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.publish()
}
}
}
// PublishNow performs a single immediate metadata publish.
// Useful after lifecycle transitions or other state changes.
func (p *MetadataPublisher) PublishNow() {
p.publish()
}
func (p *MetadataPublisher) publish() {
meta := p.provider.ProvideMetadata()
if meta == nil {
return
}
data, err := json.Marshal(meta)
if err != nil {
p.logger.Error("Failed to marshal metadata", zap.Error(err))
return
}
if err := p.host.Peerstore().Put(p.host.ID(), "rqlite_metadata", data); err != nil {
p.logger.Error("Failed to store metadata in peerstore", zap.Error(err))
}
}

View File

@ -4,18 +4,101 @@ import (
"time"
)
// RQLiteNodeMetadata contains RQLite-specific information announced via LibP2P
// ServiceStatus represents the health of an individual service on a node.
type ServiceStatus struct {
Name string `json:"name"` // e.g. "rqlite", "gateway", "olric"
Running bool `json:"running"` // whether the process is up
Healthy bool `json:"healthy"` // whether it passed its health check
Message string `json:"message,omitempty"` // optional detail ("leader", "follower", etc.)
}
// NamespaceStatus represents a namespace's status on a node.
type NamespaceStatus struct {
Name string `json:"name"`
Status string `json:"status"` // "healthy", "degraded", "recovering"
}
// RQLiteNodeMetadata contains node information announced via LibP2P peer exchange.
// This struct is the single source of truth for node metadata propagated through
// the cluster. Go's json.Unmarshal silently ignores unknown fields, so old nodes
// reading metadata from new nodes simply skip the new fields — no protocol
// version change is needed.
type RQLiteNodeMetadata struct {
NodeID string `json:"node_id"` // RQLite node ID (from config)
RaftAddress string `json:"raft_address"` // Raft port address (e.g., "51.83.128.181:7001")
HTTPAddress string `json:"http_address"` // HTTP API address (e.g., "51.83.128.181:5001")
// --- Existing fields (unchanged) ---
NodeID string `json:"node_id"` // RQLite node ID (raft address)
RaftAddress string `json:"raft_address"` // Raft port address (e.g., "10.0.0.1:7001")
HTTPAddress string `json:"http_address"` // HTTP API address (e.g., "10.0.0.1:5001")
NodeType string `json:"node_type"` // Node type identifier
RaftLogIndex uint64 `json:"raft_log_index"` // Current Raft log index (for data comparison)
LastSeen time.Time `json:"last_seen"` // Updated on every announcement
ClusterVersion string `json:"cluster_version"` // For compatibility checking
// --- New: Identity ---
// PeerID is the LibP2P peer ID of the node. Used for metadata authentication:
// on receipt, the receiver verifies PeerID == stream sender to prevent spoofing.
PeerID string `json:"peer_id,omitempty"`
// WireGuardIP is the node's WireGuard VPN address (e.g., "10.0.0.1").
WireGuardIP string `json:"wireguard_ip,omitempty"`
// --- New: Lifecycle ---
// LifecycleState is the node's current lifecycle state:
// "joining", "active", "draining", or "maintenance".
// Zero value (empty string) from old nodes is treated as "active".
LifecycleState string `json:"lifecycle_state,omitempty"`
// MaintenanceTTL is the time at which maintenance mode expires.
// Only meaningful when LifecycleState == "maintenance".
MaintenanceTTL time.Time `json:"maintenance_ttl,omitempty"`
// --- New: Services ---
// Services reports the status of each service running on the node.
Services map[string]*ServiceStatus `json:"services,omitempty"`
// Namespaces reports the status of each namespace on the node.
Namespaces map[string]*NamespaceStatus `json:"namespaces,omitempty"`
// --- New: Version ---
// BinaryVersion is the node's binary version string (e.g., "1.2.3").
BinaryVersion string `json:"binary_version,omitempty"`
}
// PeerExchangeResponseV2 extends the original response with RQLite metadata
// EffectiveLifecycleState returns the lifecycle state, defaulting to "active"
// for old nodes that don't populate the field.
func (m *RQLiteNodeMetadata) EffectiveLifecycleState() string {
if m.LifecycleState == "" {
return "active"
}
return m.LifecycleState
}
// IsInMaintenance returns true if the node has announced maintenance mode.
func (m *RQLiteNodeMetadata) IsInMaintenance() bool {
return m.EffectiveLifecycleState() == "maintenance"
}
// IsAvailable returns true if the node is in a state that can serve requests.
func (m *RQLiteNodeMetadata) IsAvailable() bool {
return m.EffectiveLifecycleState() == "active"
}
// IsMaintenanceExpired returns true if the node is in maintenance and the
// TTL has passed. Used by the leader's health monitor to enforce expiry.
func (m *RQLiteNodeMetadata) IsMaintenanceExpired() bool {
if !m.IsInMaintenance() {
return false
}
return !m.MaintenanceTTL.IsZero() && time.Now().After(m.MaintenanceTTL)
}
// PeerExchangeResponseV2 extends the original response with RQLite metadata.
// Kept for backward compatibility — the V1 PeerExchangeResponse in discovery.go
// already includes the same RQLiteMetadata field, so this is effectively unused.
type PeerExchangeResponseV2 struct {
Peers []PeerInfo `json:"peers"`
RQLiteMetadata *RQLiteNodeMetadata `json:"rqlite_metadata,omitempty"`

View File

@ -0,0 +1,235 @@
package discovery
import (
"encoding/json"
"testing"
"time"
)
func TestEffectiveLifecycleState(t *testing.T) {
tests := []struct {
name string
state string
want string
}{
{"empty defaults to active", "", "active"},
{"explicit active", "active", "active"},
{"joining", "joining", "joining"},
{"maintenance", "maintenance", "maintenance"},
{"draining", "draining", "draining"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &RQLiteNodeMetadata{LifecycleState: tt.state}
if got := m.EffectiveLifecycleState(); got != tt.want {
t.Fatalf("got %q, want %q", got, tt.want)
}
})
}
}
func TestIsInMaintenance(t *testing.T) {
m := &RQLiteNodeMetadata{LifecycleState: "maintenance"}
if !m.IsInMaintenance() {
t.Fatal("expected maintenance")
}
m.LifecycleState = "active"
if m.IsInMaintenance() {
t.Fatal("expected not maintenance")
}
// Empty state (old node) should not be maintenance
m.LifecycleState = ""
if m.IsInMaintenance() {
t.Fatal("empty state should not be maintenance")
}
}
func TestIsAvailable(t *testing.T) {
m := &RQLiteNodeMetadata{LifecycleState: "active"}
if !m.IsAvailable() {
t.Fatal("expected available")
}
// Empty state (old node) defaults to active → available
m.LifecycleState = ""
if !m.IsAvailable() {
t.Fatal("empty state should be available (backward compat)")
}
m.LifecycleState = "maintenance"
if m.IsAvailable() {
t.Fatal("maintenance should not be available")
}
}
func TestIsMaintenanceExpired(t *testing.T) {
// Expired
m := &RQLiteNodeMetadata{
LifecycleState: "maintenance",
MaintenanceTTL: time.Now().Add(-1 * time.Minute),
}
if !m.IsMaintenanceExpired() {
t.Fatal("expected expired")
}
// Not expired
m.MaintenanceTTL = time.Now().Add(5 * time.Minute)
if m.IsMaintenanceExpired() {
t.Fatal("expected not expired")
}
// Zero TTL in maintenance
m.MaintenanceTTL = time.Time{}
if m.IsMaintenanceExpired() {
t.Fatal("zero TTL should not be considered expired")
}
// Not in maintenance
m.LifecycleState = "active"
m.MaintenanceTTL = time.Now().Add(-1 * time.Minute)
if m.IsMaintenanceExpired() {
t.Fatal("active state should not report expired")
}
}
// TestBackwardCompatibility verifies that old metadata (without new fields)
// unmarshals correctly — new fields get zero values, helpers return sane defaults.
func TestBackwardCompatibility(t *testing.T) {
oldJSON := `{
"node_id": "10.0.0.1:7001",
"raft_address": "10.0.0.1:7001",
"http_address": "10.0.0.1:5001",
"node_type": "node",
"raft_log_index": 42,
"cluster_version": "1.0"
}`
var m RQLiteNodeMetadata
if err := json.Unmarshal([]byte(oldJSON), &m); err != nil {
t.Fatalf("unmarshal old metadata: %v", err)
}
// Existing fields preserved
if m.NodeID != "10.0.0.1:7001" {
t.Fatalf("expected node_id 10.0.0.1:7001, got %s", m.NodeID)
}
if m.RaftLogIndex != 42 {
t.Fatalf("expected raft_log_index 42, got %d", m.RaftLogIndex)
}
// New fields default to zero values
if m.PeerID != "" {
t.Fatalf("expected empty PeerID, got %q", m.PeerID)
}
if m.LifecycleState != "" {
t.Fatalf("expected empty LifecycleState, got %q", m.LifecycleState)
}
if m.Services != nil {
t.Fatal("expected nil Services")
}
// Helpers return correct defaults
if m.EffectiveLifecycleState() != "active" {
t.Fatalf("expected effective state 'active', got %q", m.EffectiveLifecycleState())
}
if !m.IsAvailable() {
t.Fatal("old metadata should be available")
}
if m.IsInMaintenance() {
t.Fatal("old metadata should not be in maintenance")
}
}
// TestNewFieldsRoundTrip verifies that new fields marshal/unmarshal correctly.
func TestNewFieldsRoundTrip(t *testing.T) {
original := &RQLiteNodeMetadata{
NodeID: "10.0.0.1:7001",
RaftAddress: "10.0.0.1:7001",
HTTPAddress: "10.0.0.1:5001",
NodeType: "node",
RaftLogIndex: 100,
ClusterVersion: "1.0",
PeerID: "QmPeerID123",
WireGuardIP: "10.0.0.1",
LifecycleState: "maintenance",
MaintenanceTTL: time.Now().Add(10 * time.Minute).Truncate(time.Millisecond),
BinaryVersion: "1.2.3",
Services: map[string]*ServiceStatus{
"rqlite": {Name: "rqlite", Running: true, Healthy: true, Message: "leader"},
},
Namespaces: map[string]*NamespaceStatus{
"myapp": {Name: "myapp", Status: "healthy"},
},
}
data, err := json.Marshal(original)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var decoded RQLiteNodeMetadata
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if decoded.PeerID != original.PeerID {
t.Fatalf("PeerID: got %q, want %q", decoded.PeerID, original.PeerID)
}
if decoded.WireGuardIP != original.WireGuardIP {
t.Fatalf("WireGuardIP: got %q, want %q", decoded.WireGuardIP, original.WireGuardIP)
}
if decoded.LifecycleState != original.LifecycleState {
t.Fatalf("LifecycleState: got %q, want %q", decoded.LifecycleState, original.LifecycleState)
}
if decoded.BinaryVersion != original.BinaryVersion {
t.Fatalf("BinaryVersion: got %q, want %q", decoded.BinaryVersion, original.BinaryVersion)
}
if decoded.Services["rqlite"] == nil || !decoded.Services["rqlite"].Running {
t.Fatal("expected rqlite service to be running")
}
if decoded.Namespaces["myapp"] == nil || decoded.Namespaces["myapp"].Status != "healthy" {
t.Fatal("expected myapp namespace to be healthy")
}
}
// TestOldNodeReadsNewMetadata simulates an old node (that doesn't know about new fields)
// reading metadata from a new node. Go's JSON unmarshalling silently ignores unknown fields.
func TestOldNodeReadsNewMetadata(t *testing.T) {
newJSON := `{
"node_id": "10.0.0.1:7001",
"raft_address": "10.0.0.1:7001",
"http_address": "10.0.0.1:5001",
"node_type": "node",
"raft_log_index": 42,
"cluster_version": "1.0",
"peer_id": "QmSomePeerID",
"wireguard_ip": "10.0.0.1",
"lifecycle_state": "maintenance",
"maintenance_ttl": "2025-01-01T00:00:00Z",
"binary_version": "2.0.0",
"services": {"rqlite": {"name": "rqlite", "running": true, "healthy": true}},
"namespaces": {"app": {"name": "app", "status": "healthy"}},
"some_future_field": "unknown"
}`
// Simulate "old" struct with only original fields
type OldMetadata struct {
NodeID string `json:"node_id"`
RaftAddress string `json:"raft_address"`
HTTPAddress string `json:"http_address"`
NodeType string `json:"node_type"`
RaftLogIndex uint64 `json:"raft_log_index"`
ClusterVersion string `json:"cluster_version"`
}
var old OldMetadata
if err := json.Unmarshal([]byte(newJSON), &old); err != nil {
t.Fatalf("old node should unmarshal new metadata without error: %v", err)
}
if old.NodeID != "10.0.0.1:7001" || old.RaftLogIndex != 42 {
t.Fatal("old fields should be preserved")
}
}

View File

@ -0,0 +1,194 @@
package encryption
import (
"crypto/ed25519"
"crypto/sha256"
"fmt"
"io"
"os"
"os/exec"
"strings"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
)
// NodeKeys holds all cryptographic keys derived from a wallet's master key.
type NodeKeys struct {
LibP2PPrivateKey ed25519.PrivateKey // Ed25519 for LibP2P identity
LibP2PPublicKey ed25519.PublicKey
WireGuardKey [32]byte // Curve25519 private key (clamped)
WireGuardPubKey [32]byte // Curve25519 public key
IPFSPrivateKey ed25519.PrivateKey
IPFSPublicKey ed25519.PublicKey
ClusterPrivateKey ed25519.PrivateKey // IPFS Cluster identity
ClusterPublicKey ed25519.PublicKey
JWTPrivateKey ed25519.PrivateKey // EdDSA JWT signing key
JWTPublicKey ed25519.PublicKey
}
// DeriveNodeKeysFromWallet calls `rw derive` to get a master key from the user's
// Root Wallet, then expands it into all node keys. The wallet's private key never
// leaves the `rw` process.
//
// vpsIP is used as the HKDF info parameter, so each VPS gets unique keys from the
// same wallet. Stdin is passed through so rw can prompt for the wallet password.
func DeriveNodeKeysFromWallet(vpsIP string) (*NodeKeys, error) {
if vpsIP == "" {
return nil, fmt.Errorf("VPS IP is required for key derivation")
}
// Check rw is installed
if _, err := exec.LookPath("rw"); err != nil {
return nil, fmt.Errorf("Root Wallet (rw) not found in PATH — install it first")
}
// Call rw derive to get master key bytes
cmd := exec.Command("rw", "derive", "--salt", "orama-node", "--info", vpsIP)
cmd.Stdin = os.Stdin // pass through for password prompts
cmd.Stderr = os.Stderr // rw UI messages go to terminal
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("rw derive failed: %w", err)
}
masterHex := strings.TrimSpace(string(out))
if len(masterHex) != 64 { // 32 bytes = 64 hex chars
return nil, fmt.Errorf("rw derive returned unexpected output length: %d (expected 64 hex chars)", len(masterHex))
}
masterKey, err := hexToBytes(masterHex)
if err != nil {
return nil, fmt.Errorf("rw derive returned invalid hex: %w", err)
}
defer zeroBytes(masterKey)
return ExpandNodeKeys(masterKey)
}
// ExpandNodeKeys expands a 32-byte master key into all node keys using HKDF-SHA256.
// The master key should come from `rw derive --salt "orama-node" --info "<IP>"`.
//
// Each key type uses a different HKDF info string under the salt "orama-expand",
// ensuring cryptographic independence between key types.
func ExpandNodeKeys(masterKey []byte) (*NodeKeys, error) {
if len(masterKey) != 32 {
return nil, fmt.Errorf("master key must be 32 bytes, got %d", len(masterKey))
}
salt := []byte("orama-expand")
keys := &NodeKeys{}
// Derive LibP2P Ed25519 key
seed, err := deriveBytes(masterKey, salt, []byte("libp2p-identity"), ed25519.SeedSize)
if err != nil {
return nil, fmt.Errorf("failed to derive libp2p key: %w", err)
}
priv := ed25519.NewKeyFromSeed(seed)
zeroBytes(seed)
keys.LibP2PPrivateKey = priv
keys.LibP2PPublicKey = priv.Public().(ed25519.PublicKey)
// Derive WireGuard Curve25519 key
wgSeed, err := deriveBytes(masterKey, salt, []byte("wireguard-key"), 32)
if err != nil {
return nil, fmt.Errorf("failed to derive wireguard key: %w", err)
}
copy(keys.WireGuardKey[:], wgSeed)
zeroBytes(wgSeed)
clampCurve25519Key(&keys.WireGuardKey)
pubKey, err := curve25519.X25519(keys.WireGuardKey[:], curve25519.Basepoint)
if err != nil {
return nil, fmt.Errorf("failed to compute wireguard public key: %w", err)
}
copy(keys.WireGuardPubKey[:], pubKey)
// Derive IPFS Ed25519 key
seed, err = deriveBytes(masterKey, salt, []byte("ipfs-identity"), ed25519.SeedSize)
if err != nil {
return nil, fmt.Errorf("failed to derive ipfs key: %w", err)
}
priv = ed25519.NewKeyFromSeed(seed)
zeroBytes(seed)
keys.IPFSPrivateKey = priv
keys.IPFSPublicKey = priv.Public().(ed25519.PublicKey)
// Derive IPFS Cluster Ed25519 key
seed, err = deriveBytes(masterKey, salt, []byte("ipfs-cluster"), ed25519.SeedSize)
if err != nil {
return nil, fmt.Errorf("failed to derive cluster key: %w", err)
}
priv = ed25519.NewKeyFromSeed(seed)
zeroBytes(seed)
keys.ClusterPrivateKey = priv
keys.ClusterPublicKey = priv.Public().(ed25519.PublicKey)
// Derive JWT EdDSA signing key
seed, err = deriveBytes(masterKey, salt, []byte("jwt-signing"), ed25519.SeedSize)
if err != nil {
return nil, fmt.Errorf("failed to derive jwt key: %w", err)
}
priv = ed25519.NewKeyFromSeed(seed)
zeroBytes(seed)
keys.JWTPrivateKey = priv
keys.JWTPublicKey = priv.Public().(ed25519.PublicKey)
return keys, nil
}
// deriveBytes uses HKDF-SHA256 to derive n bytes from the given IKM, salt, and info.
func deriveBytes(ikm, salt, info []byte, n int) ([]byte, error) {
hkdfReader := hkdf.New(sha256.New, ikm, salt, info)
out := make([]byte, n)
if _, err := io.ReadFull(hkdfReader, out); err != nil {
return nil, err
}
return out, nil
}
// clampCurve25519Key applies the standard Curve25519 clamping to a private key.
func clampCurve25519Key(key *[32]byte) {
key[0] &= 248
key[31] &= 127
key[31] |= 64
}
// hexToBytes decodes a hex string to bytes.
func hexToBytes(hex string) ([]byte, error) {
if len(hex)%2 != 0 {
return nil, fmt.Errorf("odd-length hex string")
}
b := make([]byte, len(hex)/2)
for i := 0; i < len(hex); i += 2 {
var hi, lo byte
var err error
if hi, err = hexCharToByte(hex[i]); err != nil {
return nil, err
}
if lo, err = hexCharToByte(hex[i+1]); err != nil {
return nil, err
}
b[i/2] = hi<<4 | lo
}
return b, nil
}
func hexCharToByte(c byte) (byte, error) {
switch {
case c >= '0' && c <= '9':
return c - '0', nil
case c >= 'a' && c <= 'f':
return c - 'a' + 10, nil
case c >= 'A' && c <= 'F':
return c - 'A' + 10, nil
default:
return 0, fmt.Errorf("invalid hex character: %c", c)
}
}
// zeroBytes zeroes a byte slice to clear sensitive data from memory.
func zeroBytes(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@ -0,0 +1,202 @@
package encryption
import (
"bytes"
"crypto/ed25519"
"testing"
)
// testMasterKey is a deterministic 32-byte key for testing ExpandNodeKeys.
// In production, this comes from `rw derive --salt "orama-node" --info "<IP>"`.
var testMasterKey = bytes.Repeat([]byte{0xab}, 32)
var testMasterKey2 = bytes.Repeat([]byte{0xcd}, 32)
func TestExpandNodeKeys_Determinism(t *testing.T) {
keys1, err := ExpandNodeKeys(testMasterKey)
if err != nil {
t.Fatalf("ExpandNodeKeys: %v", err)
}
keys2, err := ExpandNodeKeys(testMasterKey)
if err != nil {
t.Fatalf("ExpandNodeKeys (second): %v", err)
}
if !bytes.Equal(keys1.LibP2PPrivateKey, keys2.LibP2PPrivateKey) {
t.Error("LibP2P private keys differ for same input")
}
if !bytes.Equal(keys1.WireGuardKey[:], keys2.WireGuardKey[:]) {
t.Error("WireGuard keys differ for same input")
}
if !bytes.Equal(keys1.IPFSPrivateKey, keys2.IPFSPrivateKey) {
t.Error("IPFS private keys differ for same input")
}
if !bytes.Equal(keys1.ClusterPrivateKey, keys2.ClusterPrivateKey) {
t.Error("Cluster private keys differ for same input")
}
if !bytes.Equal(keys1.JWTPrivateKey, keys2.JWTPrivateKey) {
t.Error("JWT private keys differ for same input")
}
}
func TestExpandNodeKeys_Uniqueness(t *testing.T) {
keys1, err := ExpandNodeKeys(testMasterKey)
if err != nil {
t.Fatalf("ExpandNodeKeys(master1): %v", err)
}
keys2, err := ExpandNodeKeys(testMasterKey2)
if err != nil {
t.Fatalf("ExpandNodeKeys(master2): %v", err)
}
if bytes.Equal(keys1.LibP2PPrivateKey, keys2.LibP2PPrivateKey) {
t.Error("LibP2P keys should differ for different master keys")
}
if bytes.Equal(keys1.WireGuardKey[:], keys2.WireGuardKey[:]) {
t.Error("WireGuard keys should differ for different master keys")
}
if bytes.Equal(keys1.IPFSPrivateKey, keys2.IPFSPrivateKey) {
t.Error("IPFS keys should differ for different master keys")
}
if bytes.Equal(keys1.ClusterPrivateKey, keys2.ClusterPrivateKey) {
t.Error("Cluster keys should differ for different master keys")
}
if bytes.Equal(keys1.JWTPrivateKey, keys2.JWTPrivateKey) {
t.Error("JWT keys should differ for different master keys")
}
}
func TestExpandNodeKeys_KeysAreMutuallyUnique(t *testing.T) {
keys, err := ExpandNodeKeys(testMasterKey)
if err != nil {
t.Fatalf("ExpandNodeKeys: %v", err)
}
privKeys := [][]byte{
keys.LibP2PPrivateKey.Seed(),
keys.IPFSPrivateKey.Seed(),
keys.ClusterPrivateKey.Seed(),
keys.JWTPrivateKey.Seed(),
keys.WireGuardKey[:],
}
labels := []string{"LibP2P", "IPFS", "Cluster", "JWT", "WireGuard"}
for i := 0; i < len(privKeys); i++ {
for j := i + 1; j < len(privKeys); j++ {
if bytes.Equal(privKeys[i], privKeys[j]) {
t.Errorf("%s and %s keys should differ", labels[i], labels[j])
}
}
}
}
func TestExpandNodeKeys_Ed25519Validity(t *testing.T) {
keys, err := ExpandNodeKeys(testMasterKey)
if err != nil {
t.Fatalf("ExpandNodeKeys: %v", err)
}
msg := []byte("test message for verification")
pairs := []struct {
name string
priv ed25519.PrivateKey
pub ed25519.PublicKey
}{
{"LibP2P", keys.LibP2PPrivateKey, keys.LibP2PPublicKey},
{"IPFS", keys.IPFSPrivateKey, keys.IPFSPublicKey},
{"Cluster", keys.ClusterPrivateKey, keys.ClusterPublicKey},
{"JWT", keys.JWTPrivateKey, keys.JWTPublicKey},
}
for _, p := range pairs {
signature := ed25519.Sign(p.priv, msg)
if !ed25519.Verify(p.pub, msg, signature) {
t.Errorf("%s key pair: signature verification failed", p.name)
}
}
}
func TestExpandNodeKeys_WireGuardClamping(t *testing.T) {
keys, err := ExpandNodeKeys(testMasterKey)
if err != nil {
t.Fatalf("ExpandNodeKeys: %v", err)
}
if keys.WireGuardKey[0]&7 != 0 {
t.Errorf("WireGuard key not properly clamped: low 3 bits of first byte should be 0, got %08b", keys.WireGuardKey[0])
}
if keys.WireGuardKey[31]&128 != 0 {
t.Errorf("WireGuard key not properly clamped: high bit of last byte should be 0, got %08b", keys.WireGuardKey[31])
}
if keys.WireGuardKey[31]&64 != 64 {
t.Errorf("WireGuard key not properly clamped: second-high bit of last byte should be 1, got %08b", keys.WireGuardKey[31])
}
var zero [32]byte
if keys.WireGuardPubKey == zero {
t.Error("WireGuard public key is all zeros")
}
}
func TestExpandNodeKeys_InvalidMasterKeyLength(t *testing.T) {
_, err := ExpandNodeKeys(nil)
if err == nil {
t.Error("expected error for nil master key")
}
_, err = ExpandNodeKeys([]byte{})
if err == nil {
t.Error("expected error for empty master key")
}
_, err = ExpandNodeKeys(make([]byte, 16))
if err == nil {
t.Error("expected error for 16-byte master key")
}
_, err = ExpandNodeKeys(make([]byte, 64))
if err == nil {
t.Error("expected error for 64-byte master key")
}
}
func TestHexToBytes(t *testing.T) {
tests := []struct {
input string
expected []byte
wantErr bool
}{
{"", []byte{}, false},
{"00", []byte{0}, false},
{"ff", []byte{255}, false},
{"FF", []byte{255}, false},
{"0a1b2c", []byte{10, 27, 44}, false},
{"0", nil, true}, // odd length
{"zz", nil, true}, // invalid chars
{"gg", nil, true}, // invalid chars
}
for _, tt := range tests {
got, err := hexToBytes(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("hexToBytes(%q): expected error", tt.input)
}
continue
}
if err != nil {
t.Errorf("hexToBytes(%q): unexpected error: %v", tt.input, err)
continue
}
if !bytes.Equal(got, tt.expected) {
t.Errorf("hexToBytes(%q) = %v, want %v", tt.input, got, tt.expected)
}
}
}
func TestDeriveNodeKeysFromWallet_EmptyIP(t *testing.T) {
_, err := DeriveNodeKeysFromWallet("")
if err == nil {
t.Error("expected error for empty VPS IP")
}
}

View File

@ -2,6 +2,7 @@ package auth
import (
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
@ -15,13 +16,13 @@ import (
func (s *Service) JWKSHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if s.signingKey == nil {
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{}})
return
}
keys := make([]any, 0, 2)
// RSA key (RS256)
if s.signingKey != nil {
pub := s.signingKey.Public().(*rsa.PublicKey)
n := pub.N.Bytes()
// Encode exponent as big-endian bytes
eVal := pub.E
eb := make([]byte, 0)
for eVal > 0 {
@ -31,15 +32,30 @@ func (s *Service) JWKSHandler(w http.ResponseWriter, r *http.Request) {
if len(eb) == 0 {
eb = []byte{0}
}
jwk := map[string]string{
keys = append(keys, map[string]string{
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": s.keyID,
"n": base64.RawURLEncoding.EncodeToString(n),
"e": base64.RawURLEncoding.EncodeToString(eb),
})
}
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{jwk}})
// Ed25519 key (EdDSA)
if s.edSigningKey != nil {
pubKey := s.edSigningKey.Public().(ed25519.PublicKey)
keys = append(keys, map[string]string{
"kty": "OKP",
"use": "sig",
"alg": "EdDSA",
"kid": s.edKeyID,
"crv": "Ed25519",
"x": base64.RawURLEncoding.EncodeToString(pubKey),
})
}
_ = json.NewEncoder(w).Encode(map[string]any{"keys": keys})
}
// Internal types for JWT handling
@ -59,11 +75,12 @@ type JWTClaims struct {
Namespace string `json:"namespace"`
}
// ParseAndVerifyJWT verifies an RS256 JWT created by this gateway and returns claims
// ParseAndVerifyJWT verifies a JWT created by this gateway using kid-based key
// selection. It accepts both RS256 (legacy) and EdDSA (new) tokens.
//
// Security (C3 fix): The key is selected by kid, then cross-checked against alg
// to prevent algorithm confusion attacks. Only RS256 and EdDSA are accepted.
func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
if s.signingKey == nil {
return nil, errors.New("signing key unavailable")
}
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, errors.New("invalid token format")
@ -80,20 +97,60 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
if err != nil {
return nil, errors.New("invalid signature encoding")
}
var header jwtHeader
if err := json.Unmarshal(hb, &header); err != nil {
return nil, errors.New("invalid header json")
}
if header.Alg != "RS256" {
return nil, errors.New("unsupported alg")
// Explicit algorithm allowlist — reject everything else before verification
if header.Alg != "RS256" && header.Alg != "EdDSA" {
return nil, errors.New("unsupported algorithm")
}
// Verify signature
signingInput := parts[0] + "." + parts[1]
// Key selection by kid (not alg) — prevents algorithm confusion (C3 fix)
switch {
case header.Kid != "" && header.Kid == s.edKeyID && s.edSigningKey != nil:
// EdDSA key matched by kid — cross-check alg
if header.Alg != "EdDSA" {
return nil, errors.New("algorithm mismatch for key")
}
pubKey := s.edSigningKey.Public().(ed25519.PublicKey)
if !ed25519.Verify(pubKey, []byte(signingInput), sb) {
return nil, errors.New("invalid signature")
}
case header.Kid != "" && header.Kid == s.keyID && s.signingKey != nil:
// RSA key matched by kid — cross-check alg
if header.Alg != "RS256" {
return nil, errors.New("algorithm mismatch for key")
}
sum := sha256.Sum256([]byte(signingInput))
pub := s.signingKey.Public().(*rsa.PublicKey)
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
return nil, errors.New("invalid signature")
}
case header.Kid == "":
// Legacy token without kid — RS256 only (backward compat)
if header.Alg != "RS256" {
return nil, errors.New("legacy token must be RS256")
}
if s.signingKey == nil {
return nil, errors.New("signing key unavailable")
}
sum := sha256.Sum256([]byte(signingInput))
pub := s.signingKey.Public().(*rsa.PublicKey)
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
return nil, errors.New("invalid signature")
}
default:
return nil, errors.New("unknown key ID")
}
// Parse claims
var claims JWTClaims
if err := json.Unmarshal(pb, &claims); err != nil {
@ -105,8 +162,7 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
}
// Validate registered claims
now := time.Now().Unix()
// allow small clock skew ±60s
const skew = int64(60)
const skew = int64(60) // allow small clock skew ±60s
if claims.Nbf != 0 && now+skew < claims.Nbf {
return nil, errors.New("token not yet valid")
}
@ -123,6 +179,44 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
}
func (s *Service) GenerateJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
// Prefer EdDSA when available
if s.preferEdDSA && s.edSigningKey != nil {
return s.generateEdDSAJWT(ns, subject, ttl)
}
return s.generateRSAJWT(ns, subject, ttl)
}
func (s *Service) generateEdDSAJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
if s.edSigningKey == nil {
return "", 0, errors.New("EdDSA signing key unavailable")
}
header := map[string]string{
"alg": "EdDSA",
"typ": "JWT",
"kid": s.edKeyID,
}
hb, _ := json.Marshal(header)
now := time.Now().UTC()
exp := now.Add(ttl)
payload := map[string]any{
"iss": "debros-gateway",
"sub": subject,
"aud": "gateway",
"iat": now.Unix(),
"nbf": now.Unix(),
"exp": exp.Unix(),
"namespace": ns,
}
pb, _ := json.Marshal(payload)
hb64 := base64.RawURLEncoding.EncodeToString(hb)
pb64 := base64.RawURLEncoding.EncodeToString(pb)
signingInput := hb64 + "." + pb64
sig := ed25519.Sign(s.edSigningKey, []byte(signingInput))
sb64 := base64.RawURLEncoding.EncodeToString(sig)
return signingInput + "." + sb64, exp.Unix(), nil
}
func (s *Service) generateRSAJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
if s.signingKey == nil {
return "", 0, errors.New("signing key unavailable")
}

View File

@ -28,6 +28,9 @@ type Service struct {
orm client.NetworkClient
signingKey *rsa.PrivateKey
keyID string
edSigningKey ed25519.PrivateKey
edKeyID string
preferEdDSA bool
defaultNS string
}
@ -58,6 +61,16 @@ func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signing
return s, nil
}
// SetEdDSAKey configures an Ed25519 signing key for EdDSA JWT support.
// When set, new tokens are signed with EdDSA; RS256 is still accepted for verification.
func (s *Service) SetEdDSAKey(privKey ed25519.PrivateKey) {
s.edSigningKey = privKey
pubBytes := []byte(privKey.Public().(ed25519.PublicKey))
sum := sha256.Sum256(pubBytes)
s.edKeyID = "ed_" + hex.EncodeToString(sum[:8])
s.preferEdDSA = true
}
// CreateNonce generates a new nonce and stores it in the database
func (s *Service) CreateNonce(ctx context.Context, wallet, purpose, namespace string) (string, error) {
// Generate a URL-safe random nonce (32 bytes)

View File

@ -2,11 +2,18 @@ package auth
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@ -164,3 +171,248 @@ func TestVerifySolSignature(t *testing.T) {
t.Error("VerifySignature should have failed for invalid base64 signature")
}
}
// createDualKeyService creates a service with both RSA and EdDSA keys configured
func createDualKeyService(t *testing.T) *Service {
t.Helper()
s := createTestService(t) // has RSA
_, edPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("failed to generate ed25519 key: %v", err)
}
s.SetEdDSAKey(edPriv)
return s
}
func TestEdDSAJWTFlow(t *testing.T) {
s := createDualKeyService(t)
ns := "test-ns"
sub := "0xabcdef1234567890abcdef1234567890abcdef12"
ttl := 15 * time.Minute
// With EdDSA preferred, GenerateJWT should produce an EdDSA token
token, exp, err := s.GenerateJWT(ns, sub, ttl)
if err != nil {
t.Fatalf("GenerateJWT (EdDSA) failed: %v", err)
}
if token == "" {
t.Fatal("generated EdDSA token is empty")
}
if exp <= time.Now().Unix() {
t.Errorf("expiration time %d is in the past", exp)
}
// Verify the header contains EdDSA
parts := strings.Split(token, ".")
hb, _ := base64.RawURLEncoding.DecodeString(parts[0])
var header map[string]string
json.Unmarshal(hb, &header)
if header["alg"] != "EdDSA" {
t.Errorf("expected alg EdDSA, got %s", header["alg"])
}
if header["kid"] != s.edKeyID {
t.Errorf("expected kid %s, got %s", s.edKeyID, header["kid"])
}
// Verify the token
claims, err := s.ParseAndVerifyJWT(token)
if err != nil {
t.Fatalf("ParseAndVerifyJWT (EdDSA) failed: %v", err)
}
if claims.Sub != sub {
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
}
if claims.Namespace != ns {
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
}
}
func TestRS256BackwardCompat(t *testing.T) {
s := createDualKeyService(t)
// Generate an RS256 token directly (simulating a legacy token)
s.preferEdDSA = false
token, _, err := s.GenerateJWT("test-ns", "user1", 15*time.Minute)
if err != nil {
t.Fatalf("GenerateJWT (RS256) failed: %v", err)
}
s.preferEdDSA = true // re-enable EdDSA preference
// Verify the RS256 token still works with dual-key service
claims, err := s.ParseAndVerifyJWT(token)
if err != nil {
t.Fatalf("ParseAndVerifyJWT should accept RS256 token: %v", err)
}
if claims.Sub != "user1" {
t.Errorf("expected subject user1, got %s", claims.Sub)
}
}
func TestAlgorithmConfusion_Rejected(t *testing.T) {
s := createDualKeyService(t)
t.Run("none_algorithm", func(t *testing.T) {
// Craft a token with alg=none
header := map[string]string{"alg": "none", "typ": "JWT"}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
base64.RawURLEncoding.EncodeToString(pb) + "."
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject alg=none")
}
})
t.Run("HS256_algorithm", func(t *testing.T) {
header := map[string]string{"alg": "HS256", "typ": "JWT", "kid": s.keyID}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
base64.RawURLEncoding.EncodeToString(pb) + "." +
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject alg=HS256")
}
})
t.Run("kid_alg_mismatch_EdDSA_kid_RS256_alg", func(t *testing.T) {
// Use EdDSA kid but claim RS256 alg
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": s.edKeyID}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
// Sign with RSA (trying to confuse the verifier into using RSA on EdDSA kid)
hb64 := base64.RawURLEncoding.EncodeToString(hb)
pb64 := base64.RawURLEncoding.EncodeToString(pb)
signingInput := hb64 + "." + pb64
sum := sha256.Sum256([]byte(signingInput))
rsaSig, _ := rsa.SignPKCS1v15(rand.Reader, s.signingKey, 4, sum[:]) // crypto.SHA256 = 4
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(rsaSig)
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject kid/alg mismatch (EdDSA kid with RS256 alg)")
}
if err != nil && !strings.Contains(err.Error(), "algorithm mismatch") {
t.Errorf("expected 'algorithm mismatch' error, got: %v", err)
}
})
t.Run("unknown_kid", func(t *testing.T) {
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": "unknown-kid-123"}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
base64.RawURLEncoding.EncodeToString(pb) + "." +
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject unknown kid")
}
})
t.Run("legacy_token_EdDSA_rejected", func(t *testing.T) {
// Token with no kid and alg=EdDSA — should be rejected (legacy must be RS256)
header := map[string]string{"alg": "EdDSA", "typ": "JWT"}
hb, _ := json.Marshal(header)
payload := map[string]any{
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
}
pb, _ := json.Marshal(payload)
hb64 := base64.RawURLEncoding.EncodeToString(hb)
pb64 := base64.RawURLEncoding.EncodeToString(pb)
signingInput := hb64 + "." + pb64
sig := ed25519.Sign(s.edSigningKey, []byte(signingInput))
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(sig)
_, err := s.ParseAndVerifyJWT(token)
if err == nil {
t.Error("should reject legacy token (no kid) with EdDSA alg")
}
})
}
func TestJWKSHandler_DualKey(t *testing.T) {
s := createDualKeyService(t)
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
w := httptest.NewRecorder()
s.JWKSHandler(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
var result struct {
Keys []map[string]string `json:"keys"`
}
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode JWKS response: %v", err)
}
if len(result.Keys) != 2 {
t.Fatalf("expected 2 keys in JWKS, got %d", len(result.Keys))
}
// Verify we have both RSA and OKP keys
algSet := map[string]bool{}
for _, k := range result.Keys {
algSet[k["alg"]] = true
if k["kid"] == "" {
t.Error("key missing kid")
}
}
if !algSet["RS256"] {
t.Error("JWKS missing RS256 key")
}
if !algSet["EdDSA"] {
t.Error("JWKS missing EdDSA key")
}
}
func TestJWKSHandler_RSAOnly(t *testing.T) {
s := createTestService(t) // RSA only, no EdDSA
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
w := httptest.NewRecorder()
s.JWKSHandler(w, req)
var result struct {
Keys []map[string]string `json:"keys"`
}
json.NewDecoder(w.Body).Decode(&result)
if len(result.Keys) != 1 {
t.Fatalf("expected 1 key in JWKS (RSA only), got %d", len(result.Keys))
}
if result.Keys[0]["alg"] != "RS256" {
t.Errorf("expected RS256, got %s", result.Keys[0]["alg"])
}
}

View File

@ -429,7 +429,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
logger.Logger,
)
// Initialize auth service with persistent signing key
// Initialize auth service with persistent signing keys (RSA + EdDSA)
keyPEM, err := loadOrCreateSigningKey(cfg.DataDir, logger)
if err != nil {
return fmt.Errorf("failed to load or create JWT signing key: %w", err)
@ -438,6 +438,17 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
if err != nil {
return fmt.Errorf("failed to initialize auth service: %w", err)
}
// Load or create EdDSA key for new JWT tokens
edKey, err := loadOrCreateEdSigningKey(cfg.DataDir, logger)
if err != nil {
logger.ComponentWarn(logging.ComponentGeneral, "Failed to load EdDSA signing key; new JWTs will use RS256",
zap.Error(err))
} else {
authService.SetEdDSAKey(edKey)
logger.ComponentInfo(logging.ComponentGeneral, "EdDSA signing key loaded; new JWTs will use EdDSA")
}
deps.AuthService = authService
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",

View File

@ -42,7 +42,7 @@ func (g *Gateway) startNamespaceHealthLoop(ctx context.Context) {
}
probeTicker := time.NewTicker(30 * time.Second)
reconcileTicker := time.NewTicker(1 * time.Hour)
reconcileTicker := time.NewTicker(5 * time.Minute)
defer probeTicker.Stop()
defer reconcileTicker.Stop()

View File

@ -1,6 +1,7 @@
package gateway
import (
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
@ -14,6 +15,7 @@ import (
)
const jwtKeyFileName = "jwt-signing-key.pem"
const eddsaKeyFileName = "jwt-eddsa-key.pem"
// loadOrCreateSigningKey loads the JWT signing key from disk, or generates a new one
// if none exists. This ensures JWTs survive gateway restarts.
@ -61,3 +63,56 @@ func loadOrCreateSigningKey(dataDir string, logger *logging.ColoredLogger) ([]by
zap.String("path", keyPath))
return keyPEM, nil
}
// loadOrCreateEdSigningKey loads or generates an Ed25519 private key for EdDSA JWT signing.
// The key is stored as a PKCS8-encoded PEM file alongside the RSA key.
func loadOrCreateEdSigningKey(dataDir string, logger *logging.ColoredLogger) (ed25519.PrivateKey, error) {
keyPath := filepath.Join(dataDir, "secrets", eddsaKeyFileName)
// Try to load existing key
if keyPEM, err := os.ReadFile(keyPath); err == nil && len(keyPEM) > 0 {
block, _ := pem.Decode(keyPEM)
if block != nil {
parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err == nil {
if edKey, ok := parsed.(ed25519.PrivateKey); ok {
logger.ComponentInfo(logging.ComponentGeneral, "Loaded existing EdDSA signing key",
zap.String("path", keyPath))
return edKey, nil
}
}
}
logger.ComponentWarn(logging.ComponentGeneral, "Existing EdDSA signing key is invalid, generating new one",
zap.String("path", keyPath))
}
// Generate new Ed25519 key
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate Ed25519 key: %w", err)
}
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, fmt.Errorf("marshal Ed25519 key: %w", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Bytes,
})
// Ensure secrets directory exists
secretsDir := filepath.Dir(keyPath)
if err := os.MkdirAll(secretsDir, 0700); err != nil {
return nil, fmt.Errorf("create secrets directory: %w", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
return nil, fmt.Errorf("write EdDSA signing key: %w", err)
}
logger.ComponentInfo(logging.ComponentGeneral, "Generated and saved new EdDSA signing key",
zap.String("path", keyPath))
return priv, nil
}

View File

@ -130,11 +130,14 @@ func (cm *ClusterManager) HandleRecoveredNode(ctx context.Context, nodeID string
}
if len(results) > 0 && results[0].Count > 0 {
// Node still has legitimate assignments — just mark active and return
// Node still has legitimate assignments — mark active and repair degraded clusters
cm.logger.Info("Recovered node still has cluster assignments, marking active",
zap.String("node_id", nodeID),
zap.Int("assignments", results[0].Count))
cm.markNodeActive(ctx, nodeID)
// Trigger repair for any degraded clusters this node belongs to
cm.repairDegradedClusters(ctx, nodeID)
return
}
@ -522,6 +525,37 @@ func (cm *ClusterManager) markNodeActive(ctx context.Context, nodeID string) {
}
}
// repairDegradedClusters finds degraded clusters that the recovered node
// belongs to and triggers RepairCluster for each one.
func (cm *ClusterManager) repairDegradedClusters(ctx context.Context, nodeID string) {
type clusterRef struct {
NamespaceName string `db:"namespace_name"`
}
var refs []clusterRef
query := `
SELECT DISTINCT c.namespace_name
FROM namespace_cluster_nodes cn
JOIN namespace_clusters c ON cn.namespace_cluster_id = c.id
WHERE cn.node_id = ? AND c.status = 'degraded'
`
if err := cm.db.Query(ctx, &refs, query, nodeID); err != nil {
cm.logger.Warn("Failed to query degraded clusters for recovered node",
zap.String("node_id", nodeID), zap.Error(err))
return
}
for _, ref := range refs {
cm.logger.Info("Triggering repair for degraded cluster after node recovery",
zap.String("namespace", ref.NamespaceName),
zap.String("recovered_node", nodeID))
if err := cm.RepairCluster(ctx, ref.NamespaceName); err != nil {
cm.logger.Warn("Failed to repair degraded cluster",
zap.String("namespace", ref.NamespaceName),
zap.Error(err))
}
}
}
// removeDeadNodeFromRaft sends a DELETE request to a surviving RQLite node
// to remove the dead node from the Raft voter set.
func (cm *ClusterManager) removeDeadNodeFromRaft(ctx context.Context, deadRaftAddr string, survivingNodes []survivingNodePorts) {

View File

@ -25,8 +25,19 @@ const (
DefaultDeadAfter = 12 // consecutive misses → dead
DefaultQuorumWindow = 5 * time.Minute
DefaultMinQuorum = 2 // out of K observers must agree
// DefaultStartupGracePeriod prevents false dead declarations after
// cluster-wide restart. During this period, no nodes are declared dead.
DefaultStartupGracePeriod = 5 * time.Minute
)
// MetadataReader provides lifecycle metadata for peers. Implemented by
// ClusterDiscoveryService. The health monitor uses this to check maintenance
// status and LastSeen before falling through to HTTP probes.
type MetadataReader interface {
GetPeerLifecycleState(nodeID string) (state string, lastSeen time.Time, found bool)
}
// Config holds the configuration for a Monitor.
type Config struct {
NodeID string // this node's ID (dns_nodes.id / peer ID)
@ -35,6 +46,16 @@ type Config struct {
ProbeInterval time.Duration // how often to probe (default 10s)
ProbeTimeout time.Duration // per-probe HTTP timeout (default 3s)
Neighbors int // K — how many ring neighbors to monitor (default 3)
// MetadataReader provides LibP2P lifecycle metadata for peers.
// When set, the monitor checks peer maintenance state and LastSeen
// before falling through to HTTP probes.
MetadataReader MetadataReader
// StartupGracePeriod prevents false dead declarations after cluster-wide
// restart. During this period, nodes can be marked suspect but never dead.
// Default: 5 minutes.
StartupGracePeriod time.Duration
}
// nodeInfo is a row from dns_nodes used for probing.
@ -56,6 +77,7 @@ type Monitor struct {
cfg Config
httpClient *http.Client
logger *zap.Logger
startTime time.Time // when the monitor was created
mu sync.Mutex
peers map[string]*peerState // nodeID → state
@ -75,6 +97,9 @@ func NewMonitor(cfg Config) *Monitor {
if cfg.Neighbors == 0 {
cfg.Neighbors = DefaultNeighbors
}
if cfg.StartupGracePeriod == 0 {
cfg.StartupGracePeriod = DefaultStartupGracePeriod
}
if cfg.Logger == nil {
cfg.Logger = zap.NewNop()
}
@ -85,18 +110,19 @@ func NewMonitor(cfg Config) *Monitor {
Timeout: cfg.ProbeTimeout,
},
logger: cfg.Logger.With(zap.String("component", "health-monitor")),
startTime: time.Now(),
peers: make(map[string]*peerState),
}
}
// OnNodeDead registers a callback invoked when a node is confirmed dead by
// quorum. The callback runs synchronously in the monitor goroutine.
// quorum. The callback runs with the monitor lock released.
func (m *Monitor) OnNodeDead(fn func(nodeID string)) {
m.onDeadFn = fn
}
// OnNodeRecovered registers a callback invoked when a previously dead node
// transitions back to healthy. Used to clean up orphaned services.
// transitions back to healthy. The callback runs with the monitor lock released.
func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) {
m.onRecoveredFn = fn
}
@ -107,6 +133,7 @@ func (m *Monitor) Start(ctx context.Context) {
zap.String("node_id", m.cfg.NodeID),
zap.Duration("probe_interval", m.cfg.ProbeInterval),
zap.Int("neighbors", m.cfg.Neighbors),
zap.Duration("startup_grace", m.cfg.StartupGracePeriod),
)
ticker := time.NewTicker(m.cfg.ProbeInterval)
@ -123,6 +150,11 @@ func (m *Monitor) Start(ctx context.Context) {
}
}
// isInStartupGrace returns true if the startup grace period is still active.
func (m *Monitor) isInStartupGrace() bool {
return time.Since(m.startTime) < m.cfg.StartupGracePeriod
}
// probeRound runs a single round of probing our ring neighbors.
func (m *Monitor) probeRound(ctx context.Context) {
neighbors, err := m.getRingNeighbors(ctx)
@ -140,7 +172,7 @@ func (m *Monitor) probeRound(ctx context.Context) {
wg.Add(1)
go func(node nodeInfo) {
defer wg.Done()
ok := m.probe(ctx, node)
ok := m.probeNode(ctx, node)
m.updateState(ctx, node.ID, ok)
}(n)
}
@ -150,6 +182,28 @@ func (m *Monitor) probeRound(ctx context.Context) {
m.pruneStaleState(neighbors)
}
// probeNode checks a node's health. It first checks LibP2P metadata (if
// available) to avoid unnecessary HTTP probes, then falls through to HTTP.
func (m *Monitor) probeNode(ctx context.Context, node nodeInfo) bool {
if m.cfg.MetadataReader != nil {
state, lastSeen, found := m.cfg.MetadataReader.GetPeerLifecycleState(node.ID)
if found {
// Maintenance node with recent LastSeen → count as healthy
if state == "maintenance" && time.Since(lastSeen) < 2*time.Minute {
return true
}
// Recently seen active node → count as healthy (no HTTP needed)
if state == "active" && time.Since(lastSeen) < 30*time.Second {
return true
}
}
}
// Fall through to HTTP probe
return m.probe(ctx, node)
}
// probe sends an HTTP ping to a single node. Returns true if healthy.
func (m *Monitor) probe(ctx context.Context, node nodeInfo) bool {
url := fmt.Sprintf("http://%s:6001/v1/internal/ping", node.InternalIP)
@ -167,9 +221,9 @@ func (m *Monitor) probe(ctx context.Context, node nodeInfo) bool {
}
// updateState updates the in-memory state for a peer after a probe.
// Callbacks are invoked with the lock released to prevent deadlocks (C2 fix).
func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) {
m.mu.Lock()
defer m.mu.Unlock()
ps, exists := m.peers[nodeID]
if !exists {
@ -178,23 +232,26 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
}
if healthy {
// Recovered
if ps.status != "healthy" {
wasDead := ps.status == "dead"
m.logger.Info("Node recovered", zap.String("target", nodeID),
zap.String("previous_status", ps.status))
m.writeEvent(ctx, nodeID, "recovered")
shouldCallback := wasDead && m.onRecoveredFn != nil
prevStatus := ps.status
// Fire recovery callback for nodes that were confirmed dead
if wasDead && m.onRecoveredFn != nil {
m.mu.Unlock()
m.onRecoveredFn(nodeID)
m.mu.Lock()
}
}
// Update state BEFORE releasing lock (C2 fix)
ps.missCount = 0
ps.status = "healthy"
ps.reportedDead = false
m.mu.Unlock()
if prevStatus != "healthy" {
m.logger.Info("Node recovered", zap.String("target", nodeID),
zap.String("previous_status", prevStatus))
m.writeEvent(ctx, nodeID, "recovered")
}
// Fire recovery callback without holding the lock (C2 fix)
if shouldCallback {
m.onRecoveredFn(nodeID)
}
return
}
@ -203,6 +260,23 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
switch {
case ps.missCount >= DefaultDeadAfter && !ps.reportedDead:
// During startup grace period, don't declare dead — only suspect
if m.isInStartupGrace() {
if ps.status != "suspect" {
ps.status = "suspect"
ps.suspectAt = time.Now()
m.mu.Unlock()
m.logger.Warn("Node SUSPECT (startup grace — deferring dead)",
zap.String("target", nodeID),
zap.Int("misses", ps.missCount),
)
m.writeEvent(ctx, nodeID, "suspect")
return
}
m.mu.Unlock()
return
}
if ps.status != "dead" {
m.logger.Error("Node declared DEAD",
zap.String("target", nodeID),
@ -211,22 +285,34 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
}
ps.status = "dead"
ps.reportedDead = true
// Copy what we need before releasing lock
shouldCheckQuorum := m.cfg.DB != nil && m.onDeadFn != nil
m.mu.Unlock()
m.writeEvent(ctx, nodeID, "dead")
if shouldCheckQuorum {
m.checkQuorum(ctx, nodeID)
}
return
case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy":
ps.status = "suspect"
ps.suspectAt = time.Now()
m.mu.Unlock()
m.logger.Warn("Node SUSPECT",
zap.String("target", nodeID),
zap.Int("misses", ps.missCount),
)
m.writeEvent(ctx, nodeID, "suspect")
}
return
}
// writeEvent inserts a health event into node_health_events. Must be called
// with m.mu held.
m.mu.Unlock()
}
// writeEvent inserts a health event into node_health_events.
func (m *Monitor) writeEvent(ctx context.Context, targetID, status string) {
if m.cfg.DB == nil {
return
@ -242,7 +328,8 @@ func (m *Monitor) writeEvent(ctx context.Context, targetID, status string) {
}
// checkQuorum queries the events table to see if enough observers agree the
// target is dead, then fires the onDead callback. Must be called with m.mu held.
// target is dead, then fires the onDead callback. Called WITHOUT the lock held
// (C2 fix — previously called with lock held, causing deadlocks in callbacks).
func (m *Monitor) checkQuorum(ctx context.Context, targetID string) {
if m.cfg.DB == nil || m.onDeadFn == nil {
return
@ -287,10 +374,7 @@ func (m *Monitor) checkQuorum(ctx context.Context, targetID string) {
zap.String("target", targetID),
zap.Int("observers", count),
)
// Release the lock before calling the callback to avoid deadlocks.
m.mu.Unlock()
m.onDeadFn(targetID)
m.mu.Lock()
}
// getRingNeighbors queries dns_nodes for active nodes, sorts them, and

View File

@ -161,7 +161,9 @@ func TestStateTransitions(t *testing.T) {
NodeID: "self",
ProbeInterval: time.Second,
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond, // disable grace for this test
})
time.Sleep(2 * time.Millisecond) // ensure grace period expired
ctx := context.Background()
@ -300,7 +302,9 @@ func TestOnNodeDead_Callback(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
m.OnNodeDead(func(nodeID string) {
called.Add(1)
})
@ -316,3 +320,204 @@ func TestOnNodeDead_Callback(t *testing.T) {
t.Fatalf("expected dead, got %s", m.peers["victim"].status)
}
}
// ---------------------------------------------------------------
// Startup grace period
// ---------------------------------------------------------------
func TestStartupGrace_PreventsDead(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Hour, // very long grace
})
ctx := context.Background()
// Accumulate enough misses for dead (12)
for i := 0; i < DefaultDeadAfter+5; i++ {
m.updateState(ctx, "peer1", false)
}
m.mu.Lock()
status := m.peers["peer1"].status
m.mu.Unlock()
// During grace, should be suspect, NOT dead
if status != "suspect" {
t.Fatalf("expected suspect during startup grace, got %s", status)
}
}
func TestStartupGrace_AllowsDeadAfterExpiry(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond) // grace expired
ctx := context.Background()
for i := 0; i < DefaultDeadAfter; i++ {
m.updateState(ctx, "peer1", false)
}
m.mu.Lock()
status := m.peers["peer1"].status
m.mu.Unlock()
if status != "dead" {
t.Fatalf("expected dead after grace expired, got %s", status)
}
}
// ---------------------------------------------------------------
// MetadataReader integration
// ---------------------------------------------------------------
type mockMetadataReader struct {
state string
lastSeen time.Time
found bool
}
func (m *mockMetadataReader) GetPeerLifecycleState(nodeID string) (string, time.Time, bool) {
return m.state, m.lastSeen, m.found
}
func TestProbeNode_MaintenanceCountsHealthy(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
MetadataReader: &mockMetadataReader{
state: "maintenance",
lastSeen: time.Now(),
found: true,
},
})
node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"} // unreachable
ok := m.probeNode(context.Background(), node)
if !ok {
t.Fatal("maintenance node with recent LastSeen should count as healthy")
}
}
func TestProbeNode_RecentActiveSkipsHTTP(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
MetadataReader: &mockMetadataReader{
state: "active",
lastSeen: time.Now(),
found: true,
},
})
// Use unreachable IP — if HTTP were attempted, it would fail
node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"}
ok := m.probeNode(context.Background(), node)
if !ok {
t.Fatal("recently seen active node should skip HTTP and count as healthy")
}
}
func TestProbeNode_StaleMetadataFallsToHTTP(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
ProbeTimeout: 100 * time.Millisecond,
MetadataReader: &mockMetadataReader{
state: "active",
lastSeen: time.Now().Add(-5 * time.Minute), // stale
found: true,
},
})
node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"} // unreachable
ok := m.probeNode(context.Background(), node)
if ok {
t.Fatal("stale metadata should fall through to HTTP probe, which should fail")
}
}
func TestProbeNode_UnknownPeerFallsToHTTP(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
ProbeTimeout: 100 * time.Millisecond,
MetadataReader: &mockMetadataReader{
found: false, // peer not found in metadata
},
})
node := nodeInfo{ID: "unknown", InternalIP: "192.0.2.1"}
ok := m.probeNode(context.Background(), node)
if ok {
t.Fatal("unknown peer should fall through to HTTP probe, which should fail")
}
}
func TestProbeNode_NoMetadataReader(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
m := NewMonitor(Config{
NodeID: "self",
ProbeTimeout: 2 * time.Second,
MetadataReader: nil, // no metadata reader
})
// Without MetadataReader, should go straight to HTTP
addr := strings.TrimPrefix(srv.URL, "http://")
node := nodeInfo{ID: "peer1", InternalIP: addr}
// Note: probe() hardcodes port 6001, so this won't hit our server.
// But we verify it doesn't panic and falls through correctly.
_ = m.probeNode(context.Background(), node)
}
// ---------------------------------------------------------------
// Recovery callback (C2 fix)
// ---------------------------------------------------------------
func TestRecoveryCallback_InvokedWithoutLock(t *testing.T) {
m := NewMonitor(Config{
NodeID: "self",
Neighbors: 3,
StartupGracePeriod: 1 * time.Millisecond,
})
time.Sleep(2 * time.Millisecond)
var recoveredNode string
m.OnNodeRecovered(func(nodeID string) {
recoveredNode = nodeID
// If lock were held, this would deadlock since we try to access peers
// Just verify callback fires correctly
})
ctx := context.Background()
// Drive to dead state
for i := 0; i < DefaultDeadAfter; i++ {
m.updateState(ctx, "peer1", false)
}
m.mu.Lock()
if m.peers["peer1"].status != "dead" {
m.mu.Unlock()
t.Fatal("expected dead state")
}
m.mu.Unlock()
// Recover
m.updateState(ctx, "peer1", true)
if recoveredNode != "peer1" {
t.Fatalf("expected recovery callback for peer1, got %q", recoveredNode)
}
m.mu.Lock()
if m.peers["peer1"].status != "healthy" {
m.mu.Unlock()
t.Fatal("expected healthy after recovery")
}
m.mu.Unlock()
}

View File

@ -0,0 +1,184 @@
package lifecycle
import (
"fmt"
"sync"
"time"
)
// State represents a node's lifecycle state.
type State string
const (
StateJoining State = "joining"
StateActive State = "active"
StateDraining State = "draining"
StateMaintenance State = "maintenance"
)
// MaxMaintenanceTTL is the maximum duration a node can remain in maintenance
// mode. The leader's health monitor enforces this limit — nodes that exceed
// it are treated as unreachable so they can't hide in maintenance forever.
const MaxMaintenanceTTL = 15 * time.Minute
// validTransitions defines the allowed state machine transitions.
// Each entry maps from-state → set of valid to-states.
var validTransitions = map[State]map[State]bool{
StateJoining: {StateActive: true},
StateActive: {StateDraining: true, StateMaintenance: true},
StateDraining: {StateMaintenance: true},
StateMaintenance: {StateActive: true},
}
// StateChangeCallback is called when the lifecycle state changes.
type StateChangeCallback func(old, new State)
// Manager manages a node's lifecycle state machine.
// It has no external dependencies (no LibP2P, no discovery imports)
// and is fully testable in isolation.
type Manager struct {
mu sync.RWMutex
state State
maintenanceTTL time.Time
enterTime time.Time // when the current state was entered
onStateChange []StateChangeCallback
}
// NewManager creates a new lifecycle manager in the joining state.
func NewManager() *Manager {
return &Manager{
state: StateJoining,
enterTime: time.Now(),
}
}
// State returns the current lifecycle state.
func (m *Manager) State() State {
m.mu.RLock()
defer m.mu.RUnlock()
return m.state
}
// MaintenanceTTL returns the maintenance mode expiration time.
// Returns zero value if not in maintenance.
func (m *Manager) MaintenanceTTL() time.Time {
m.mu.RLock()
defer m.mu.RUnlock()
return m.maintenanceTTL
}
// StateEnteredAt returns when the current state was entered.
func (m *Manager) StateEnteredAt() time.Time {
m.mu.RLock()
defer m.mu.RUnlock()
return m.enterTime
}
// OnStateChange registers a callback invoked on state transitions.
// Callbacks are called with the lock released to avoid deadlocks.
func (m *Manager) OnStateChange(cb StateChangeCallback) {
m.mu.Lock()
defer m.mu.Unlock()
m.onStateChange = append(m.onStateChange, cb)
}
// TransitionTo moves the node to a new lifecycle state.
// Returns an error if the transition is not valid.
func (m *Manager) TransitionTo(newState State) error {
m.mu.Lock()
old := m.state
allowed, exists := validTransitions[old]
if !exists || !allowed[newState] {
m.mu.Unlock()
return fmt.Errorf("invalid lifecycle transition: %s → %s", old, newState)
}
m.state = newState
m.enterTime = time.Now()
// Clear maintenance TTL when leaving maintenance
if newState != StateMaintenance {
m.maintenanceTTL = time.Time{}
}
// Copy callbacks before releasing lock
callbacks := make([]StateChangeCallback, len(m.onStateChange))
copy(callbacks, m.onStateChange)
m.mu.Unlock()
// Invoke callbacks without holding the lock
for _, cb := range callbacks {
cb(old, newState)
}
return nil
}
// EnterMaintenance transitions to maintenance with a TTL.
// The TTL is capped at MaxMaintenanceTTL.
func (m *Manager) EnterMaintenance(ttl time.Duration) error {
if ttl <= 0 {
ttl = MaxMaintenanceTTL
}
if ttl > MaxMaintenanceTTL {
ttl = MaxMaintenanceTTL
}
m.mu.Lock()
old := m.state
// Allow both active→maintenance and draining→maintenance
allowed, exists := validTransitions[old]
if !exists || !allowed[StateMaintenance] {
m.mu.Unlock()
return fmt.Errorf("invalid lifecycle transition: %s → %s", old, StateMaintenance)
}
m.state = StateMaintenance
m.maintenanceTTL = time.Now().Add(ttl)
m.enterTime = time.Now()
callbacks := make([]StateChangeCallback, len(m.onStateChange))
copy(callbacks, m.onStateChange)
m.mu.Unlock()
for _, cb := range callbacks {
cb(old, StateMaintenance)
}
return nil
}
// IsMaintenanceExpired returns true if the node is in maintenance and the TTL
// has expired. Used by the leader's health monitor to enforce the max TTL.
func (m *Manager) IsMaintenanceExpired() bool {
m.mu.RLock()
defer m.mu.RUnlock()
if m.state != StateMaintenance {
return false
}
return !m.maintenanceTTL.IsZero() && time.Now().After(m.maintenanceTTL)
}
// IsAvailable returns true if the node is in a state that can serve requests.
func (m *Manager) IsAvailable() bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.state == StateActive
}
// IsInMaintenance returns true if the node is in maintenance mode.
func (m *Manager) IsInMaintenance() bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.state == StateMaintenance
}
// Snapshot returns a point-in-time copy of the lifecycle state for
// embedding in metadata without holding the lock.
func (m *Manager) Snapshot() (state State, ttl time.Time) {
m.mu.RLock()
defer m.mu.RUnlock()
return m.state, m.maintenanceTTL
}

View File

@ -0,0 +1,320 @@
package lifecycle
import (
"sync"
"testing"
"time"
)
func TestNewManager(t *testing.T) {
m := NewManager()
if m.State() != StateJoining {
t.Fatalf("expected initial state %q, got %q", StateJoining, m.State())
}
if m.IsAvailable() {
t.Fatal("joining node should not be available")
}
if m.IsInMaintenance() {
t.Fatal("joining node should not be in maintenance")
}
}
func TestValidTransitions(t *testing.T) {
tests := []struct {
name string
from State
to State
wantErr bool
}{
{"joining→active", StateJoining, StateActive, false},
{"active→draining", StateActive, StateDraining, false},
{"draining→maintenance", StateDraining, StateMaintenance, false},
{"active→maintenance", StateActive, StateMaintenance, false},
{"maintenance→active", StateMaintenance, StateActive, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &Manager{state: tt.from, enterTime: time.Now()}
err := m.TransitionTo(tt.to)
if (err != nil) != tt.wantErr {
t.Fatalf("TransitionTo(%q): err=%v, wantErr=%v", tt.to, err, tt.wantErr)
}
if err == nil && m.State() != tt.to {
t.Fatalf("expected state %q, got %q", tt.to, m.State())
}
})
}
}
func TestInvalidTransitions(t *testing.T) {
tests := []struct {
name string
from State
to State
}{
{"joining→draining", StateJoining, StateDraining},
{"joining→maintenance", StateJoining, StateMaintenance},
{"joining→joining", StateJoining, StateJoining},
{"active→active", StateActive, StateActive},
{"active→joining", StateActive, StateJoining},
{"draining→active", StateDraining, StateActive},
{"draining→joining", StateDraining, StateJoining},
{"maintenance→draining", StateMaintenance, StateDraining},
{"maintenance→joining", StateMaintenance, StateJoining},
{"maintenance→maintenance", StateMaintenance, StateMaintenance},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &Manager{state: tt.from, enterTime: time.Now()}
err := m.TransitionTo(tt.to)
if err == nil {
t.Fatalf("expected error for transition %s → %s", tt.from, tt.to)
}
})
}
}
func TestEnterMaintenance(t *testing.T) {
m := NewManager()
_ = m.TransitionTo(StateActive)
err := m.EnterMaintenance(5 * time.Minute)
if err != nil {
t.Fatalf("EnterMaintenance: %v", err)
}
if !m.IsInMaintenance() {
t.Fatal("expected maintenance state")
}
ttl := m.MaintenanceTTL()
if ttl.IsZero() {
t.Fatal("expected non-zero maintenance TTL")
}
// TTL should be roughly 5 minutes from now
remaining := time.Until(ttl)
if remaining < 4*time.Minute || remaining > 6*time.Minute {
t.Fatalf("expected TTL ~5min from now, got %v", remaining)
}
}
func TestEnterMaintenanceTTLCapped(t *testing.T) {
m := NewManager()
_ = m.TransitionTo(StateActive)
// Request 1 hour, should be capped at MaxMaintenanceTTL
err := m.EnterMaintenance(1 * time.Hour)
if err != nil {
t.Fatalf("EnterMaintenance: %v", err)
}
ttl := m.MaintenanceTTL()
remaining := time.Until(ttl)
if remaining > MaxMaintenanceTTL+time.Second {
t.Fatalf("TTL should be capped at %v, got %v remaining", MaxMaintenanceTTL, remaining)
}
}
func TestEnterMaintenanceZeroTTL(t *testing.T) {
m := NewManager()
_ = m.TransitionTo(StateActive)
// Zero TTL should default to MaxMaintenanceTTL
err := m.EnterMaintenance(0)
if err != nil {
t.Fatalf("EnterMaintenance: %v", err)
}
ttl := m.MaintenanceTTL()
remaining := time.Until(ttl)
if remaining < MaxMaintenanceTTL-time.Second {
t.Fatalf("zero TTL should default to MaxMaintenanceTTL, got %v remaining", remaining)
}
}
func TestMaintenanceTTLClearedOnExit(t *testing.T) {
m := NewManager()
_ = m.TransitionTo(StateActive)
_ = m.EnterMaintenance(5 * time.Minute)
if m.MaintenanceTTL().IsZero() {
t.Fatal("expected non-zero TTL in maintenance")
}
_ = m.TransitionTo(StateActive)
if !m.MaintenanceTTL().IsZero() {
t.Fatal("expected zero TTL after leaving maintenance")
}
}
func TestIsMaintenanceExpired(t *testing.T) {
m := &Manager{
state: StateMaintenance,
maintenanceTTL: time.Now().Add(-1 * time.Minute), // expired 1 minute ago
enterTime: time.Now().Add(-20 * time.Minute),
}
if !m.IsMaintenanceExpired() {
t.Fatal("expected maintenance to be expired")
}
// Not expired
m.maintenanceTTL = time.Now().Add(5 * time.Minute)
if m.IsMaintenanceExpired() {
t.Fatal("expected maintenance to not be expired")
}
// Not in maintenance
m.state = StateActive
if m.IsMaintenanceExpired() {
t.Fatal("expected non-maintenance state to not report expired")
}
}
func TestStateChangeCallback(t *testing.T) {
m := NewManager()
var callbackOld, callbackNew State
called := false
m.OnStateChange(func(old, new State) {
callbackOld = old
callbackNew = new
called = true
})
_ = m.TransitionTo(StateActive)
if !called {
t.Fatal("callback was not called")
}
if callbackOld != StateJoining || callbackNew != StateActive {
t.Fatalf("callback got old=%q new=%q, want old=%q new=%q",
callbackOld, callbackNew, StateJoining, StateActive)
}
}
func TestMultipleCallbacks(t *testing.T) {
m := NewManager()
count := 0
m.OnStateChange(func(_, _ State) { count++ })
m.OnStateChange(func(_, _ State) { count++ })
_ = m.TransitionTo(StateActive)
if count != 2 {
t.Fatalf("expected 2 callbacks, got %d", count)
}
}
func TestSnapshot(t *testing.T) {
m := NewManager()
_ = m.TransitionTo(StateActive)
_ = m.EnterMaintenance(10 * time.Minute)
state, ttl := m.Snapshot()
if state != StateMaintenance {
t.Fatalf("expected maintenance, got %q", state)
}
if ttl.IsZero() {
t.Fatal("expected non-zero TTL in snapshot")
}
}
func TestConcurrentAccess(t *testing.T) {
m := NewManager()
_ = m.TransitionTo(StateActive)
var wg sync.WaitGroup
// Concurrent reads
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = m.State()
_ = m.IsAvailable()
_ = m.IsInMaintenance()
_ = m.IsMaintenanceExpired()
_, _ = m.Snapshot()
}()
}
// Concurrent maintenance enter/exit cycles
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = m.EnterMaintenance(1 * time.Minute)
_ = m.TransitionTo(StateActive)
}()
}
wg.Wait()
}
func TestStateEnteredAt(t *testing.T) {
before := time.Now()
m := NewManager()
after := time.Now()
entered := m.StateEnteredAt()
if entered.Before(before) || entered.After(after) {
t.Fatalf("StateEnteredAt %v not between %v and %v", entered, before, after)
}
time.Sleep(10 * time.Millisecond)
_ = m.TransitionTo(StateActive)
newEntered := m.StateEnteredAt()
if !newEntered.After(entered) {
t.Fatal("expected StateEnteredAt to update after transition")
}
}
func TestEnterMaintenanceFromInvalidState(t *testing.T) {
m := NewManager() // joining state
err := m.EnterMaintenance(5 * time.Minute)
if err == nil {
t.Fatal("expected error entering maintenance from joining state")
}
}
func TestFullLifecycle(t *testing.T) {
m := NewManager()
// joining → active
if err := m.TransitionTo(StateActive); err != nil {
t.Fatalf("joining→active: %v", err)
}
if !m.IsAvailable() {
t.Fatal("active node should be available")
}
// active → draining
if err := m.TransitionTo(StateDraining); err != nil {
t.Fatalf("active→draining: %v", err)
}
if m.IsAvailable() {
t.Fatal("draining node should not be available")
}
// draining → maintenance
if err := m.EnterMaintenance(10 * time.Minute); err != nil {
t.Fatalf("draining→maintenance: %v", err)
}
if !m.IsInMaintenance() {
t.Fatal("should be in maintenance")
}
// maintenance → active
if err := m.TransitionTo(StateActive); err != nil {
t.Fatalf("maintenance→active: %v", err)
}
if !m.IsAvailable() {
t.Fatal("should be available after maintenance")
}
}

View File

@ -12,6 +12,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/gateway"
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/node/lifecycle"
"github.com/DeBrosOfficial/network/pkg/pubsub"
database "github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/libp2p/go-libp2p/core/host"
@ -24,6 +25,9 @@ type Node struct {
logger *logging.ColoredLogger
host host.Host
// Lifecycle state machine (joining → active ⇄ maintenance)
lifecycle *lifecycle.Manager
rqliteManager *database.RQLiteManager
rqliteAdapter *database.RQLiteAdapter
clusterDiscovery *database.ClusterDiscoveryService
@ -56,6 +60,7 @@ func NewNode(cfg *config.Config) (*Node, error) {
return &Node{
config: cfg,
logger: logger,
lifecycle: lifecycle.NewManager(),
}, nil
}
@ -124,9 +129,20 @@ func (n *Node) Start(ctx context.Context) error {
}
}
// All services started — transition lifecycle: joining → active
if err := n.lifecycle.TransitionTo(lifecycle.StateActive); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to transition lifecycle to active", zap.Error(err))
}
// Publish updated metadata with active lifecycle state
if n.clusterDiscovery != nil {
n.clusterDiscovery.UpdateOwnMetadata()
}
n.logger.ComponentInfo(logging.ComponentNode, "Network node started successfully",
zap.String("peer_id", n.GetPeerID()),
zap.Strings("listen_addrs", listenAddrs),
zap.String("lifecycle", string(n.lifecycle.State())),
)
n.startConnectionMonitoring()
@ -138,6 +154,17 @@ func (n *Node) Start(ctx context.Context) error {
func (n *Node) Stop() error {
n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node")
// Enter maintenance so peers know we're shutting down
if n.lifecycle.IsAvailable() {
if err := n.lifecycle.EnterMaintenance(5 * time.Minute); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to enter maintenance on shutdown", zap.Error(err))
}
// Publish maintenance state before tearing down services
if n.clusterDiscovery != nil {
n.clusterDiscovery.UpdateOwnMetadata()
}
}
// Stop HTTP Gateway server
if n.apiGatewayServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

View File

@ -34,6 +34,7 @@ func (n *Node) startRQLite(ctx context.Context) error {
n.config.Discovery.RaftAdvAddress,
n.config.Discovery.HttpAdvAddress,
n.config.Node.DataDir,
n.lifecycle,
n.logger.Logger,
)

View File

@ -71,9 +71,10 @@ func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rq
}
if remotePeerCount >= requiredRemotePeers {
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
// Check discovery-peers.json (safe location outside raft dir)
peersPath := filepath.Join(rqliteDataDir, "discovery-peers.json")
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
time.Sleep(500 * time.Millisecond)
if info, err := os.Stat(peersPath); err == nil && info.Size() > 10 {
data, err := os.ReadFile(peersPath)
@ -97,13 +98,11 @@ func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rql
if err := r.discoveryService.TriggerPeerExchange(ctx); err != nil {
r.logger.Warn("Failed to trigger peer exchange during pre-start discovery", zap.Error(err))
}
time.Sleep(1 * time.Second)
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
time.Sleep(500 * time.Millisecond)
// Wait up to 2 minutes for peer discovery - LibP2P DHT can take 60+ seconds
// to re-establish connections after simultaneous restart
discoveryDeadline := time.Now().Add(2 * time.Minute)
// Wait up to 45s for peer discovery — parallel dials compensate for the shorter deadline
discoveryDeadline := time.Now().Add(45 * time.Second)
var discoveredPeers int
for time.Now().Before(discoveryDeadline) {
@ -151,7 +150,7 @@ func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rql
}
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
time.Sleep(500 * time.Millisecond)
return nil
}
@ -182,9 +181,8 @@ func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error {
}
r.discoveryService.TriggerPeerExchange(ctx)
time.Sleep(2 * time.Second)
r.discoveryService.TriggerSync()
time.Sleep(2 * time.Second)
time.Sleep(500 * time.Millisecond)
rqliteDataDir, _ := r.rqliteDataDirPath()
ourIndex := r.getRaftLogIndex()
@ -201,7 +199,7 @@ func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error {
r.logger.Warn("Failed to clear raft state during split-brain recovery", zap.Error(err))
}
r.discoveryService.TriggerPeerExchange(ctx)
time.Sleep(1 * time.Second)
time.Sleep(500 * time.Millisecond)
if err := r.discoveryService.ForceWritePeersJSON(); err != nil {
r.logger.Warn("Failed to write peers.json during split-brain recovery", zap.Error(err))
}
@ -326,14 +324,15 @@ func (r *RQLiteManager) hasExistingRaftState(rqliteDataDir string) bool {
if info, err := os.Stat(raftLogPath); err == nil && info.Size() > 1024 {
return true
}
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
_, err := os.Stat(peersPath)
return err == nil
// Don't check peers.json — discovery-peers.json is now written outside
// the raft dir and should not be treated as existing Raft state.
return false
}
func (r *RQLiteManager) clearRaftState(rqliteDataDir string) error {
_ = os.Remove(filepath.Join(rqliteDataDir, "raft.db"))
_ = os.Remove(filepath.Join(rqliteDataDir, "raft", "peers.json"))
_ = os.Remove(filepath.Join(rqliteDataDir, "discovery-peers.json"))
return nil
}

View File

@ -3,10 +3,12 @@ package rqlite
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/DeBrosOfficial/network/pkg/discovery"
"github.com/DeBrosOfficial/network/pkg/node/lifecycle"
"github.com/libp2p/go-libp2p/core/host"
"go.uber.org/zap"
)
@ -20,9 +22,13 @@ type ClusterDiscoveryService struct {
nodeType string
raftAddress string
httpAddress string
wireGuardIP string // extracted from raftAddress (IP component)
dataDir string
minClusterSize int // Minimum cluster size required
// Lifecycle manager for this node's state machine
lifecycle *lifecycle.Manager
knownPeers map[string]*discovery.RQLiteNodeMetadata // NodeID -> Metadata
peerHealth map[string]*PeerHealth // NodeID -> Health
lastUpdate time.Time
@ -45,6 +51,7 @@ func NewClusterDiscoveryService(
raftAddress string,
httpAddress string,
dataDir string,
lm *lifecycle.Manager,
logger *zap.Logger,
) *ClusterDiscoveryService {
minClusterSize := 1
@ -52,6 +59,12 @@ func NewClusterDiscoveryService(
minClusterSize = rqliteManager.config.MinClusterSize
}
// Extract WireGuard IP from the raft address (e.g., "10.0.0.1" from "10.0.0.1:7001")
wgIP := ""
if host, _, err := net.SplitHostPort(raftAddress); err == nil {
wgIP = host
}
return &ClusterDiscoveryService{
host: h,
discoveryMgr: discoveryMgr,
@ -60,8 +73,10 @@ func NewClusterDiscoveryService(
nodeType: nodeType,
raftAddress: raftAddress,
httpAddress: httpAddress,
wireGuardIP: wgIP,
dataDir: dataDir,
minClusterSize: minClusterSize,
lifecycle: lm,
knownPeers: make(map[string]*discovery.RQLiteNodeMetadata),
peerHealth: make(map[string]*PeerHealth),
updateInterval: 30 * time.Second,
@ -119,6 +134,25 @@ func (c *ClusterDiscoveryService) Stop() {
c.logger.Info("Cluster discovery service stopped")
}
// Lifecycle returns the node's lifecycle manager.
func (c *ClusterDiscoveryService) Lifecycle() *lifecycle.Manager {
return c.lifecycle
}
// GetPeerLifecycleState returns the lifecycle state and last-seen time for a
// peer identified by its RQLite node ID (raft address). This method implements
// the MetadataReader interface used by the health monitor.
func (c *ClusterDiscoveryService) GetPeerLifecycleState(nodeID string) (state string, lastSeen time.Time, found bool) {
c.mu.RLock()
defer c.mu.RUnlock()
peer, ok := c.knownPeers[nodeID]
if !ok {
return "", time.Time{}, false
}
return peer.EffectiveLifecycleState(), peer.LastSeen, true
}
// IsVoter returns true if the given raft address should be a voter
// in the default cluster based on the current known peers.
func (c *ClusterDiscoveryService) IsVoter(raftAddress string) bool {

View File

@ -39,6 +39,17 @@ func (c *ClusterDiscoveryService) collectPeerMetadata() []*discovery.RQLiteNodeM
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
LastSeen: time.Now(),
ClusterVersion: "1.0",
PeerID: c.host.ID().String(),
WireGuardIP: c.wireGuardIP,
}
// Populate lifecycle state
if c.lifecycle != nil {
state, ttl := c.lifecycle.Snapshot()
ourMetadata.LifecycleState = string(state)
if state == "maintenance" {
ourMetadata.MaintenanceTTL = ttl
}
}
if c.adjustSelfAdvertisedAddresses(ourMetadata) {
@ -272,7 +283,7 @@ func (c *ClusterDiscoveryService) getPeersJSONUnlocked() []map[string]interface{
}
// computeVoterSet returns the set of raft addresses that should be voters.
// It sorts addresses by their IP component and selects the first maxVoters.
// It sorts addresses by their numeric IP and selects the first maxVoters.
// This is deterministic — all nodes compute the same voter set from the same peer list.
func computeVoterSet(raftAddrs []string, maxVoters int) map[string]struct{} {
sorted := make([]string, len(raftAddrs))
@ -281,7 +292,7 @@ func computeVoterSet(raftAddrs []string, maxVoters int) map[string]struct{} {
sort.Slice(sorted, func(i, j int) bool {
ipI := extractIPForSort(sorted[i])
ipJ := extractIPForSort(sorted[j])
return ipI < ipJ
return compareIPs(ipI, ipJ)
})
voters := make(map[string]struct{})
@ -303,6 +314,31 @@ func extractIPForSort(raftAddr string) string {
return host
}
// compareIPs compares two IP strings numerically (not alphabetically).
// Alphabetical sort gives wrong results: "10.0.0.10" < "10.0.0.2" alphabetically,
// but numerically 10.0.0.2 < 10.0.0.10. This was causing wrong nodes to be
// selected as voters (e.g., 10.0.0.1, 10.0.0.10, 10.0.0.11 instead of 10.0.0.1-5).
func compareIPs(a, b string) bool {
ipA := net.ParseIP(a)
ipB := net.ParseIP(b)
// Fallback to string comparison if parsing fails
if ipA == nil || ipB == nil {
return a < b
}
// Normalize to 16-byte representation for consistent comparison
ipA = ipA.To16()
ipB = ipB.To16()
for i := range ipA {
if ipA[i] != ipB[i] {
return ipA[i] < ipB[i]
}
}
return false
}
// IsVoter returns true if the given raft address is in the voter set
// based on the current known peers. Must be called with c.mu held.
func (c *ClusterDiscoveryService) IsVoterLocked(raftAddress string) bool {
@ -328,6 +364,14 @@ func (c *ClusterDiscoveryService) writePeersJSON() error {
return c.writePeersJSONWithData(peers)
}
// writePeersJSONWithData writes the discovery peers file to a SAFE location
// outside the raft directory. This is critical: rqlite v8 treats any
// peers.json inside <dataDir>/raft/ as a recovery signal and RESETS
// the Raft configuration on startup. Writing there on every periodic sync
// caused split-brain on every node restart.
//
// Safe location: <dataDir>/rqlite/discovery-peers.json
// Dangerous location: <dataDir>/rqlite/raft/peers.json (only for explicit recovery)
func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error {
dataDir := os.ExpandEnv(c.dataDir)
if strings.HasPrefix(dataDir, "~") {
@ -338,30 +382,25 @@ func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]inte
dataDir = filepath.Join(home, dataDir[1:])
}
rqliteDir := filepath.Join(dataDir, "rqlite", "raft")
// Write to <dataDir>/rqlite/ — NOT inside raft/ subdirectory.
// rqlite v8 auto-recovers from raft/peers.json on every startup,
// which resets the Raft config and causes split-brain.
rqliteDir := filepath.Join(dataDir, "rqlite")
if err := os.MkdirAll(rqliteDir, 0755); err != nil {
return fmt.Errorf("failed to create raft directory %s: %w", rqliteDir, err)
return fmt.Errorf("failed to create rqlite directory %s: %w", rqliteDir, err)
}
peersFile := filepath.Join(rqliteDir, "peers.json")
backupFile := filepath.Join(rqliteDir, "peers.json.backup")
if _, err := os.Stat(peersFile); err == nil {
data, err := os.ReadFile(peersFile)
if err == nil {
_ = os.WriteFile(backupFile, data, 0644)
}
}
peersFile := filepath.Join(rqliteDir, "discovery-peers.json")
data, err := json.MarshalIndent(peers, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal peers.json: %w", err)
return fmt.Errorf("failed to marshal discovery-peers.json: %w", err)
}
tempFile := peersFile + ".tmp"
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp peers.json %s: %w", tempFile, err)
return fmt.Errorf("failed to write temp discovery-peers.json %s: %w", tempFile, err)
}
if err := os.Rename(tempFile, peersFile); err != nil {
@ -375,7 +414,57 @@ func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]inte
}
}
c.logger.Info("peers.json written",
c.logger.Debug("discovery-peers.json written",
zap.Int("peers", len(peers)),
zap.Strings("nodes", nodeIDs))
return nil
}
// writeRecoveryPeersJSON writes peers.json to the raft directory for
// INTENTIONAL cluster recovery only. rqlite v8 will read this file on
// startup and reset the Raft configuration accordingly. Only call this
// when you explicitly want to trigger Raft recovery.
func (c *ClusterDiscoveryService) writeRecoveryPeersJSON(peers []map[string]interface{}) error {
dataDir := os.ExpandEnv(c.dataDir)
if strings.HasPrefix(dataDir, "~") {
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to determine home directory: %w", err)
}
dataDir = filepath.Join(home, dataDir[1:])
}
raftDir := filepath.Join(dataDir, "rqlite", "raft")
if err := os.MkdirAll(raftDir, 0755); err != nil {
return fmt.Errorf("failed to create raft directory %s: %w", raftDir, err)
}
peersFile := filepath.Join(raftDir, "peers.json")
data, err := json.MarshalIndent(peers, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal recovery peers.json: %w", err)
}
tempFile := peersFile + ".tmp"
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp recovery peers.json %s: %w", tempFile, err)
}
if err := os.Rename(tempFile, peersFile); err != nil {
return fmt.Errorf("failed to rename %s to %s: %w", tempFile, peersFile, err)
}
nodeIDs := make([]string, 0, len(peers))
for _, p := range peers {
if id, ok := p["id"].(string); ok {
nodeIDs = append(nodeIDs, id)
}
}
c.logger.Warn("RECOVERY peers.json written to raft directory — rqlited will reset Raft config on next startup",
zap.Int("peers", len(peers)),
zap.Strings("nodes", nodeIDs))

View File

@ -128,9 +128,12 @@ func (c *ClusterDiscoveryService) TriggerSync() {
c.updateClusterMembership()
}
// ForceWritePeersJSON forces writing peers.json regardless of membership changes
// ForceWritePeersJSON writes peers.json to the RAFT directory for intentional
// cluster recovery. rqlite v8 will read this on startup and reset its Raft
// configuration. Only call this when you explicitly want Raft recovery
// (e.g., after clearing raft state or during split-brain recovery).
func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
c.logger.Info("Force writing peers.json")
c.logger.Info("Force writing recovery peers.json to raft directory")
metadata := c.collectPeerMetadata()
@ -153,16 +156,17 @@ func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
peers := c.getPeersJSONUnlocked()
c.mu.Unlock()
if err := c.writePeersJSONWithData(peers); err != nil {
c.logger.Error("Failed to force write peers.json",
// Write to RAFT directory — this is intentional recovery
if err := c.writeRecoveryPeersJSON(peers); err != nil {
c.logger.Error("Failed to force write recovery peers.json",
zap.Error(err),
zap.String("data_dir", c.dataDir),
zap.Int("peers", len(peers)))
return err
}
c.logger.Info("peers.json written",
zap.Int("peers", len(peers)))
// Also update discovery location
_ = c.writePeersJSONWithData(peers)
return nil
}
@ -179,7 +183,9 @@ func (c *ClusterDiscoveryService) TriggerPeerExchange(ctx context.Context) error
return nil
}
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore.
// This is called periodically and after significant state changes (lifecycle
// transitions, service status updates) to ensure peers have current info.
func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
c.mu.RLock()
currentRaftAddr := c.raftAddress
@ -194,6 +200,17 @@ func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
LastSeen: time.Now(),
ClusterVersion: "1.0",
PeerID: c.host.ID().String(),
WireGuardIP: c.wireGuardIP,
}
// Populate lifecycle state from the lifecycle manager
if c.lifecycle != nil {
state, ttl := c.lifecycle.Snapshot()
metadata.LifecycleState = string(state)
if state == "maintenance" {
metadata.MaintenanceTTL = ttl
}
}
if c.adjustSelfAdvertisedAddresses(metadata) {
@ -215,7 +232,41 @@ func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
c.logger.Debug("Metadata updated",
zap.String("node", metadata.NodeID),
zap.Uint64("log_index", metadata.RaftLogIndex))
zap.Uint64("log_index", metadata.RaftLogIndex),
zap.String("lifecycle", metadata.LifecycleState))
}
// ProvideMetadata builds and returns the current node metadata without storing it.
// Implements discovery.MetadataProvider so the MetadataPublisher can call this
// on a regular interval and store the result in the peerstore.
func (c *ClusterDiscoveryService) ProvideMetadata() *discovery.RQLiteNodeMetadata {
c.mu.RLock()
currentRaftAddr := c.raftAddress
currentHTTPAddr := c.httpAddress
c.mu.RUnlock()
metadata := &discovery.RQLiteNodeMetadata{
NodeID: currentRaftAddr,
RaftAddress: currentRaftAddr,
HTTPAddress: currentHTTPAddr,
NodeType: c.nodeType,
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
LastSeen: time.Now(),
ClusterVersion: "1.0",
PeerID: c.host.ID().String(),
WireGuardIP: c.wireGuardIP,
}
if c.lifecycle != nil {
state, ttl := c.lifecycle.Snapshot()
metadata.LifecycleState = string(state)
if state == "maintenance" {
metadata.MaintenanceTTL = ttl
}
}
c.adjustSelfAdvertisedAddresses(metadata)
return metadata
}
// StoreRemotePeerMetadata stores metadata received from a remote peer

View File

@ -83,6 +83,14 @@ func (is *InstanceSpawner) SpawnInstance(ctx context.Context, cfg InstanceConfig
"-raft-adv-addr", cfg.RaftAdvAddress,
}
// Raft tuning — match the global node's tuning for consistency
args = append(args,
"-raft-election-timeout", "5s",
"-raft-heartbeat-timeout", "2s",
"-raft-apply-timeout", "30s",
"-raft-leader-lease-timeout", "5s",
)
// Add join addresses if not the leader (must be before data directory)
if !cfg.IsLeader && len(cfg.JoinAddresses) > 0 {
for _, addr := range cfg.JoinAddresses {

131
pkg/rqlite/leadership.go Normal file
View File

@ -0,0 +1,131 @@
package rqlite
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"go.uber.org/zap"
)
// TransferLeadership attempts to transfer Raft leadership to another voter.
// Used by both the RQLiteManager (on Stop) and the CLI (pre-upgrade).
// Returns nil if this node is not the leader or if transfer succeeds.
func TransferLeadership(port int, logger *zap.Logger) error {
client := &http.Client{Timeout: 5 * time.Second}
// 1. Check if we're the leader
statusURL := fmt.Sprintf("http://localhost:%d/status", port)
resp, err := client.Get(statusURL)
if err != nil {
return fmt.Errorf("failed to query status: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read status: %w", err)
}
var status RQLiteStatus
if err := json.Unmarshal(body, &status); err != nil {
return fmt.Errorf("failed to parse status: %w", err)
}
if status.Store.Raft.State != "Leader" {
logger.Debug("Not the leader, skipping transfer", zap.Int("port", port))
return nil
}
logger.Info("This node is the Raft leader, attempting leadership transfer",
zap.Int("port", port),
zap.String("leader_id", status.Store.Raft.LeaderID))
// 2. Find an eligible voter to transfer to
nodesURL := fmt.Sprintf("http://localhost:%d/nodes?nonvoters&ver=2&timeout=5s", port)
nodesResp, err := client.Get(nodesURL)
if err != nil {
return fmt.Errorf("failed to query nodes: %w", err)
}
defer nodesResp.Body.Close()
nodesBody, err := io.ReadAll(nodesResp.Body)
if err != nil {
return fmt.Errorf("failed to read nodes: %w", err)
}
// Try ver=2 wrapped format, fall back to plain array
var nodes RQLiteNodes
var wrapped struct {
Nodes RQLiteNodes `json:"nodes"`
}
if err := json.Unmarshal(nodesBody, &wrapped); err == nil && wrapped.Nodes != nil {
nodes = wrapped.Nodes
} else {
_ = json.Unmarshal(nodesBody, &nodes)
}
// Find a reachable voter that is NOT us
var targetID string
for _, n := range nodes {
if n.Voter && n.Reachable && n.ID != status.Store.Raft.LeaderID {
targetID = n.ID
break
}
}
if targetID == "" {
logger.Warn("No eligible voter found for leadership transfer — will rely on SIGTERM graceful step-down",
zap.Int("port", port))
return nil
}
// 3. Attempt transfer via rqlite v8+ API
// POST /nodes/<target>/transfer-leadership
// If the API doesn't exist (404), fall back to relying on SIGTERM.
transferURL := fmt.Sprintf("http://localhost:%d/nodes/%s/transfer-leadership", port, targetID)
transferResp, err := client.Post(transferURL, "application/json", nil)
if err != nil {
logger.Warn("Leadership transfer request failed, relying on SIGTERM",
zap.Error(err))
return nil
}
transferResp.Body.Close()
if transferResp.StatusCode == http.StatusNotFound {
logger.Info("Leadership transfer API not available (rqlite version), relying on SIGTERM")
return nil
}
if transferResp.StatusCode != http.StatusOK {
logger.Warn("Leadership transfer returned unexpected status",
zap.Int("status", transferResp.StatusCode))
return nil
}
// 4. Verify transfer
time.Sleep(2 * time.Second)
verifyResp, err := client.Get(statusURL)
if err != nil {
logger.Info("Could not verify transfer (node may have already stepped down)")
return nil
}
defer verifyResp.Body.Close()
verifyBody, _ := io.ReadAll(verifyResp.Body)
var newStatus RQLiteStatus
if err := json.Unmarshal(verifyBody, &newStatus); err == nil {
if newStatus.Store.Raft.State != "Leader" {
logger.Info("Leadership transferred successfully",
zap.String("new_leader", newStatus.Store.Raft.LeaderID),
zap.Int("port", port))
} else {
logger.Warn("Still leader after transfer attempt — will rely on SIGTERM",
zap.Int("port", port))
}
}
return nil
}

View File

@ -66,6 +66,29 @@ func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string)
// Kill any orphaned rqlited from a previous crash
r.killOrphanedRQLite()
// Remove stale peers.json from the raft directory to prevent rqlite v8
// from triggering automatic Raft recovery on normal restarts.
//
// Only delete when raft.db EXISTS (normal restart). If raft.db does NOT
// exist, peers.json was likely placed intentionally by ForceWritePeersJSON()
// as part of a recovery flow (clearRaftState + ForceWritePeersJSON + launch).
stalePeersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
raftDBPath := filepath.Join(rqliteDataDir, "raft.db")
if _, err := os.Stat(stalePeersPath); err == nil {
if _, err := os.Stat(raftDBPath); err == nil {
// raft.db exists → this is a normal restart, peers.json is stale
r.logger.Warn("Removing stale peers.json from raft directory to prevent accidental recovery",
zap.String("path", stalePeersPath))
_ = os.Remove(stalePeersPath)
_ = os.Remove(stalePeersPath + ".backup")
_ = os.Remove(stalePeersPath + ".tmp")
} else {
// raft.db missing → intentional recovery, keep peers.json for rqlited
r.logger.Info("Keeping peers.json in raft directory for intentional cluster recovery",
zap.String("path", stalePeersPath))
}
}
// Build RQLite command
args := []string{
"-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort),
@ -90,6 +113,30 @@ func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string)
}
}
// Raft tuning — higher timeouts suit WireGuard latency
raftElection := r.config.RaftElectionTimeout
if raftElection == 0 {
raftElection = 5 * time.Second
}
raftHeartbeat := r.config.RaftHeartbeatTimeout
if raftHeartbeat == 0 {
raftHeartbeat = 2 * time.Second
}
raftApply := r.config.RaftApplyTimeout
if raftApply == 0 {
raftApply = 30 * time.Second
}
raftLeaderLease := r.config.RaftLeaderLeaseTimeout
if raftLeaderLease == 0 {
raftLeaderLease = 5 * time.Second
}
args = append(args,
"-raft-election-timeout", raftElection.String(),
"-raft-heartbeat-timeout", raftHeartbeat.String(),
"-raft-apply-timeout", raftApply.String(),
"-raft-leader-lease-timeout", raftLeaderLease.String(),
)
if r.config.RQLiteJoinAddress != "" && !r.hasExistingState(rqliteDataDir) {
r.logger.Info("First-time join to RQLite cluster", zap.String("join_address", r.config.RQLiteJoinAddress))

View File

@ -70,6 +70,7 @@ func (r *RQLiteManager) Start(ctx context.Context) error {
if r.discoveryService != nil {
go r.startHealthMonitoring(ctx)
go r.startVoterReconciliation(ctx)
go r.startOrphanedNodeRecovery(ctx) // C1 fix: recover nodes orphaned by failed voter changes
}
// Start child process watchdog to detect and recover from crashes
@ -138,21 +139,9 @@ func (r *RQLiteManager) Stop() error {
// transferLeadershipIfLeader checks if this node is the Raft leader and
// requests a leadership transfer to minimize election disruption.
func (r *RQLiteManager) transferLeadershipIfLeader() {
status, err := r.getRQLiteStatus()
if err != nil {
return
if err := TransferLeadership(r.config.RQLitePort, r.logger); err != nil {
r.logger.Warn("Leadership transfer failed, relying on SIGTERM", zap.Error(err))
}
if status.Store.Raft.State != "Leader" {
return
}
r.logger.Info("This node is the Raft leader, requesting leadership transfer before shutdown")
// RQLite doesn't have a direct leadership transfer API, but we can
// signal readiness to step down. The fastest approach is to let the
// SIGTERM handler in rqlited handle this — rqlite v8 gracefully
// steps down on SIGTERM when possible. We log the state for visibility.
r.logger.Info("Leader will transfer on SIGTERM (rqlite built-in behavior)")
}
// cleanupPIDFile removes the PID file on shutdown

View File

@ -7,15 +7,32 @@ import (
"fmt"
"io"
"net/http"
"sync"
"time"
"go.uber.org/zap"
)
const (
// voterChangeCooldown is how long to wait after a failed voter change
// before retrying the same node.
voterChangeCooldown = 10 * time.Minute
)
// voterReconciler holds voter change cooldown state.
type voterReconciler struct {
mu sync.Mutex
cooldowns map[string]time.Time // nodeID → earliest next attempt
}
// startVoterReconciliation periodically checks and corrects voter/non-voter
// assignments. Only takes effect on the leader node. Corrects at most one
// node per cycle to minimize disruption.
func (r *RQLiteManager) startVoterReconciliation(ctx context.Context) {
reconciler := &voterReconciler{
cooldowns: make(map[string]time.Time),
}
// Wait for cluster to stabilize after startup
time.Sleep(3 * time.Minute)
@ -27,21 +44,104 @@ func (r *RQLiteManager) startVoterReconciliation(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.reconcileVoters(); err != nil {
if err := r.reconcileVoters(reconciler); err != nil {
r.logger.Debug("Voter reconciliation skipped", zap.Error(err))
}
}
}
}
// startOrphanedNodeRecovery runs every 5 minutes on the leader. It scans for
// nodes that appear in the discovery peer list but NOT in the Raft cluster
// (orphaned by a failed remove+rejoin during voter reconciliation). For each
// orphaned node, it re-adds them via POST /join. (C1 fix)
func (r *RQLiteManager) startOrphanedNodeRecovery(ctx context.Context) {
// Wait for cluster to stabilize
time.Sleep(5 * time.Minute)
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
r.recoverOrphanedNodes()
}
}
}
// recoverOrphanedNodes finds nodes known to discovery but missing from the
// Raft cluster and re-adds them.
func (r *RQLiteManager) recoverOrphanedNodes() {
if r.discoveryService == nil {
return
}
// Only the leader runs orphan recovery
status, err := r.getRQLiteStatus()
if err != nil || status.Store.Raft.State != "Leader" {
return
}
// Get all Raft cluster members
raftNodes, err := r.getAllClusterNodes()
if err != nil {
return
}
raftNodeSet := make(map[string]bool, len(raftNodes))
for _, n := range raftNodes {
raftNodeSet[n.ID] = true
}
// Get all discovery peers
discoveryPeers := r.discoveryService.GetAllPeers()
for _, peer := range discoveryPeers {
if peer.RaftAddress == r.discoverConfig.RaftAdvAddress {
continue // skip self
}
if raftNodeSet[peer.RaftAddress] {
continue // already in cluster
}
// This peer is in discovery but not in Raft — it's orphaned
r.logger.Warn("Found orphaned node (in discovery but not in Raft cluster), re-adding",
zap.String("node_raft_addr", peer.RaftAddress),
zap.String("node_id", peer.NodeID))
// Determine voter status
raftAddrs := make([]string, 0, len(discoveryPeers))
for _, p := range discoveryPeers {
raftAddrs = append(raftAddrs, p.RaftAddress)
}
voters := computeVoterSet(raftAddrs, MaxDefaultVoters)
_, shouldBeVoter := voters[peer.RaftAddress]
if err := r.joinClusterNode(peer.RaftAddress, peer.RaftAddress, shouldBeVoter); err != nil {
r.logger.Error("Failed to re-add orphaned node",
zap.String("node", peer.RaftAddress),
zap.Bool("voter", shouldBeVoter),
zap.Error(err))
} else {
r.logger.Info("Successfully re-added orphaned node to Raft cluster",
zap.String("node", peer.RaftAddress),
zap.Bool("voter", shouldBeVoter))
}
}
}
// reconcileVoters compares actual cluster voter assignments (from RQLite's
// /nodes endpoint) against the deterministic desired set (computeVoterSet)
// and corrects mismatches. Uses remove + re-join since RQLite's /join
// ignores voter flag changes for existing members.
// and corrects mismatches.
//
// Safety: only runs on the leader, only when all nodes are reachable,
// never demotes the leader, and fixes at most one node per cycle.
func (r *RQLiteManager) reconcileVoters() error {
// Improvements over original:
// - Promotion: tries direct POST /join with voter=true first (no remove needed)
// - Leader stability: verifies leader is stable before demotion
// - Cooldown: skips nodes that recently failed a voter change
// - Fixes at most one node per cycle
func (r *RQLiteManager) reconcileVoters(reconciler *voterReconciler) error {
// 1. Only the leader reconciles
status, err := r.getRQLiteStatus()
if err != nil {
@ -68,14 +168,21 @@ func (r *RQLiteManager) reconcileVoters() error {
}
}
// 4. Compute desired voter set from raft addresses
// 4. Leader stability: verify term hasn't changed recently
// (Re-check status to confirm we're still the stable leader)
status2, err := r.getRQLiteStatus()
if err != nil || status2.Store.Raft.State != "Leader" || status2.Store.Raft.Term != status.Store.Raft.Term {
return fmt.Errorf("leader state changed during reconciliation check")
}
// 5. Compute desired voter set from raft addresses
raftAddrs := make([]string, 0, len(nodes))
for _, n := range nodes {
raftAddrs = append(raftAddrs, n.ID)
}
desiredVoters := computeVoterSet(raftAddrs, MaxDefaultVoters)
// 5. Safety: never demote ourselves (the current leader)
// 6. Safety: never demote ourselves (the current leader)
myRaftAddr := status.Store.Raft.LeaderID
if _, shouldBeVoter := desiredVoters[myRaftAddr]; !shouldBeVoter {
r.logger.Warn("Leader is not in computed voter set — skipping reconciliation",
@ -83,10 +190,19 @@ func (r *RQLiteManager) reconcileVoters() error {
return nil
}
// 6. Find one mismatch to fix (one change per cycle)
// 7. Find one mismatch to fix (one change per cycle)
for _, n := range nodes {
_, shouldBeVoter := desiredVoters[n.ID]
// Check cooldown
reconciler.mu.Lock()
cooldownUntil, hasCooldown := reconciler.cooldowns[n.ID]
if hasCooldown && time.Now().Before(cooldownUntil) {
reconciler.mu.Unlock()
continue
}
reconciler.mu.Unlock()
if n.Voter && !shouldBeVoter {
// Skip if this is the leader
if n.ID == myRaftAddr {
@ -100,6 +216,9 @@ func (r *RQLiteManager) reconcileVoters() error {
r.logger.Warn("Failed to demote voter",
zap.String("node_id", n.ID),
zap.Error(err))
reconciler.mu.Lock()
reconciler.cooldowns[n.ID] = time.Now().Add(voterChangeCooldown)
reconciler.mu.Unlock()
return err
}
@ -112,10 +231,24 @@ func (r *RQLiteManager) reconcileVoters() error {
r.logger.Info("Promoting non-voter to voter",
zap.String("node_id", n.ID))
// Try direct promotion first (POST /join with voter=true)
if err := r.joinClusterNode(n.ID, n.ID, true); err == nil {
r.logger.Info("Successfully promoted non-voter to voter (direct join)",
zap.String("node_id", n.ID))
return nil
}
// Direct join didn't change voter status, fall back to remove+rejoin
r.logger.Info("Direct promotion didn't work, trying remove+rejoin",
zap.String("node_id", n.ID))
if err := r.changeNodeVoterStatus(n.ID, true); err != nil {
r.logger.Warn("Failed to promote non-voter",
zap.String("node_id", n.ID),
zap.Error(err))
reconciler.mu.Lock()
reconciler.cooldowns[n.ID] = time.Now().Add(voterChangeCooldown)
reconciler.mu.Unlock()
return err
}
@ -130,13 +263,12 @@ func (r *RQLiteManager) reconcileVoters() error {
// changeNodeVoterStatus changes a node's voter status by removing it from the
// cluster and immediately re-adding it with the desired voter flag.
// This is necessary because RQLite's /join endpoint ignores voter flag changes
// for nodes that are already cluster members with the same ID and address.
//
// Safety improvements:
// - Pre-check: verify quorum would survive the temporary removal
// - Pre-check: verify target node is still reachable
// - Rollback: if rejoin fails, attempt to re-add with original status
// - Retry: attempt rejoin up to 3 times with backoff
// - Retry: 5 attempts with exponential backoff (2s, 4s, 8s, 15s, 30s)
func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error {
// Pre-check: if demoting a voter, verify quorum safety
if !voter {
@ -145,34 +277,53 @@ func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error {
return fmt.Errorf("quorum pre-check: %w", err)
}
voterCount := 0
targetReachable := false
for _, n := range nodes {
if n.Voter && n.Reachable {
voterCount++
}
if n.ID == nodeID && n.Reachable {
targetReachable = true
}
}
if !targetReachable {
return fmt.Errorf("target node %s is not reachable, skipping voter change", nodeID)
}
// After removing this voter, we need (voterCount-1)/2 + 1 for quorum
// which means voterCount-1 > (voterCount-1)/2, i.e., voterCount >= 3
if voterCount <= 2 {
return fmt.Errorf("cannot remove voter: only %d reachable voters, quorum would be lost", voterCount)
}
}
// Fresh quorum check immediately before removal
nodes, err := r.getAllClusterNodes()
if err != nil {
return fmt.Errorf("fresh quorum check: %w", err)
}
for _, n := range nodes {
if !n.Reachable {
return fmt.Errorf("node %s is unreachable, aborting voter change", n.ID)
}
}
// Step 1: Remove the node from the cluster
if err := r.removeClusterNode(nodeID); err != nil {
return fmt.Errorf("remove node: %w", err)
}
// Wait for Raft to commit the configuration change, then rejoin with retries
// Exponential backoff: 2s, 4s, 8s, 15s, 30s
backoffs := []time.Duration{2 * time.Second, 4 * time.Second, 8 * time.Second, 15 * time.Second, 30 * time.Second}
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
waitTime := time.Duration(2+attempt*2) * time.Second // 2s, 4s, 6s
time.Sleep(waitTime)
for attempt, wait := range backoffs {
time.Sleep(wait)
if err := r.joinClusterNode(nodeID, nodeID, voter); err != nil {
lastErr = err
r.logger.Warn("Rejoin attempt failed, retrying",
zap.String("node_id", nodeID),
zap.Int("attempt", attempt+1),
zap.Int("max_attempts", len(backoffs)),
zap.Error(err))
continue
}
@ -187,12 +338,12 @@ func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error {
originalVoter := !voter
if err := r.joinClusterNode(nodeID, nodeID, originalVoter); err != nil {
r.logger.Error("Rollback also failed — node may be orphaned from cluster",
r.logger.Error("Rollback also failed — node may be orphaned (orphan recovery will re-add it)",
zap.String("node_id", nodeID),
zap.Error(err))
}
return fmt.Errorf("rejoin node after 3 attempts: %w", lastErr)
return fmt.Errorf("rejoin node after %d attempts: %w", len(backoffs), lastErr)
}
// getAllClusterNodes queries /nodes?nonvoters&ver=2 to get all cluster members

View File

@ -10,13 +10,34 @@ import (
)
const (
// watchdogInterval is how often we check if rqlited is alive.
watchdogInterval = 30 * time.Second
// watchdogMaxRestart is the maximum number of restart attempts before giving up.
watchdogMaxRestart = 3
// watchdogGracePeriod is how long to wait after a restart before
// the watchdog starts checking. This gives rqlited time to rejoin
// the Raft cluster — Raft election timeouts + log replay can take
// 60-120 seconds after a restart.
watchdogGracePeriod = 120 * time.Second
)
// startProcessWatchdog monitors the RQLite child process and restarts it if it crashes.
// It checks both process liveness and HTTP responsiveness.
// It only restarts when the process has actually DIED (exited). It does NOT kill
// rqlited for being slow to find a leader — that's normal during cluster rejoin.
func (r *RQLiteManager) startProcessWatchdog(ctx context.Context) {
// Wait for the grace period before starting to monitor.
// rqlited needs time to:
// 1. Open the raft log and snapshots
// 2. Reconnect to existing Raft peers
// 3. Either rejoin as follower or participate in a new election
select {
case <-ctx.Done():
return
case <-time.After(watchdogGracePeriod):
}
ticker := time.NewTicker(watchdogInterval)
defer ticker.Stop()
@ -46,10 +67,21 @@ func (r *RQLiteManager) startProcessWatchdog(ctx context.Context) {
restartCount++
r.logger.Info("RQLite process restarted by watchdog",
zap.Int("restart_count", restartCount))
// Give the restarted process time to stabilize before checking again
select {
case <-ctx.Done():
return
case <-time.After(watchdogGracePeriod):
}
} else {
// Process is alive — check HTTP responsiveness
if !r.isHTTPResponsive() {
r.logger.Warn("RQLite process is alive but not responding to HTTP")
// Process is alive — reset restart counter on sustained health
if r.isHTTPResponsive() {
if restartCount > 0 {
r.logger.Info("RQLite process has stabilized, resetting restart counter",
zap.Int("previous_restart_count", restartCount))
restartCount = 0
}
}
}
}

208
scripts/clean-testnet.sh Executable file
View File

@ -0,0 +1,208 @@
#!/usr/bin/env bash
#
# Clean all testnet nodes for fresh reinstall.
# Preserves Anyone relay keys (/var/lib/anon/) for --anyone-migrate.
# DOES NOT TOUCH DEVNET NODES.
#
# Usage: scripts/clean-testnet.sh [--nuclear]
# --nuclear Also remove shared binaries (rqlited, ipfs, coredns, caddy, etc.)
#
set -euo pipefail
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
CONF="$ROOT_DIR/scripts/remote-nodes.conf"
[[ -f "$CONF" ]] || { echo "ERROR: Missing $CONF"; exit 1; }
command -v sshpass >/dev/null 2>&1 || { echo "ERROR: sshpass not installed (brew install sshpass / apt install sshpass)"; exit 1; }
NUCLEAR=false
[[ "${1:-}" == "--nuclear" ]] && NUCLEAR=true
SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=10 -o LogLevel=ERROR -o PubkeyAuthentication=no)
# ── Cleanup script (runs as root on each remote node) ─────────────────────
# Uses a quoted heredoc so NO local variable expansion happens.
# This script is uploaded to /tmp/orama-clean.sh and executed remotely.
CLEANUP_SCRIPT=$(cat <<'SCRIPT_END'
#!/bin/bash
set -e
export DEBIAN_FRONTEND=noninteractive
echo " Stopping services..."
systemctl stop debros-node debros-gateway debros-ipfs debros-ipfs-cluster debros-olric debros-anyone-relay debros-anyone-client coredns caddy 2>/dev/null || true
systemctl disable debros-node debros-gateway debros-ipfs debros-ipfs-cluster debros-olric debros-anyone-relay debros-anyone-client coredns caddy 2>/dev/null || true
echo " Removing systemd service files..."
rm -f /etc/systemd/system/debros-*.service
rm -f /etc/systemd/system/coredns.service
rm -f /etc/systemd/system/caddy.service
rm -f /etc/systemd/system/orama-deploy-*.service
systemctl daemon-reload
echo " Tearing down WireGuard..."
systemctl stop wg-quick@wg0 2>/dev/null || true
wg-quick down wg0 2>/dev/null || true
systemctl disable wg-quick@wg0 2>/dev/null || true
rm -f /etc/wireguard/wg0.conf
echo " Resetting UFW firewall..."
ufw --force reset
ufw allow 22/tcp
ufw --force enable
echo " Killing debros processes..."
pkill -u debros 2>/dev/null || true
sleep 1
echo " Removing debros user and data..."
userdel -r debros 2>/dev/null || true
rm -rf /home/debros
echo " Removing sudoers files..."
rm -f /etc/sudoers.d/debros-access
rm -f /etc/sudoers.d/debros-deployments
rm -f /etc/sudoers.d/debros-wireguard
echo " Removing CoreDNS and Caddy configs..."
rm -rf /etc/coredns
rm -rf /etc/caddy
rm -rf /var/lib/caddy
echo " Cleaning temp files..."
rm -f /tmp/orama /tmp/network-source.tar.gz /tmp/network-source.zip
rm -rf /tmp/network-extract /tmp/coredns-build /tmp/caddy-build
# Nuclear: also remove shared binaries
if [ "${1:-}" = "--nuclear" ]; then
echo " Removing shared binaries (nuclear)..."
rm -f /usr/local/bin/rqlited
rm -f /usr/local/bin/ipfs
rm -f /usr/local/bin/ipfs-cluster-service
rm -f /usr/local/bin/olric-server
rm -f /usr/local/bin/coredns
rm -f /usr/local/bin/xcaddy
rm -f /usr/bin/caddy
rm -f /usr/local/bin/orama
fi
# Verify Anyone relay keys are preserved
if [ -d /var/lib/anon/keys ]; then
echo " Anyone relay keys PRESERVED at /var/lib/anon/keys"
if [ -f /var/lib/anon/fingerprint ]; then
fp=$(cat /var/lib/anon/fingerprint 2>/dev/null || true)
echo " Relay fingerprint: $fp"
fi
if [ -f /var/lib/anon/wallet ]; then
wallet=$(cat /var/lib/anon/wallet 2>/dev/null || true)
echo " Relay wallet: $wallet"
fi
else
echo " WARNING: No Anyone relay keys found at /var/lib/anon/"
fi
echo " DONE"
SCRIPT_END
)
# ── Parse testnet nodes only ──────────────────────────────────────────────
hosts=()
passes=()
users=()
while IFS='|' read -r env hostspec pass role key; do
[[ -z "$env" || "$env" == \#* ]] && continue
env="${env%%#*}"
env="$(echo "$env" | xargs)"
[[ "$env" != "testnet" ]] && continue
hosts+=("$hostspec")
passes+=("$pass")
users+=("${hostspec%%@*}")
done < "$CONF"
if [[ ${#hosts[@]} -eq 0 ]]; then
echo "ERROR: No testnet nodes found in $CONF"
exit 1
fi
echo "== clean-testnet.sh — ${#hosts[@]} testnet nodes =="
for i in "${!hosts[@]}"; do
echo " [$((i+1))] ${hosts[$i]}"
done
echo ""
echo "This will CLEAN all testnet nodes (stop services, remove data)."
echo "Anyone relay keys (/var/lib/anon/) will be PRESERVED."
echo "Devnet nodes will NOT be touched."
$NUCLEAR && echo "Nuclear mode: shared binaries will also be removed."
echo ""
read -rp "Type 'yes' to continue: " confirm
if [[ "$confirm" != "yes" ]]; then
echo "Aborted."
exit 0
fi
# ── Execute cleanup on each node ──────────────────────────────────────────
failed=()
succeeded=0
NUCLEAR_FLAG=""
$NUCLEAR && NUCLEAR_FLAG="--nuclear"
for i in "${!hosts[@]}"; do
h="${hosts[$i]}"
p="${passes[$i]}"
u="${users[$i]}"
echo ""
echo "== [$((i+1))/${#hosts[@]}] Cleaning $h =="
# Step 1: Upload cleanup script
# No -n flag here — we're piping the script content via stdin
if ! echo "$CLEANUP_SCRIPT" | sshpass -p "$p" ssh "${SSH_OPTS[@]}" "$h" \
"cat > /tmp/orama-clean.sh && chmod +x /tmp/orama-clean.sh" 2>&1; then
echo " !! FAILED to upload script to $h"
failed+=("$h")
continue
fi
# Step 2: Execute the cleanup script as root
if [[ "$u" == "root" ]]; then
# Root: run directly
if ! sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" \
"bash /tmp/orama-clean.sh $NUCLEAR_FLAG; rm -f /tmp/orama-clean.sh" 2>&1; then
echo " !! FAILED: $h"
failed+=("$h")
continue
fi
else
# Non-root: escape password for single-quote embedding, pipe to sudo -S
escaped_p=$(printf '%s' "$p" | sed "s/'/'\\\\''/g")
if ! sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" \
"printf '%s\n' '${escaped_p}' | sudo -S bash /tmp/orama-clean.sh $NUCLEAR_FLAG; rm -f /tmp/orama-clean.sh" 2>&1; then
echo " !! FAILED: $h"
failed+=("$h")
continue
fi
fi
echo " OK: $h cleaned"
((succeeded++)) || true
done
echo ""
echo "========================================"
echo "Cleanup complete: $succeeded succeeded, ${#failed[@]} failed"
if [[ ${#failed[@]} -gt 0 ]]; then
echo ""
echo "Failed nodes:"
for f in "${failed[@]}"; do
echo " - $f"
done
echo ""
echo "Troubleshooting:"
echo " 1. Check connectivity: ssh <user>@<host>"
echo " 2. Check password in remote-nodes.conf"
echo " 3. Try cleaning manually: docs/CLEAN_NODE.md"
fi
echo ""
echo "Anyone relay keys preserved at /var/lib/anon/ on all nodes."
echo "Use --anyone-migrate during install to reuse existing relay identity."
echo "========================================"

289
scripts/recover-rqlite.sh Normal file
View File

@ -0,0 +1,289 @@
#!/usr/bin/env bash
#
# Recover RQLite cluster from split-brain.
#
# Strategy:
# 1. Stop debros-node on ALL nodes simultaneously
# 2. Keep raft/ data ONLY on the node with the highest commit index (leader candidate)
# 3. Delete raft/ on all other nodes (they'll join fresh via -join)
# 4. Start the leader candidate first, wait for it to become Leader
# 5. Start all other nodes — they discover the leader via LibP2P and join
# 6. Verify cluster health
#
# Usage:
# scripts/recover-rqlite.sh --devnet --leader 57.129.7.232
# scripts/recover-rqlite.sh --testnet --leader <ip>
#
set -euo pipefail
# ── Parse flags ──────────────────────────────────────────────────────────────
ENV=""
LEADER_HOST=""
for arg in "$@"; do
case "$arg" in
--devnet) ENV="devnet" ;;
--testnet) ENV="testnet" ;;
--leader=*) LEADER_HOST="${arg#--leader=}" ;;
-h|--help)
echo "Usage: scripts/recover-rqlite.sh --devnet|--testnet --leader=<public_ip_or_user@host>"
exit 0
;;
*)
echo "Unknown flag: $arg" >&2
exit 1
;;
esac
done
if [[ -z "$ENV" ]]; then
echo "ERROR: specify --devnet or --testnet" >&2
exit 1
fi
if [[ -z "$LEADER_HOST" ]]; then
echo "ERROR: specify --leader=<host> (the node with highest commit index)" >&2
exit 1
fi
# ── Paths ────────────────────────────────────────────────────────────────────
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
CONF="$ROOT_DIR/scripts/remote-nodes.conf"
die() { echo "ERROR: $*" >&2; exit 1; }
[[ -f "$CONF" ]] || die "Missing $CONF"
# ── Load nodes from conf ────────────────────────────────────────────────────
HOSTS=()
PASSES=()
ROLES=()
SSH_KEYS=()
while IFS='|' read -r env host pass role key; do
[[ -z "$env" || "$env" == \#* ]] && continue
env="${env%%#*}"
env="$(echo "$env" | xargs)"
[[ "$env" != "$ENV" ]] && continue
HOSTS+=("$host")
PASSES+=("$pass")
ROLES+=("${role:-node}")
SSH_KEYS+=("${key:-}")
done < "$CONF"
if [[ ${#HOSTS[@]} -eq 0 ]]; then
die "No nodes found for environment '$ENV' in $CONF"
fi
echo "== recover-rqlite.sh ($ENV) — ${#HOSTS[@]} nodes =="
echo "Leader candidate: $LEADER_HOST"
echo ""
# Find leader index
LEADER_IDX=-1
for i in "${!HOSTS[@]}"; do
if [[ "${HOSTS[$i]}" == *"$LEADER_HOST"* ]]; then
LEADER_IDX=$i
break
fi
done
if [[ $LEADER_IDX -eq -1 ]]; then
die "Leader host '$LEADER_HOST' not found in node list"
fi
echo "Nodes:"
for i in "${!HOSTS[@]}"; do
marker=""
[[ $i -eq $LEADER_IDX ]] && marker=" ← LEADER (keep data)"
echo " [$i] ${HOSTS[$i]} (${ROLES[$i]})$marker"
done
echo ""
# ── SSH helpers ──────────────────────────────────────────────────────────────
SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=10)
node_ssh() {
local idx="$1"
shift
local h="${HOSTS[$idx]}"
local p="${PASSES[$idx]}"
local k="${SSH_KEYS[$idx]:-}"
if [[ -n "$k" ]]; then
local expanded_key="${k/#\~/$HOME}"
if [[ -f "$expanded_key" ]]; then
ssh -i "$expanded_key" "${SSH_OPTS[@]}" "$h" "$@" 2>/dev/null
return $?
fi
fi
sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" "$@" 2>/dev/null
}
# ── Confirmation ─────────────────────────────────────────────────────────────
echo "⚠️ THIS WILL:"
echo " 1. Stop debros-node on ALL ${#HOSTS[@]} nodes"
echo " 2. DELETE raft/ data on ${#HOSTS[@]}-1 nodes (backup to /tmp/rqlite-raft-backup/)"
echo " 3. Keep raft/ data ONLY on ${HOSTS[$LEADER_IDX]} (leader candidate)"
echo " 4. Restart all nodes to reform the cluster"
echo ""
read -r -p "Continue? [y/N] " confirm
if [[ "$confirm" != "y" && "$confirm" != "Y" ]]; then
echo "Aborted."
exit 0
fi
echo ""
RAFT_DIR="/home/debros/.orama/data/rqlite/raft"
BACKUP_DIR="/tmp/rqlite-raft-backup"
# ── Phase 1: Stop debros-node on ALL nodes ───────────────────────────────────
echo "== Phase 1: Stopping debros-node on all ${#HOSTS[@]} nodes =="
failed=()
for i in "${!HOSTS[@]}"; do
h="${HOSTS[$i]}"
p="${PASSES[$i]}"
echo -n " Stopping $h ... "
if node_ssh "$i" "printf '%s\n' '$p' | sudo -S systemctl stop debros-node 2>&1 && echo STOPPED"; then
echo ""
else
echo "FAILED"
failed+=("$h")
fi
done
if [[ ${#failed[@]} -gt 0 ]]; then
echo ""
echo "⚠️ ${#failed[@]} nodes failed to stop. Attempting kill..."
for i in "${!HOSTS[@]}"; do
h="${HOSTS[$i]}"
p="${PASSES[$i]}"
for fh in "${failed[@]}"; do
if [[ "$h" == "$fh" ]]; then
node_ssh "$i" "printf '%s\n' '$p' | sudo -S killall -9 orama-node rqlited 2>/dev/null; echo KILLED" || true
fi
done
done
fi
echo ""
echo "Waiting 5s for processes to fully stop..."
sleep 5
# ── Phase 2: Backup and delete raft/ on non-leader nodes ────────────────────
echo "== Phase 2: Clearing raft state on non-leader nodes =="
for i in "${!HOSTS[@]}"; do
[[ $i -eq $LEADER_IDX ]] && continue
h="${HOSTS[$i]}"
p="${PASSES[$i]}"
echo -n " Clearing $h ... "
if node_ssh "$i" "
printf '%s\n' '$p' | sudo -S bash -c '
rm -rf $BACKUP_DIR
if [ -d $RAFT_DIR ]; then
cp -r $RAFT_DIR $BACKUP_DIR 2>/dev/null || true
rm -rf $RAFT_DIR
echo \"CLEARED (backup at $BACKUP_DIR)\"
else
echo \"NO_RAFT_DIR (nothing to clear)\"
fi
'
"; then
true
else
echo "FAILED"
fi
done
echo ""
echo "Leader node ${HOSTS[$LEADER_IDX]} raft/ data preserved."
# ── Phase 3: Start leader node ──────────────────────────────────────────────
echo ""
echo "== Phase 3: Starting leader node (${HOSTS[$LEADER_IDX]}) =="
lp="${PASSES[$LEADER_IDX]}"
node_ssh "$LEADER_IDX" "printf '%s\n' '$lp' | sudo -S systemctl start debros-node" || die "Failed to start leader node"
echo " Waiting for leader to become Leader..."
max_wait=120
elapsed=0
while [[ $elapsed -lt $max_wait ]]; do
state=$(node_ssh "$LEADER_IDX" "curl -s --max-time 3 http://localhost:5001/status 2>/dev/null | python3 -c \"import sys,json; d=json.load(sys.stdin); print(d.get('store',{}).get('raft',{}).get('state',''))\" 2>/dev/null" || echo "")
if [[ "$state" == "Leader" ]]; then
echo " ✓ Leader node is Leader after ${elapsed}s"
break
fi
echo " ... state=$state (${elapsed}s / ${max_wait}s)"
sleep 5
((elapsed+=5))
done
if [[ "$state" != "Leader" ]]; then
echo " ⚠️ Leader did not become Leader within ${max_wait}s (state=$state)"
echo " The node may need more time. Continuing anyway..."
fi
# ── Phase 4: Start all other nodes ──────────────────────────────────────────
echo ""
echo "== Phase 4: Starting remaining nodes =="
# Start non-leader nodes in batches of 3 with 15s between batches
batch_size=3
batch_count=0
for i in "${!HOSTS[@]}"; do
[[ $i -eq $LEADER_IDX ]] && continue
h="${HOSTS[$i]}"
p="${PASSES[$i]}"
echo -n " Starting $h ... "
if node_ssh "$i" "printf '%s\n' '$p' | sudo -S systemctl start debros-node && echo STARTED"; then
true
else
echo "FAILED"
fi
((batch_count++))
if [[ $((batch_count % batch_size)) -eq 0 ]]; then
echo " (waiting 15s between batches for cluster stability)"
sleep 15
fi
done
# ── Phase 5: Wait and verify ────────────────────────────────────────────────
echo ""
echo "== Phase 5: Waiting for cluster to form (120s) =="
sleep 30
echo " ... 30s"
sleep 30
echo " ... 60s"
sleep 30
echo " ... 90s"
sleep 30
echo " ... 120s"
echo ""
echo "== Cluster status =="
for i in "${!HOSTS[@]}"; do
h="${HOSTS[$i]}"
result=$(node_ssh "$i" "curl -s --max-time 5 http://localhost:5001/status 2>/dev/null | python3 -c \"
import sys,json
try:
d=json.load(sys.stdin)
r=d.get('store',{}).get('raft',{})
n=d.get('store',{}).get('num_nodes','?')
print(f'state={r.get(\"state\",\"?\")} commit={r.get(\"commit_index\",\"?\")} leader={r.get(\"leader\",{}).get(\"node_id\",\"?\")} nodes={n}')
except:
print('NO_RESPONSE')
\" 2>/dev/null" || echo "SSH_FAILED")
marker=""
[[ $i -eq $LEADER_IDX ]] && marker=" ← LEADER"
echo " ${HOSTS[$i]}: $result$marker"
done
echo ""
echo "== Recovery complete =="
echo ""
echo "Next steps:"
echo " 1. Run 'scripts/inspect.sh --devnet' to verify full cluster health"
echo " 2. If some nodes show Candidate state, give them more time (up to 5 min)"
echo " 3. If nodes fail to join, check /home/debros/.orama/logs/rqlite-node.log on the node"

View File

@ -3,18 +3,19 @@
# Redeploy to all nodes in a given environment (devnet or testnet).
# Reads node credentials from scripts/remote-nodes.conf.
#
# Flow (per docs/DEV_DEPLOY.md):
# 1) make build-linux
# Flow:
# 1) make build-linux-all
# 2) scripts/generate-source-archive.sh -> /tmp/network-source.tar.gz
# 3) scp archive + extract-deploy.sh + conf to hub node
# 4) from hub: sshpass scp to all other nodes + sudo bash /tmp/extract-deploy.sh
# 5) rolling: upgrade followers one-by-one, leader last
# 5) rolling upgrade: followers first, leader last
# per node: pre-upgrade -> stop -> extract binary -> post-upgrade
#
# Usage:
# scripts/redeploy.sh --devnet
# scripts/redeploy.sh --testnet
# scripts/redeploy.sh --devnet --no-build
# scripts/redeploy.sh --testnet --no-build
# scripts/redeploy.sh --devnet --skip-build
#
set -euo pipefail
@ -26,14 +27,14 @@ for arg in "$@"; do
case "$arg" in
--devnet) ENV="devnet" ;;
--testnet) ENV="testnet" ;;
--no-build) NO_BUILD=1 ;;
--no-build|--skip-build) NO_BUILD=1 ;;
-h|--help)
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build]"
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build|--skip-build]"
exit 0
;;
*)
echo "Unknown flag: $arg" >&2
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build]" >&2
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build|--skip-build]" >&2
exit 1
;;
esac
@ -106,9 +107,9 @@ echo "Hub: $HUB_HOST (idx=$HUB_IDX, key=${HUB_KEY:-none})"
# ── Build ────────────────────────────────────────────────────────────────────
if [[ "$NO_BUILD" -eq 0 ]]; then
echo "== build-linux =="
(cd "$ROOT_DIR" && make build-linux) || {
echo "WARN: make build-linux failed; continuing if existing bin-linux is acceptable."
echo "== build-linux-all =="
(cd "$ROOT_DIR" && make build-linux-all) || {
echo "WARN: make build-linux-all failed; continuing if existing bin-linux is acceptable."
}
else
echo "== skipping build (--no-build) =="
@ -192,12 +193,16 @@ if [[ ${#hosts[@]} -gt 0 ]] && ! command -v sshpass >/dev/null 2>&1; then
fi
echo "== fan-out: upload to ${#hosts[@]} nodes =="
upload_failed=()
for i in "${!hosts[@]}"; do
h="${hosts[$i]}"
p="${passes[$i]}"
echo " -> $h"
sshpass -p "$p" scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$TAR" "$EX" "$h":/tmp/
if ! sshpass -p "$p" scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$TAR" "$EX" "$h":/tmp/; then
echo " !! UPLOAD FAILED: $h"
upload_failed+=("$h")
fi
done
echo "== extract on all fan-out nodes =="
@ -205,10 +210,22 @@ for i in "${!hosts[@]}"; do
h="${hosts[$i]}"
p="${passes[$i]}"
echo " -> $h"
sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK"
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK"; then
echo " !! EXTRACT FAILED: $h"
upload_failed+=("$h")
fi
done
if [[ ${#upload_failed[@]} -gt 0 ]]; then
echo ""
echo "WARNING: ${#upload_failed[@]} nodes had upload/extract failures:"
for uf in "${upload_failed[@]}"; do
echo " - $uf"
done
echo "Continuing with rolling restart..."
fi
echo "== extract on hub =="
printf '%s\n' "$hub_pass" | sudo -S bash "$EX" >/tmp/extract.log 2>&1
@ -253,44 +270,131 @@ if [[ -z "$leader" ]]; then
fi
echo "Leader: $leader"
failed_nodes=()
# ── Per-node upgrade flow ──
# Uses pre-upgrade (maintenance + leadership transfer + propagation wait)
# then stops, deploys binary, and post-upgrade (start + health verification).
upgrade_one() {
local h="$1" p="$2"
echo "== upgrade $h =="
sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$h" "printf '%s\n' '$p' | sudo -S orama prod stop && printf '%s\n' '$p' | sudo -S orama upgrade --no-pull --pre-built --restart"
# 1. Pre-upgrade: enter maintenance, transfer leadership, wait for propagation
echo " [1/4] pre-upgrade..."
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$h" "printf '%s\n' '$p' | sudo -S orama prod pre-upgrade" 2>&1; then
echo " !! pre-upgrade failed on $h (continuing with stop)"
fi
# 2. Stop all services
echo " [2/4] stopping services..."
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$h" "printf '%s\n' '$p' | sudo -S systemctl stop 'debros-*'" 2>&1; then
echo " !! stop failed on $h"
failed_nodes+=("$h")
return 1
fi
# 3. Deploy new binary
echo " [3/4] deploying binary..."
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK" 2>&1; then
echo " !! extract failed on $h"
failed_nodes+=("$h")
return 1
fi
# 4. Post-upgrade: start services, verify health, exit maintenance
echo " [4/4] post-upgrade..."
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
"$h" "printf '%s\n' '$p' | sudo -S orama prod post-upgrade" 2>&1; then
echo " !! post-upgrade failed on $h"
failed_nodes+=("$h")
return 1
fi
echo " OK: $h"
}
upgrade_hub() {
echo "== upgrade hub (localhost) =="
printf '%s\n' "$hub_pass" | sudo -S orama prod stop
printf '%s\n' "$hub_pass" | sudo -S orama upgrade --no-pull --pre-built --restart
# 1. Pre-upgrade
echo " [1/4] pre-upgrade..."
if ! (printf '%s\n' "$hub_pass" | sudo -S orama prod pre-upgrade) 2>&1; then
echo " !! pre-upgrade failed on hub (continuing with stop)"
fi
# 2. Stop all services
echo " [2/4] stopping services..."
if ! (printf '%s\n' "$hub_pass" | sudo -S systemctl stop 'debros-*') 2>&1; then
echo " !! stop failed on hub ($hub_host)"
failed_nodes+=("$hub_host (hub)")
return 1
fi
# 3. Deploy new binary
echo " [3/4] deploying binary..."
if ! (printf '%s\n' "$hub_pass" | sudo -S bash "$EX" >/tmp/extract.log 2>&1); then
echo " !! extract failed on hub ($hub_host)"
failed_nodes+=("$hub_host (hub)")
return 1
fi
# 4. Post-upgrade
echo " [4/4] post-upgrade..."
if ! (printf '%s\n' "$hub_pass" | sudo -S orama prod post-upgrade) 2>&1; then
echo " !! post-upgrade failed on hub ($hub_host)"
failed_nodes+=("$hub_host (hub)")
return 1
fi
echo " OK: hub ($hub_host)"
}
echo "== rolling upgrade (followers first) =="
echo "== rolling upgrade (followers first, leader last) =="
for i in "${!hosts[@]}"; do
h="${hosts[$i]}"
p="${passes[$i]}"
[[ "$h" == "$leader" ]] && continue
upgrade_one "$h" "$p"
upgrade_one "$h" "$p" || true
done
# Upgrade hub if not the leader
if [[ "$leader" != "HUB" ]]; then
upgrade_hub
upgrade_hub || true
fi
# Upgrade leader last
echo "== upgrade leader last =="
if [[ "$leader" == "HUB" ]]; then
upgrade_hub
upgrade_hub || true
else
upgrade_one "$leader" "$leader_pass"
upgrade_one "$leader" "$leader_pass" || true
fi
# Clean up conf from hub
rm -f "$CONF"
echo "Done."
# ── Report results ──
echo ""
echo "========================================"
if [[ ${#failed_nodes[@]} -gt 0 ]]; then
echo "UPGRADE COMPLETED WITH FAILURES (${#failed_nodes[@]} nodes failed):"
for fn in "${failed_nodes[@]}"; do
echo " FAILED: $fn"
done
echo ""
echo "Recommended actions:"
echo " 1. SSH into the failed node(s)"
echo " 2. Check logs: sudo orama prod logs node --follow"
echo " 3. Manually run: sudo orama prod post-upgrade"
echo "========================================"
exit 1
else
echo "All nodes upgraded successfully."
echo "========================================"
fi
REMOTE
echo "== complete =="