mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 08:36:57 +00:00
Bro i did so many things to fix the problematic discovery and redeployment and i dont even remember what i did
This commit is contained in:
parent
afbb7d4ede
commit
749d5ed5e7
@ -43,6 +43,10 @@ func HandleCommand(args []string) {
|
|||||||
case "restart":
|
case "restart":
|
||||||
force := hasFlag(subargs, "--force")
|
force := hasFlag(subargs, "--force")
|
||||||
lifecycle.HandleRestartWithFlags(force)
|
lifecycle.HandleRestartWithFlags(force)
|
||||||
|
case "pre-upgrade":
|
||||||
|
lifecycle.HandlePreUpgrade()
|
||||||
|
case "post-upgrade":
|
||||||
|
lifecycle.HandlePostUpgrade()
|
||||||
case "logs":
|
case "logs":
|
||||||
logs.Handle(subargs)
|
logs.Handle(subargs)
|
||||||
case "uninstall":
|
case "uninstall":
|
||||||
@ -105,6 +109,10 @@ func ShowHelp() {
|
|||||||
fmt.Printf(" restart - Restart all production services (requires root/sudo)\n")
|
fmt.Printf(" restart - Restart all production services (requires root/sudo)\n")
|
||||||
fmt.Printf(" Options:\n")
|
fmt.Printf(" Options:\n")
|
||||||
fmt.Printf(" --force - Bypass quorum safety check\n")
|
fmt.Printf(" --force - Bypass quorum safety check\n")
|
||||||
|
fmt.Printf(" pre-upgrade - Prepare node for safe restart (requires root/sudo)\n")
|
||||||
|
fmt.Printf(" Transfers leadership, enters maintenance mode, waits for propagation\n")
|
||||||
|
fmt.Printf(" post-upgrade - Bring node back online after restart (requires root/sudo)\n")
|
||||||
|
fmt.Printf(" Starts services, verifies RQLite health, exits maintenance\n")
|
||||||
fmt.Printf(" logs <service> - View production service logs\n")
|
fmt.Printf(" logs <service> - View production service logs\n")
|
||||||
fmt.Printf(" Service aliases: node, ipfs, cluster, gateway, olric\n")
|
fmt.Printf(" Service aliases: node, ipfs, cluster, gateway, olric\n")
|
||||||
fmt.Printf(" Options:\n")
|
fmt.Printf(" Options:\n")
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
libp2ppubsub "github.com/libp2p/go-libp2p-pubsub"
|
libp2ppubsub "github.com/libp2p/go-libp2p-pubsub"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/encryption"
|
||||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -144,6 +145,30 @@ func (c *Client) Connect() error {
|
|||||||
libp2p.DefaultMuxers,
|
libp2p.DefaultMuxers,
|
||||||
)
|
)
|
||||||
opts = append(opts, libp2p.Transport(tcp.NewTCPTransport))
|
opts = append(opts, libp2p.Transport(tcp.NewTCPTransport))
|
||||||
|
|
||||||
|
// Load or create persistent identity if IdentityPath is configured
|
||||||
|
if c.config.IdentityPath != "" {
|
||||||
|
identity, loadErr := encryption.LoadIdentity(c.config.IdentityPath)
|
||||||
|
if loadErr != nil {
|
||||||
|
// File doesn't exist yet — generate and save
|
||||||
|
identity, loadErr = encryption.GenerateIdentity()
|
||||||
|
if loadErr != nil {
|
||||||
|
return fmt.Errorf("failed to generate identity: %w", loadErr)
|
||||||
|
}
|
||||||
|
if saveErr := encryption.SaveIdentity(identity, c.config.IdentityPath); saveErr != nil {
|
||||||
|
return fmt.Errorf("failed to save identity: %w", saveErr)
|
||||||
|
}
|
||||||
|
c.logger.Info("Generated new persistent identity",
|
||||||
|
zap.String("peer_id", identity.PeerID.String()),
|
||||||
|
zap.String("path", c.config.IdentityPath))
|
||||||
|
} else {
|
||||||
|
c.logger.Info("Loaded persistent identity",
|
||||||
|
zap.String("peer_id", identity.PeerID.String()),
|
||||||
|
zap.String("path", c.config.IdentityPath))
|
||||||
|
}
|
||||||
|
opts = append(opts, libp2p.Identity(identity.PrivateKey))
|
||||||
|
}
|
||||||
|
|
||||||
// Enable QUIC only when not proxying. When proxy is enabled, prefer TCP via SOCKS5.
|
// Enable QUIC only when not proxying. When proxy is enabled, prefer TCP via SOCKS5.
|
||||||
h, err := libp2p.New(opts...)
|
h, err := libp2p.New(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -11,13 +11,14 @@ type ClientConfig struct {
|
|||||||
DatabaseName string `json:"database_name"`
|
DatabaseName string `json:"database_name"`
|
||||||
BootstrapPeers []string `json:"peers"`
|
BootstrapPeers []string `json:"peers"`
|
||||||
DatabaseEndpoints []string `json:"database_endpoints"`
|
DatabaseEndpoints []string `json:"database_endpoints"`
|
||||||
GatewayURL string `json:"gateway_url"` // Gateway URL for HTTP API access
|
GatewayURL string `json:"gateway_url"` // Gateway URL for HTTP API access
|
||||||
ConnectTimeout time.Duration `json:"connect_timeout"`
|
ConnectTimeout time.Duration `json:"connect_timeout"`
|
||||||
RetryAttempts int `json:"retry_attempts"`
|
RetryAttempts int `json:"retry_attempts"`
|
||||||
RetryDelay time.Duration `json:"retry_delay"`
|
RetryDelay time.Duration `json:"retry_delay"`
|
||||||
QuietMode bool `json:"quiet_mode"` // Suppress debug/info logs
|
QuietMode bool `json:"quiet_mode"` // Suppress debug/info logs
|
||||||
APIKey string `json:"api_key"` // API key for gateway auth
|
APIKey string `json:"api_key"` // API key for gateway auth
|
||||||
JWT string `json:"jwt"` // Optional JWT bearer token
|
JWT string `json:"jwt"` // Optional JWT bearer token
|
||||||
|
IdentityPath string `json:"identity_path"` // Path to persistent LibP2P identity key file
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultClientConfig returns a default client configuration
|
// DefaultClientConfig returns a default client configuration
|
||||||
|
|||||||
92
pkg/client/identity_test.go
Normal file
92
pkg/client/identity_test.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/encryption"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPersistentIdentity_NoPath(t *testing.T) {
|
||||||
|
// Without IdentityPath, Connect() generates a random ID each time.
|
||||||
|
// We can't easily test Connect() (needs network), so verify config defaults.
|
||||||
|
cfg := DefaultClientConfig("test-app")
|
||||||
|
if cfg.IdentityPath != "" {
|
||||||
|
t.Fatalf("expected empty IdentityPath by default, got %q", cfg.IdentityPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistentIdentity_GenerateAndReload(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
keyPath := filepath.Join(dir, "identity.key")
|
||||||
|
|
||||||
|
// 1. No file exists — generate + save
|
||||||
|
id1, err := encryption.GenerateIdentity()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateIdentity: %v", err)
|
||||||
|
}
|
||||||
|
if err := encryption.SaveIdentity(id1, keyPath); err != nil {
|
||||||
|
t.Fatalf("SaveIdentity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// File should exist
|
||||||
|
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||||
|
t.Fatal("identity key file was not created")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Load it back — same PeerID
|
||||||
|
id2, err := encryption.LoadIdentity(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadIdentity: %v", err)
|
||||||
|
}
|
||||||
|
if id1.PeerID != id2.PeerID {
|
||||||
|
t.Fatalf("PeerID mismatch: generated %s, loaded %s", id1.PeerID, id2.PeerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Load again — still the same
|
||||||
|
id3, err := encryption.LoadIdentity(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadIdentity (second): %v", err)
|
||||||
|
}
|
||||||
|
if id2.PeerID != id3.PeerID {
|
||||||
|
t.Fatalf("PeerID changed across loads: %s vs %s", id2.PeerID, id3.PeerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistentIdentity_DifferentFromRandom(t *testing.T) {
|
||||||
|
// A persistent identity should be different from a freshly generated one
|
||||||
|
id1, err := encryption.GenerateIdentity()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateIdentity: %v", err)
|
||||||
|
}
|
||||||
|
id2, err := encryption.GenerateIdentity()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateIdentity: %v", err)
|
||||||
|
}
|
||||||
|
if id1.PeerID == id2.PeerID {
|
||||||
|
t.Fatal("two independently generated identities should have different PeerIDs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistentIdentity_FilePermissions(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
keyPath := filepath.Join(dir, "subdir", "identity.key")
|
||||||
|
|
||||||
|
id, err := encryption.GenerateIdentity()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateIdentity: %v", err)
|
||||||
|
}
|
||||||
|
if err := encryption.SaveIdentity(id, keyPath); err != nil {
|
||||||
|
t.Fatalf("SaveIdentity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stat: %v", err)
|
||||||
|
}
|
||||||
|
perm := info.Mode().Perm()
|
||||||
|
if perm != 0600 {
|
||||||
|
t.Fatalf("expected file permissions 0600, got %o", perm)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -22,6 +22,13 @@ type DatabaseConfig struct {
|
|||||||
NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set)
|
NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set)
|
||||||
NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs)
|
NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs)
|
||||||
|
|
||||||
|
// Raft tuning (passed through to rqlited CLI flags).
|
||||||
|
// Higher defaults than rqlited's 1s suit WireGuard latency.
|
||||||
|
RaftElectionTimeout time.Duration `yaml:"raft_election_timeout"` // default: 5s
|
||||||
|
RaftHeartbeatTimeout time.Duration `yaml:"raft_heartbeat_timeout"` // default: 2s
|
||||||
|
RaftApplyTimeout time.Duration `yaml:"raft_apply_timeout"` // default: 30s
|
||||||
|
RaftLeaderLeaseTimeout time.Duration `yaml:"raft_leader_lease_timeout"` // default: 5s
|
||||||
|
|
||||||
// Dynamic discovery configuration (always enabled)
|
// Dynamic discovery configuration (always enabled)
|
||||||
ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s
|
ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s
|
||||||
PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h
|
PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/libp2p/go-libp2p/core/host"
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
@ -77,10 +78,14 @@ type PeerInfo struct {
|
|||||||
// interface{} to remain source-compatible with previous call sites that
|
// interface{} to remain source-compatible with previous call sites that
|
||||||
// passed a DHT instance. The value is ignored.
|
// passed a DHT instance. The value is ignored.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
host host.Host
|
host host.Host
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
failedPeerExchanges map[peer.ID]time.Time // Track failed peer exchange attempts to suppress repeated warnings
|
|
||||||
|
// failedMu protects failedPeerExchanges from concurrent access during
|
||||||
|
// parallel peer exchange dials (H3 fix).
|
||||||
|
failedMu sync.Mutex
|
||||||
|
failedPeerExchanges map[peer.ID]time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config contains discovery configuration
|
// Config contains discovery configuration
|
||||||
@ -364,8 +369,8 @@ func (d *Manager) discoverViaPeerExchange(ctx context.Context, maxConnections in
|
|||||||
// Add to peerstore (only valid addresses with port 4001)
|
// Add to peerstore (only valid addresses with port 4001)
|
||||||
d.host.Peerstore().AddAddrs(parsedID, addrs, time.Hour*24)
|
d.host.Peerstore().AddAddrs(parsedID, addrs, time.Hour*24)
|
||||||
|
|
||||||
// Try to connect
|
// Try to connect (5s timeout — WireGuard peers respond fast)
|
||||||
connectCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
connectCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
peerAddrInfo := peer.AddrInfo{ID: parsedID, Addrs: addrs}
|
peerAddrInfo := peer.AddrInfo{ID: parsedID, Addrs: addrs}
|
||||||
|
|
||||||
if err := d.host.Connect(connectCtx, peerAddrInfo); err != nil {
|
if err := d.host.Connect(connectCtx, peerAddrInfo); err != nil {
|
||||||
@ -401,15 +406,15 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
|
|||||||
// Open a stream to the peer
|
// Open a stream to the peer
|
||||||
stream, err := d.host.NewStream(ctx, peerID, PeerExchangeProtocol)
|
stream, err := d.host.NewStream(ctx, peerID, PeerExchangeProtocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check if this is a "protocols not supported" error (expected for lightweight clients like gateway)
|
d.failedMu.Lock()
|
||||||
if strings.Contains(err.Error(), "protocols not supported") {
|
if strings.Contains(err.Error(), "protocols not supported") {
|
||||||
// This is a lightweight client (gateway, etc.) that doesn't support peer exchange - expected behavior
|
// Lightweight client (gateway, etc.) — expected, track to suppress retries
|
||||||
// Track it to avoid repeated attempts, but don't log as it's not an error
|
|
||||||
d.failedPeerExchanges[peerID] = time.Now()
|
d.failedPeerExchanges[peerID] = time.Now()
|
||||||
|
d.failedMu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// For actual connection errors, log but suppress repeated warnings for the same peer
|
// Actual connection error — log but suppress repeated warnings
|
||||||
lastFailure, seen := d.failedPeerExchanges[peerID]
|
lastFailure, seen := d.failedPeerExchanges[peerID]
|
||||||
if !seen || time.Since(lastFailure) > time.Minute {
|
if !seen || time.Since(lastFailure) > time.Minute {
|
||||||
d.logger.Debug("Failed to open peer exchange stream with node",
|
d.logger.Debug("Failed to open peer exchange stream with node",
|
||||||
@ -418,12 +423,15 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
|
|||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
d.failedPeerExchanges[peerID] = time.Now()
|
d.failedPeerExchanges[peerID] = time.Now()
|
||||||
}
|
}
|
||||||
|
d.failedMu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
// Clear failure tracking on success
|
// Clear failure tracking on success
|
||||||
|
d.failedMu.Lock()
|
||||||
delete(d.failedPeerExchanges, peerID)
|
delete(d.failedPeerExchanges, peerID)
|
||||||
|
d.failedMu.Unlock()
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
req := PeerExchangeRequest{Limit: limit}
|
req := PeerExchangeRequest{Limit: limit}
|
||||||
@ -433,8 +441,8 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set read deadline
|
// Set read deadline (5s — small JSON payload)
|
||||||
if err := stream.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
if err := stream.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
d.logger.Debug("Failed to set read deadline", zap.Error(err))
|
d.logger.Debug("Failed to set read deadline", zap.Error(err))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -451,10 +459,20 @@ func (d *Manager) requestPeersFromPeer(ctx context.Context, peerID peer.ID, limi
|
|||||||
|
|
||||||
// Store remote peer's RQLite metadata if available
|
// Store remote peer's RQLite metadata if available
|
||||||
if resp.RQLiteMetadata != nil {
|
if resp.RQLiteMetadata != nil {
|
||||||
|
// Verify sender identity — prevent metadata spoofing (H2 fix).
|
||||||
|
// If the metadata contains a PeerID, it must match the stream sender.
|
||||||
|
if resp.RQLiteMetadata.PeerID != "" && resp.RQLiteMetadata.PeerID != peerID.String() {
|
||||||
|
d.logger.Warn("Rejected metadata: PeerID mismatch",
|
||||||
|
zap.String("claimed", resp.RQLiteMetadata.PeerID[:8]+"..."),
|
||||||
|
zap.String("actual", peerID.String()[:8]+"..."))
|
||||||
|
return resp.Peers
|
||||||
|
}
|
||||||
|
// Stamp verified PeerID so downstream consumers can trust it
|
||||||
|
resp.RQLiteMetadata.PeerID = peerID.String()
|
||||||
|
|
||||||
metadataJSON, err := json.Marshal(resp.RQLiteMetadata)
|
metadataJSON, err := json.Marshal(resp.RQLiteMetadata)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_ = d.host.Peerstore().Put(peerID, "rqlite_metadata", metadataJSON)
|
_ = d.host.Peerstore().Put(peerID, "rqlite_metadata", metadataJSON)
|
||||||
// Only log when new metadata is stored (useful for debugging)
|
|
||||||
d.logger.Debug("Metadata stored",
|
d.logger.Debug("Metadata stored",
|
||||||
zap.String("peer", peerID.String()[:8]+"..."),
|
zap.String("peer", peerID.String()[:8]+"..."),
|
||||||
zap.String("node", resp.RQLiteMetadata.NodeID))
|
zap.String("node", resp.RQLiteMetadata.NodeID))
|
||||||
|
|||||||
81
pkg/discovery/metadata_publisher.go
Normal file
81
pkg/discovery/metadata_publisher.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package discovery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MetadataProvider is implemented by subsystems that can supply node metadata.
|
||||||
|
// The publisher calls Provide() every cycle and stores the result in the peerstore.
|
||||||
|
type MetadataProvider interface {
|
||||||
|
ProvideMetadata() *RQLiteNodeMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetadataPublisher periodically writes local node metadata to the peerstore so
|
||||||
|
// it is included in every peer exchange response. This decouples metadata
|
||||||
|
// production (lifecycle, RQLite status, service health) from the exchange
|
||||||
|
// protocol itself.
|
||||||
|
type MetadataPublisher struct {
|
||||||
|
host host.Host
|
||||||
|
provider MetadataProvider
|
||||||
|
interval time.Duration
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMetadataPublisher creates a publisher that writes metadata every interval.
|
||||||
|
func NewMetadataPublisher(h host.Host, provider MetadataProvider, interval time.Duration, logger *zap.Logger) *MetadataPublisher {
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = 10 * time.Second
|
||||||
|
}
|
||||||
|
return &MetadataPublisher{
|
||||||
|
host: h,
|
||||||
|
provider: provider,
|
||||||
|
interval: interval,
|
||||||
|
logger: logger.With(zap.String("component", "metadata-publisher")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the periodic publish loop. It blocks until ctx is cancelled.
|
||||||
|
func (p *MetadataPublisher) Start(ctx context.Context) {
|
||||||
|
// Publish immediately on start
|
||||||
|
p.publish()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(p.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
p.publish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishNow performs a single immediate metadata publish.
|
||||||
|
// Useful after lifecycle transitions or other state changes.
|
||||||
|
func (p *MetadataPublisher) PublishNow() {
|
||||||
|
p.publish()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MetadataPublisher) publish() {
|
||||||
|
meta := p.provider.ProvideMetadata()
|
||||||
|
if meta == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(meta)
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Error("Failed to marshal metadata", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.host.Peerstore().Put(p.host.ID(), "rqlite_metadata", data); err != nil {
|
||||||
|
p.logger.Error("Failed to store metadata in peerstore", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -4,18 +4,101 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RQLiteNodeMetadata contains RQLite-specific information announced via LibP2P
|
// ServiceStatus represents the health of an individual service on a node.
|
||||||
|
type ServiceStatus struct {
|
||||||
|
Name string `json:"name"` // e.g. "rqlite", "gateway", "olric"
|
||||||
|
Running bool `json:"running"` // whether the process is up
|
||||||
|
Healthy bool `json:"healthy"` // whether it passed its health check
|
||||||
|
Message string `json:"message,omitempty"` // optional detail ("leader", "follower", etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NamespaceStatus represents a namespace's status on a node.
|
||||||
|
type NamespaceStatus struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status string `json:"status"` // "healthy", "degraded", "recovering"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RQLiteNodeMetadata contains node information announced via LibP2P peer exchange.
|
||||||
|
// This struct is the single source of truth for node metadata propagated through
|
||||||
|
// the cluster. Go's json.Unmarshal silently ignores unknown fields, so old nodes
|
||||||
|
// reading metadata from new nodes simply skip the new fields — no protocol
|
||||||
|
// version change is needed.
|
||||||
type RQLiteNodeMetadata struct {
|
type RQLiteNodeMetadata struct {
|
||||||
NodeID string `json:"node_id"` // RQLite node ID (from config)
|
// --- Existing fields (unchanged) ---
|
||||||
RaftAddress string `json:"raft_address"` // Raft port address (e.g., "51.83.128.181:7001")
|
|
||||||
HTTPAddress string `json:"http_address"` // HTTP API address (e.g., "51.83.128.181:5001")
|
NodeID string `json:"node_id"` // RQLite node ID (raft address)
|
||||||
|
RaftAddress string `json:"raft_address"` // Raft port address (e.g., "10.0.0.1:7001")
|
||||||
|
HTTPAddress string `json:"http_address"` // HTTP API address (e.g., "10.0.0.1:5001")
|
||||||
NodeType string `json:"node_type"` // Node type identifier
|
NodeType string `json:"node_type"` // Node type identifier
|
||||||
RaftLogIndex uint64 `json:"raft_log_index"` // Current Raft log index (for data comparison)
|
RaftLogIndex uint64 `json:"raft_log_index"` // Current Raft log index (for data comparison)
|
||||||
LastSeen time.Time `json:"last_seen"` // Updated on every announcement
|
LastSeen time.Time `json:"last_seen"` // Updated on every announcement
|
||||||
ClusterVersion string `json:"cluster_version"` // For compatibility checking
|
ClusterVersion string `json:"cluster_version"` // For compatibility checking
|
||||||
|
|
||||||
|
// --- New: Identity ---
|
||||||
|
|
||||||
|
// PeerID is the LibP2P peer ID of the node. Used for metadata authentication:
|
||||||
|
// on receipt, the receiver verifies PeerID == stream sender to prevent spoofing.
|
||||||
|
PeerID string `json:"peer_id,omitempty"`
|
||||||
|
|
||||||
|
// WireGuardIP is the node's WireGuard VPN address (e.g., "10.0.0.1").
|
||||||
|
WireGuardIP string `json:"wireguard_ip,omitempty"`
|
||||||
|
|
||||||
|
// --- New: Lifecycle ---
|
||||||
|
|
||||||
|
// LifecycleState is the node's current lifecycle state:
|
||||||
|
// "joining", "active", "draining", or "maintenance".
|
||||||
|
// Zero value (empty string) from old nodes is treated as "active".
|
||||||
|
LifecycleState string `json:"lifecycle_state,omitempty"`
|
||||||
|
|
||||||
|
// MaintenanceTTL is the time at which maintenance mode expires.
|
||||||
|
// Only meaningful when LifecycleState == "maintenance".
|
||||||
|
MaintenanceTTL time.Time `json:"maintenance_ttl,omitempty"`
|
||||||
|
|
||||||
|
// --- New: Services ---
|
||||||
|
|
||||||
|
// Services reports the status of each service running on the node.
|
||||||
|
Services map[string]*ServiceStatus `json:"services,omitempty"`
|
||||||
|
|
||||||
|
// Namespaces reports the status of each namespace on the node.
|
||||||
|
Namespaces map[string]*NamespaceStatus `json:"namespaces,omitempty"`
|
||||||
|
|
||||||
|
// --- New: Version ---
|
||||||
|
|
||||||
|
// BinaryVersion is the node's binary version string (e.g., "1.2.3").
|
||||||
|
BinaryVersion string `json:"binary_version,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerExchangeResponseV2 extends the original response with RQLite metadata
|
// EffectiveLifecycleState returns the lifecycle state, defaulting to "active"
|
||||||
|
// for old nodes that don't populate the field.
|
||||||
|
func (m *RQLiteNodeMetadata) EffectiveLifecycleState() string {
|
||||||
|
if m.LifecycleState == "" {
|
||||||
|
return "active"
|
||||||
|
}
|
||||||
|
return m.LifecycleState
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInMaintenance returns true if the node has announced maintenance mode.
|
||||||
|
func (m *RQLiteNodeMetadata) IsInMaintenance() bool {
|
||||||
|
return m.EffectiveLifecycleState() == "maintenance"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAvailable returns true if the node is in a state that can serve requests.
|
||||||
|
func (m *RQLiteNodeMetadata) IsAvailable() bool {
|
||||||
|
return m.EffectiveLifecycleState() == "active"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMaintenanceExpired returns true if the node is in maintenance and the
|
||||||
|
// TTL has passed. Used by the leader's health monitor to enforce expiry.
|
||||||
|
func (m *RQLiteNodeMetadata) IsMaintenanceExpired() bool {
|
||||||
|
if !m.IsInMaintenance() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !m.MaintenanceTTL.IsZero() && time.Now().After(m.MaintenanceTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerExchangeResponseV2 extends the original response with RQLite metadata.
|
||||||
|
// Kept for backward compatibility — the V1 PeerExchangeResponse in discovery.go
|
||||||
|
// already includes the same RQLiteMetadata field, so this is effectively unused.
|
||||||
type PeerExchangeResponseV2 struct {
|
type PeerExchangeResponseV2 struct {
|
||||||
Peers []PeerInfo `json:"peers"`
|
Peers []PeerInfo `json:"peers"`
|
||||||
RQLiteMetadata *RQLiteNodeMetadata `json:"rqlite_metadata,omitempty"`
|
RQLiteMetadata *RQLiteNodeMetadata `json:"rqlite_metadata,omitempty"`
|
||||||
|
|||||||
235
pkg/discovery/rqlite_metadata_test.go
Normal file
235
pkg/discovery/rqlite_metadata_test.go
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
package discovery
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEffectiveLifecycleState(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
state string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"empty defaults to active", "", "active"},
|
||||||
|
{"explicit active", "active", "active"},
|
||||||
|
{"joining", "joining", "joining"},
|
||||||
|
{"maintenance", "maintenance", "maintenance"},
|
||||||
|
{"draining", "draining", "draining"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := &RQLiteNodeMetadata{LifecycleState: tt.state}
|
||||||
|
if got := m.EffectiveLifecycleState(); got != tt.want {
|
||||||
|
t.Fatalf("got %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsInMaintenance(t *testing.T) {
|
||||||
|
m := &RQLiteNodeMetadata{LifecycleState: "maintenance"}
|
||||||
|
if !m.IsInMaintenance() {
|
||||||
|
t.Fatal("expected maintenance")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.LifecycleState = "active"
|
||||||
|
if m.IsInMaintenance() {
|
||||||
|
t.Fatal("expected not maintenance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty state (old node) should not be maintenance
|
||||||
|
m.LifecycleState = ""
|
||||||
|
if m.IsInMaintenance() {
|
||||||
|
t.Fatal("empty state should not be maintenance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAvailable(t *testing.T) {
|
||||||
|
m := &RQLiteNodeMetadata{LifecycleState: "active"}
|
||||||
|
if !m.IsAvailable() {
|
||||||
|
t.Fatal("expected available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty state (old node) defaults to active → available
|
||||||
|
m.LifecycleState = ""
|
||||||
|
if !m.IsAvailable() {
|
||||||
|
t.Fatal("empty state should be available (backward compat)")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.LifecycleState = "maintenance"
|
||||||
|
if m.IsAvailable() {
|
||||||
|
t.Fatal("maintenance should not be available")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsMaintenanceExpired(t *testing.T) {
|
||||||
|
// Expired
|
||||||
|
m := &RQLiteNodeMetadata{
|
||||||
|
LifecycleState: "maintenance",
|
||||||
|
MaintenanceTTL: time.Now().Add(-1 * time.Minute),
|
||||||
|
}
|
||||||
|
if !m.IsMaintenanceExpired() {
|
||||||
|
t.Fatal("expected expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not expired
|
||||||
|
m.MaintenanceTTL = time.Now().Add(5 * time.Minute)
|
||||||
|
if m.IsMaintenanceExpired() {
|
||||||
|
t.Fatal("expected not expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero TTL in maintenance
|
||||||
|
m.MaintenanceTTL = time.Time{}
|
||||||
|
if m.IsMaintenanceExpired() {
|
||||||
|
t.Fatal("zero TTL should not be considered expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not in maintenance
|
||||||
|
m.LifecycleState = "active"
|
||||||
|
m.MaintenanceTTL = time.Now().Add(-1 * time.Minute)
|
||||||
|
if m.IsMaintenanceExpired() {
|
||||||
|
t.Fatal("active state should not report expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBackwardCompatibility verifies that old metadata (without new fields)
|
||||||
|
// unmarshals correctly — new fields get zero values, helpers return sane defaults.
|
||||||
|
func TestBackwardCompatibility(t *testing.T) {
|
||||||
|
oldJSON := `{
|
||||||
|
"node_id": "10.0.0.1:7001",
|
||||||
|
"raft_address": "10.0.0.1:7001",
|
||||||
|
"http_address": "10.0.0.1:5001",
|
||||||
|
"node_type": "node",
|
||||||
|
"raft_log_index": 42,
|
||||||
|
"cluster_version": "1.0"
|
||||||
|
}`
|
||||||
|
|
||||||
|
var m RQLiteNodeMetadata
|
||||||
|
if err := json.Unmarshal([]byte(oldJSON), &m); err != nil {
|
||||||
|
t.Fatalf("unmarshal old metadata: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Existing fields preserved
|
||||||
|
if m.NodeID != "10.0.0.1:7001" {
|
||||||
|
t.Fatalf("expected node_id 10.0.0.1:7001, got %s", m.NodeID)
|
||||||
|
}
|
||||||
|
if m.RaftLogIndex != 42 {
|
||||||
|
t.Fatalf("expected raft_log_index 42, got %d", m.RaftLogIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New fields default to zero values
|
||||||
|
if m.PeerID != "" {
|
||||||
|
t.Fatalf("expected empty PeerID, got %q", m.PeerID)
|
||||||
|
}
|
||||||
|
if m.LifecycleState != "" {
|
||||||
|
t.Fatalf("expected empty LifecycleState, got %q", m.LifecycleState)
|
||||||
|
}
|
||||||
|
if m.Services != nil {
|
||||||
|
t.Fatal("expected nil Services")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers return correct defaults
|
||||||
|
if m.EffectiveLifecycleState() != "active" {
|
||||||
|
t.Fatalf("expected effective state 'active', got %q", m.EffectiveLifecycleState())
|
||||||
|
}
|
||||||
|
if !m.IsAvailable() {
|
||||||
|
t.Fatal("old metadata should be available")
|
||||||
|
}
|
||||||
|
if m.IsInMaintenance() {
|
||||||
|
t.Fatal("old metadata should not be in maintenance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewFieldsRoundTrip verifies that new fields marshal/unmarshal correctly.
|
||||||
|
func TestNewFieldsRoundTrip(t *testing.T) {
|
||||||
|
original := &RQLiteNodeMetadata{
|
||||||
|
NodeID: "10.0.0.1:7001",
|
||||||
|
RaftAddress: "10.0.0.1:7001",
|
||||||
|
HTTPAddress: "10.0.0.1:5001",
|
||||||
|
NodeType: "node",
|
||||||
|
RaftLogIndex: 100,
|
||||||
|
ClusterVersion: "1.0",
|
||||||
|
PeerID: "QmPeerID123",
|
||||||
|
WireGuardIP: "10.0.0.1",
|
||||||
|
LifecycleState: "maintenance",
|
||||||
|
MaintenanceTTL: time.Now().Add(10 * time.Minute).Truncate(time.Millisecond),
|
||||||
|
BinaryVersion: "1.2.3",
|
||||||
|
Services: map[string]*ServiceStatus{
|
||||||
|
"rqlite": {Name: "rqlite", Running: true, Healthy: true, Message: "leader"},
|
||||||
|
},
|
||||||
|
Namespaces: map[string]*NamespaceStatus{
|
||||||
|
"myapp": {Name: "myapp", Status: "healthy"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(original)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded RQLiteNodeMetadata
|
||||||
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if decoded.PeerID != original.PeerID {
|
||||||
|
t.Fatalf("PeerID: got %q, want %q", decoded.PeerID, original.PeerID)
|
||||||
|
}
|
||||||
|
if decoded.WireGuardIP != original.WireGuardIP {
|
||||||
|
t.Fatalf("WireGuardIP: got %q, want %q", decoded.WireGuardIP, original.WireGuardIP)
|
||||||
|
}
|
||||||
|
if decoded.LifecycleState != original.LifecycleState {
|
||||||
|
t.Fatalf("LifecycleState: got %q, want %q", decoded.LifecycleState, original.LifecycleState)
|
||||||
|
}
|
||||||
|
if decoded.BinaryVersion != original.BinaryVersion {
|
||||||
|
t.Fatalf("BinaryVersion: got %q, want %q", decoded.BinaryVersion, original.BinaryVersion)
|
||||||
|
}
|
||||||
|
if decoded.Services["rqlite"] == nil || !decoded.Services["rqlite"].Running {
|
||||||
|
t.Fatal("expected rqlite service to be running")
|
||||||
|
}
|
||||||
|
if decoded.Namespaces["myapp"] == nil || decoded.Namespaces["myapp"].Status != "healthy" {
|
||||||
|
t.Fatal("expected myapp namespace to be healthy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOldNodeReadsNewMetadata simulates an old node (that doesn't know about new fields)
|
||||||
|
// reading metadata from a new node. Go's JSON unmarshalling silently ignores unknown fields.
|
||||||
|
func TestOldNodeReadsNewMetadata(t *testing.T) {
|
||||||
|
newJSON := `{
|
||||||
|
"node_id": "10.0.0.1:7001",
|
||||||
|
"raft_address": "10.0.0.1:7001",
|
||||||
|
"http_address": "10.0.0.1:5001",
|
||||||
|
"node_type": "node",
|
||||||
|
"raft_log_index": 42,
|
||||||
|
"cluster_version": "1.0",
|
||||||
|
"peer_id": "QmSomePeerID",
|
||||||
|
"wireguard_ip": "10.0.0.1",
|
||||||
|
"lifecycle_state": "maintenance",
|
||||||
|
"maintenance_ttl": "2025-01-01T00:00:00Z",
|
||||||
|
"binary_version": "2.0.0",
|
||||||
|
"services": {"rqlite": {"name": "rqlite", "running": true, "healthy": true}},
|
||||||
|
"namespaces": {"app": {"name": "app", "status": "healthy"}},
|
||||||
|
"some_future_field": "unknown"
|
||||||
|
}`
|
||||||
|
|
||||||
|
// Simulate "old" struct with only original fields
|
||||||
|
type OldMetadata struct {
|
||||||
|
NodeID string `json:"node_id"`
|
||||||
|
RaftAddress string `json:"raft_address"`
|
||||||
|
HTTPAddress string `json:"http_address"`
|
||||||
|
NodeType string `json:"node_type"`
|
||||||
|
RaftLogIndex uint64 `json:"raft_log_index"`
|
||||||
|
ClusterVersion string `json:"cluster_version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var old OldMetadata
|
||||||
|
if err := json.Unmarshal([]byte(newJSON), &old); err != nil {
|
||||||
|
t.Fatalf("old node should unmarshal new metadata without error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if old.NodeID != "10.0.0.1:7001" || old.RaftLogIndex != 42 {
|
||||||
|
t.Fatal("old fields should be preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
194
pkg/encryption/wallet_keygen.go
Normal file
194
pkg/encryption/wallet_keygen.go
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NodeKeys holds all cryptographic keys derived from a wallet's master key.
|
||||||
|
type NodeKeys struct {
|
||||||
|
LibP2PPrivateKey ed25519.PrivateKey // Ed25519 for LibP2P identity
|
||||||
|
LibP2PPublicKey ed25519.PublicKey
|
||||||
|
WireGuardKey [32]byte // Curve25519 private key (clamped)
|
||||||
|
WireGuardPubKey [32]byte // Curve25519 public key
|
||||||
|
IPFSPrivateKey ed25519.PrivateKey
|
||||||
|
IPFSPublicKey ed25519.PublicKey
|
||||||
|
ClusterPrivateKey ed25519.PrivateKey // IPFS Cluster identity
|
||||||
|
ClusterPublicKey ed25519.PublicKey
|
||||||
|
JWTPrivateKey ed25519.PrivateKey // EdDSA JWT signing key
|
||||||
|
JWTPublicKey ed25519.PublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeriveNodeKeysFromWallet calls `rw derive` to get a master key from the user's
|
||||||
|
// Root Wallet, then expands it into all node keys. The wallet's private key never
|
||||||
|
// leaves the `rw` process.
|
||||||
|
//
|
||||||
|
// vpsIP is used as the HKDF info parameter, so each VPS gets unique keys from the
|
||||||
|
// same wallet. Stdin is passed through so rw can prompt for the wallet password.
|
||||||
|
func DeriveNodeKeysFromWallet(vpsIP string) (*NodeKeys, error) {
|
||||||
|
if vpsIP == "" {
|
||||||
|
return nil, fmt.Errorf("VPS IP is required for key derivation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check rw is installed
|
||||||
|
if _, err := exec.LookPath("rw"); err != nil {
|
||||||
|
return nil, fmt.Errorf("Root Wallet (rw) not found in PATH — install it first")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call rw derive to get master key bytes
|
||||||
|
cmd := exec.Command("rw", "derive", "--salt", "orama-node", "--info", vpsIP)
|
||||||
|
cmd.Stdin = os.Stdin // pass through for password prompts
|
||||||
|
cmd.Stderr = os.Stderr // rw UI messages go to terminal
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rw derive failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
masterHex := strings.TrimSpace(string(out))
|
||||||
|
if len(masterHex) != 64 { // 32 bytes = 64 hex chars
|
||||||
|
return nil, fmt.Errorf("rw derive returned unexpected output length: %d (expected 64 hex chars)", len(masterHex))
|
||||||
|
}
|
||||||
|
|
||||||
|
masterKey, err := hexToBytes(masterHex)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("rw derive returned invalid hex: %w", err)
|
||||||
|
}
|
||||||
|
defer zeroBytes(masterKey)
|
||||||
|
|
||||||
|
return ExpandNodeKeys(masterKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandNodeKeys expands a 32-byte master key into all node keys using HKDF-SHA256.
|
||||||
|
// The master key should come from `rw derive --salt "orama-node" --info "<IP>"`.
|
||||||
|
//
|
||||||
|
// Each key type uses a different HKDF info string under the salt "orama-expand",
|
||||||
|
// ensuring cryptographic independence between key types.
|
||||||
|
func ExpandNodeKeys(masterKey []byte) (*NodeKeys, error) {
|
||||||
|
if len(masterKey) != 32 {
|
||||||
|
return nil, fmt.Errorf("master key must be 32 bytes, got %d", len(masterKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
salt := []byte("orama-expand")
|
||||||
|
keys := &NodeKeys{}
|
||||||
|
|
||||||
|
// Derive LibP2P Ed25519 key
|
||||||
|
seed, err := deriveBytes(masterKey, salt, []byte("libp2p-identity"), ed25519.SeedSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive libp2p key: %w", err)
|
||||||
|
}
|
||||||
|
priv := ed25519.NewKeyFromSeed(seed)
|
||||||
|
zeroBytes(seed)
|
||||||
|
keys.LibP2PPrivateKey = priv
|
||||||
|
keys.LibP2PPublicKey = priv.Public().(ed25519.PublicKey)
|
||||||
|
|
||||||
|
// Derive WireGuard Curve25519 key
|
||||||
|
wgSeed, err := deriveBytes(masterKey, salt, []byte("wireguard-key"), 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive wireguard key: %w", err)
|
||||||
|
}
|
||||||
|
copy(keys.WireGuardKey[:], wgSeed)
|
||||||
|
zeroBytes(wgSeed)
|
||||||
|
clampCurve25519Key(&keys.WireGuardKey)
|
||||||
|
pubKey, err := curve25519.X25519(keys.WireGuardKey[:], curve25519.Basepoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to compute wireguard public key: %w", err)
|
||||||
|
}
|
||||||
|
copy(keys.WireGuardPubKey[:], pubKey)
|
||||||
|
|
||||||
|
// Derive IPFS Ed25519 key
|
||||||
|
seed, err = deriveBytes(masterKey, salt, []byte("ipfs-identity"), ed25519.SeedSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive ipfs key: %w", err)
|
||||||
|
}
|
||||||
|
priv = ed25519.NewKeyFromSeed(seed)
|
||||||
|
zeroBytes(seed)
|
||||||
|
keys.IPFSPrivateKey = priv
|
||||||
|
keys.IPFSPublicKey = priv.Public().(ed25519.PublicKey)
|
||||||
|
|
||||||
|
// Derive IPFS Cluster Ed25519 key
|
||||||
|
seed, err = deriveBytes(masterKey, salt, []byte("ipfs-cluster"), ed25519.SeedSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive cluster key: %w", err)
|
||||||
|
}
|
||||||
|
priv = ed25519.NewKeyFromSeed(seed)
|
||||||
|
zeroBytes(seed)
|
||||||
|
keys.ClusterPrivateKey = priv
|
||||||
|
keys.ClusterPublicKey = priv.Public().(ed25519.PublicKey)
|
||||||
|
|
||||||
|
// Derive JWT EdDSA signing key
|
||||||
|
seed, err = deriveBytes(masterKey, salt, []byte("jwt-signing"), ed25519.SeedSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive jwt key: %w", err)
|
||||||
|
}
|
||||||
|
priv = ed25519.NewKeyFromSeed(seed)
|
||||||
|
zeroBytes(seed)
|
||||||
|
keys.JWTPrivateKey = priv
|
||||||
|
keys.JWTPublicKey = priv.Public().(ed25519.PublicKey)
|
||||||
|
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deriveBytes uses HKDF-SHA256 to derive n bytes from the given IKM, salt, and info.
|
||||||
|
func deriveBytes(ikm, salt, info []byte, n int) ([]byte, error) {
|
||||||
|
hkdfReader := hkdf.New(sha256.New, ikm, salt, info)
|
||||||
|
out := make([]byte, n)
|
||||||
|
if _, err := io.ReadFull(hkdfReader, out); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clampCurve25519Key applies the standard Curve25519 clamping to a private key.
|
||||||
|
func clampCurve25519Key(key *[32]byte) {
|
||||||
|
key[0] &= 248
|
||||||
|
key[31] &= 127
|
||||||
|
key[31] |= 64
|
||||||
|
}
|
||||||
|
|
||||||
|
// hexToBytes decodes a hex string to bytes.
|
||||||
|
func hexToBytes(hex string) ([]byte, error) {
|
||||||
|
if len(hex)%2 != 0 {
|
||||||
|
return nil, fmt.Errorf("odd-length hex string")
|
||||||
|
}
|
||||||
|
b := make([]byte, len(hex)/2)
|
||||||
|
for i := 0; i < len(hex); i += 2 {
|
||||||
|
var hi, lo byte
|
||||||
|
var err error
|
||||||
|
if hi, err = hexCharToByte(hex[i]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if lo, err = hexCharToByte(hex[i+1]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
b[i/2] = hi<<4 | lo
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hexCharToByte(c byte) (byte, error) {
|
||||||
|
switch {
|
||||||
|
case c >= '0' && c <= '9':
|
||||||
|
return c - '0', nil
|
||||||
|
case c >= 'a' && c <= 'f':
|
||||||
|
return c - 'a' + 10, nil
|
||||||
|
case c >= 'A' && c <= 'F':
|
||||||
|
return c - 'A' + 10, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid hex character: %c", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zeroBytes zeroes a byte slice to clear sensitive data from memory.
|
||||||
|
func zeroBytes(b []byte) {
|
||||||
|
for i := range b {
|
||||||
|
b[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
202
pkg/encryption/wallet_keygen_test.go
Normal file
202
pkg/encryption/wallet_keygen_test.go
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testMasterKey is a deterministic 32-byte key for testing ExpandNodeKeys.
|
||||||
|
// In production, this comes from `rw derive --salt "orama-node" --info "<IP>"`.
|
||||||
|
var testMasterKey = bytes.Repeat([]byte{0xab}, 32)
|
||||||
|
var testMasterKey2 = bytes.Repeat([]byte{0xcd}, 32)
|
||||||
|
|
||||||
|
func TestExpandNodeKeys_Determinism(t *testing.T) {
|
||||||
|
keys1, err := ExpandNodeKeys(testMasterKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandNodeKeys: %v", err)
|
||||||
|
}
|
||||||
|
keys2, err := ExpandNodeKeys(testMasterKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandNodeKeys (second): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(keys1.LibP2PPrivateKey, keys2.LibP2PPrivateKey) {
|
||||||
|
t.Error("LibP2P private keys differ for same input")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(keys1.WireGuardKey[:], keys2.WireGuardKey[:]) {
|
||||||
|
t.Error("WireGuard keys differ for same input")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(keys1.IPFSPrivateKey, keys2.IPFSPrivateKey) {
|
||||||
|
t.Error("IPFS private keys differ for same input")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(keys1.ClusterPrivateKey, keys2.ClusterPrivateKey) {
|
||||||
|
t.Error("Cluster private keys differ for same input")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(keys1.JWTPrivateKey, keys2.JWTPrivateKey) {
|
||||||
|
t.Error("JWT private keys differ for same input")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpandNodeKeys_Uniqueness(t *testing.T) {
|
||||||
|
keys1, err := ExpandNodeKeys(testMasterKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandNodeKeys(master1): %v", err)
|
||||||
|
}
|
||||||
|
keys2, err := ExpandNodeKeys(testMasterKey2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandNodeKeys(master2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(keys1.LibP2PPrivateKey, keys2.LibP2PPrivateKey) {
|
||||||
|
t.Error("LibP2P keys should differ for different master keys")
|
||||||
|
}
|
||||||
|
if bytes.Equal(keys1.WireGuardKey[:], keys2.WireGuardKey[:]) {
|
||||||
|
t.Error("WireGuard keys should differ for different master keys")
|
||||||
|
}
|
||||||
|
if bytes.Equal(keys1.IPFSPrivateKey, keys2.IPFSPrivateKey) {
|
||||||
|
t.Error("IPFS keys should differ for different master keys")
|
||||||
|
}
|
||||||
|
if bytes.Equal(keys1.ClusterPrivateKey, keys2.ClusterPrivateKey) {
|
||||||
|
t.Error("Cluster keys should differ for different master keys")
|
||||||
|
}
|
||||||
|
if bytes.Equal(keys1.JWTPrivateKey, keys2.JWTPrivateKey) {
|
||||||
|
t.Error("JWT keys should differ for different master keys")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpandNodeKeys_KeysAreMutuallyUnique(t *testing.T) {
|
||||||
|
keys, err := ExpandNodeKeys(testMasterKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandNodeKeys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privKeys := [][]byte{
|
||||||
|
keys.LibP2PPrivateKey.Seed(),
|
||||||
|
keys.IPFSPrivateKey.Seed(),
|
||||||
|
keys.ClusterPrivateKey.Seed(),
|
||||||
|
keys.JWTPrivateKey.Seed(),
|
||||||
|
keys.WireGuardKey[:],
|
||||||
|
}
|
||||||
|
labels := []string{"LibP2P", "IPFS", "Cluster", "JWT", "WireGuard"}
|
||||||
|
|
||||||
|
for i := 0; i < len(privKeys); i++ {
|
||||||
|
for j := i + 1; j < len(privKeys); j++ {
|
||||||
|
if bytes.Equal(privKeys[i], privKeys[j]) {
|
||||||
|
t.Errorf("%s and %s keys should differ", labels[i], labels[j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpandNodeKeys_Ed25519Validity(t *testing.T) {
|
||||||
|
keys, err := ExpandNodeKeys(testMasterKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandNodeKeys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := []byte("test message for verification")
|
||||||
|
|
||||||
|
pairs := []struct {
|
||||||
|
name string
|
||||||
|
priv ed25519.PrivateKey
|
||||||
|
pub ed25519.PublicKey
|
||||||
|
}{
|
||||||
|
{"LibP2P", keys.LibP2PPrivateKey, keys.LibP2PPublicKey},
|
||||||
|
{"IPFS", keys.IPFSPrivateKey, keys.IPFSPublicKey},
|
||||||
|
{"Cluster", keys.ClusterPrivateKey, keys.ClusterPublicKey},
|
||||||
|
{"JWT", keys.JWTPrivateKey, keys.JWTPublicKey},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range pairs {
|
||||||
|
signature := ed25519.Sign(p.priv, msg)
|
||||||
|
if !ed25519.Verify(p.pub, msg, signature) {
|
||||||
|
t.Errorf("%s key pair: signature verification failed", p.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpandNodeKeys_WireGuardClamping(t *testing.T) {
|
||||||
|
keys, err := ExpandNodeKeys(testMasterKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandNodeKeys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keys.WireGuardKey[0]&7 != 0 {
|
||||||
|
t.Errorf("WireGuard key not properly clamped: low 3 bits of first byte should be 0, got %08b", keys.WireGuardKey[0])
|
||||||
|
}
|
||||||
|
if keys.WireGuardKey[31]&128 != 0 {
|
||||||
|
t.Errorf("WireGuard key not properly clamped: high bit of last byte should be 0, got %08b", keys.WireGuardKey[31])
|
||||||
|
}
|
||||||
|
if keys.WireGuardKey[31]&64 != 64 {
|
||||||
|
t.Errorf("WireGuard key not properly clamped: second-high bit of last byte should be 1, got %08b", keys.WireGuardKey[31])
|
||||||
|
}
|
||||||
|
|
||||||
|
var zero [32]byte
|
||||||
|
if keys.WireGuardPubKey == zero {
|
||||||
|
t.Error("WireGuard public key is all zeros")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpandNodeKeys_InvalidMasterKeyLength(t *testing.T) {
|
||||||
|
_, err := ExpandNodeKeys(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nil master key")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ExpandNodeKeys([]byte{})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty master key")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ExpandNodeKeys(make([]byte, 16))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for 16-byte master key")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ExpandNodeKeys(make([]byte, 64))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for 64-byte master key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHexToBytes(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"", []byte{}, false},
|
||||||
|
{"00", []byte{0}, false},
|
||||||
|
{"ff", []byte{255}, false},
|
||||||
|
{"FF", []byte{255}, false},
|
||||||
|
{"0a1b2c", []byte{10, 27, 44}, false},
|
||||||
|
{"0", nil, true}, // odd length
|
||||||
|
{"zz", nil, true}, // invalid chars
|
||||||
|
{"gg", nil, true}, // invalid chars
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got, err := hexToBytes(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("hexToBytes(%q): expected error", tt.input)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("hexToBytes(%q): unexpected error: %v", tt.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, tt.expected) {
|
||||||
|
t.Errorf("hexToBytes(%q) = %v, want %v", tt.input, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeriveNodeKeysFromWallet_EmptyIP(t *testing.T) {
|
||||||
|
_, err := DeriveNodeKeysFromWallet("")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for empty VPS IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,6 +2,7 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
@ -15,31 +16,46 @@ import (
|
|||||||
|
|
||||||
func (s *Service) JWKSHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Service) JWKSHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if s.signingKey == nil {
|
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{}})
|
keys := make([]any, 0, 2)
|
||||||
return
|
|
||||||
|
// RSA key (RS256)
|
||||||
|
if s.signingKey != nil {
|
||||||
|
pub := s.signingKey.Public().(*rsa.PublicKey)
|
||||||
|
n := pub.N.Bytes()
|
||||||
|
eVal := pub.E
|
||||||
|
eb := make([]byte, 0)
|
||||||
|
for eVal > 0 {
|
||||||
|
eb = append([]byte{byte(eVal & 0xff)}, eb...)
|
||||||
|
eVal >>= 8
|
||||||
|
}
|
||||||
|
if len(eb) == 0 {
|
||||||
|
eb = []byte{0}
|
||||||
|
}
|
||||||
|
keys = append(keys, map[string]string{
|
||||||
|
"kty": "RSA",
|
||||||
|
"use": "sig",
|
||||||
|
"alg": "RS256",
|
||||||
|
"kid": s.keyID,
|
||||||
|
"n": base64.RawURLEncoding.EncodeToString(n),
|
||||||
|
"e": base64.RawURLEncoding.EncodeToString(eb),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
pub := s.signingKey.Public().(*rsa.PublicKey)
|
|
||||||
n := pub.N.Bytes()
|
// Ed25519 key (EdDSA)
|
||||||
// Encode exponent as big-endian bytes
|
if s.edSigningKey != nil {
|
||||||
eVal := pub.E
|
pubKey := s.edSigningKey.Public().(ed25519.PublicKey)
|
||||||
eb := make([]byte, 0)
|
keys = append(keys, map[string]string{
|
||||||
for eVal > 0 {
|
"kty": "OKP",
|
||||||
eb = append([]byte{byte(eVal & 0xff)}, eb...)
|
"use": "sig",
|
||||||
eVal >>= 8
|
"alg": "EdDSA",
|
||||||
|
"kid": s.edKeyID,
|
||||||
|
"crv": "Ed25519",
|
||||||
|
"x": base64.RawURLEncoding.EncodeToString(pubKey),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if len(eb) == 0 {
|
|
||||||
eb = []byte{0}
|
_ = json.NewEncoder(w).Encode(map[string]any{"keys": keys})
|
||||||
}
|
|
||||||
jwk := map[string]string{
|
|
||||||
"kty": "RSA",
|
|
||||||
"use": "sig",
|
|
||||||
"alg": "RS256",
|
|
||||||
"kid": s.keyID,
|
|
||||||
"n": base64.RawURLEncoding.EncodeToString(n),
|
|
||||||
"e": base64.RawURLEncoding.EncodeToString(eb),
|
|
||||||
}
|
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{"keys": []any{jwk}})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal types for JWT handling
|
// Internal types for JWT handling
|
||||||
@ -59,11 +75,12 @@ type JWTClaims struct {
|
|||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseAndVerifyJWT verifies an RS256 JWT created by this gateway and returns claims
|
// ParseAndVerifyJWT verifies a JWT created by this gateway using kid-based key
|
||||||
|
// selection. It accepts both RS256 (legacy) and EdDSA (new) tokens.
|
||||||
|
//
|
||||||
|
// Security (C3 fix): The key is selected by kid, then cross-checked against alg
|
||||||
|
// to prevent algorithm confusion attacks. Only RS256 and EdDSA are accepted.
|
||||||
func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
|
func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
|
||||||
if s.signingKey == nil {
|
|
||||||
return nil, errors.New("signing key unavailable")
|
|
||||||
}
|
|
||||||
parts := strings.Split(token, ".")
|
parts := strings.Split(token, ".")
|
||||||
if len(parts) != 3 {
|
if len(parts) != 3 {
|
||||||
return nil, errors.New("invalid token format")
|
return nil, errors.New("invalid token format")
|
||||||
@ -80,20 +97,60 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("invalid signature encoding")
|
return nil, errors.New("invalid signature encoding")
|
||||||
}
|
}
|
||||||
|
|
||||||
var header jwtHeader
|
var header jwtHeader
|
||||||
if err := json.Unmarshal(hb, &header); err != nil {
|
if err := json.Unmarshal(hb, &header); err != nil {
|
||||||
return nil, errors.New("invalid header json")
|
return nil, errors.New("invalid header json")
|
||||||
}
|
}
|
||||||
if header.Alg != "RS256" {
|
|
||||||
return nil, errors.New("unsupported alg")
|
// Explicit algorithm allowlist — reject everything else before verification
|
||||||
|
if header.Alg != "RS256" && header.Alg != "EdDSA" {
|
||||||
|
return nil, errors.New("unsupported algorithm")
|
||||||
}
|
}
|
||||||
// Verify signature
|
|
||||||
signingInput := parts[0] + "." + parts[1]
|
signingInput := parts[0] + "." + parts[1]
|
||||||
sum := sha256.Sum256([]byte(signingInput))
|
|
||||||
pub := s.signingKey.Public().(*rsa.PublicKey)
|
// Key selection by kid (not alg) — prevents algorithm confusion (C3 fix)
|
||||||
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
|
switch {
|
||||||
return nil, errors.New("invalid signature")
|
case header.Kid != "" && header.Kid == s.edKeyID && s.edSigningKey != nil:
|
||||||
|
// EdDSA key matched by kid — cross-check alg
|
||||||
|
if header.Alg != "EdDSA" {
|
||||||
|
return nil, errors.New("algorithm mismatch for key")
|
||||||
|
}
|
||||||
|
pubKey := s.edSigningKey.Public().(ed25519.PublicKey)
|
||||||
|
if !ed25519.Verify(pubKey, []byte(signingInput), sb) {
|
||||||
|
return nil, errors.New("invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
case header.Kid != "" && header.Kid == s.keyID && s.signingKey != nil:
|
||||||
|
// RSA key matched by kid — cross-check alg
|
||||||
|
if header.Alg != "RS256" {
|
||||||
|
return nil, errors.New("algorithm mismatch for key")
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256([]byte(signingInput))
|
||||||
|
pub := s.signingKey.Public().(*rsa.PublicKey)
|
||||||
|
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
|
||||||
|
return nil, errors.New("invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
case header.Kid == "":
|
||||||
|
// Legacy token without kid — RS256 only (backward compat)
|
||||||
|
if header.Alg != "RS256" {
|
||||||
|
return nil, errors.New("legacy token must be RS256")
|
||||||
|
}
|
||||||
|
if s.signingKey == nil {
|
||||||
|
return nil, errors.New("signing key unavailable")
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256([]byte(signingInput))
|
||||||
|
pub := s.signingKey.Public().(*rsa.PublicKey)
|
||||||
|
if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, sum[:], sb); err != nil {
|
||||||
|
return nil, errors.New("invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unknown key ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse claims
|
// Parse claims
|
||||||
var claims JWTClaims
|
var claims JWTClaims
|
||||||
if err := json.Unmarshal(pb, &claims); err != nil {
|
if err := json.Unmarshal(pb, &claims); err != nil {
|
||||||
@ -105,8 +162,7 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
|
|||||||
}
|
}
|
||||||
// Validate registered claims
|
// Validate registered claims
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
// allow small clock skew ±60s
|
const skew = int64(60) // allow small clock skew ±60s
|
||||||
const skew = int64(60)
|
|
||||||
if claims.Nbf != 0 && now+skew < claims.Nbf {
|
if claims.Nbf != 0 && now+skew < claims.Nbf {
|
||||||
return nil, errors.New("token not yet valid")
|
return nil, errors.New("token not yet valid")
|
||||||
}
|
}
|
||||||
@ -123,6 +179,44 @@ func (s *Service) ParseAndVerifyJWT(token string) (*JWTClaims, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) GenerateJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
|
func (s *Service) GenerateJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
|
||||||
|
// Prefer EdDSA when available
|
||||||
|
if s.preferEdDSA && s.edSigningKey != nil {
|
||||||
|
return s.generateEdDSAJWT(ns, subject, ttl)
|
||||||
|
}
|
||||||
|
return s.generateRSAJWT(ns, subject, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) generateEdDSAJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
|
||||||
|
if s.edSigningKey == nil {
|
||||||
|
return "", 0, errors.New("EdDSA signing key unavailable")
|
||||||
|
}
|
||||||
|
header := map[string]string{
|
||||||
|
"alg": "EdDSA",
|
||||||
|
"typ": "JWT",
|
||||||
|
"kid": s.edKeyID,
|
||||||
|
}
|
||||||
|
hb, _ := json.Marshal(header)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
exp := now.Add(ttl)
|
||||||
|
payload := map[string]any{
|
||||||
|
"iss": "debros-gateway",
|
||||||
|
"sub": subject,
|
||||||
|
"aud": "gateway",
|
||||||
|
"iat": now.Unix(),
|
||||||
|
"nbf": now.Unix(),
|
||||||
|
"exp": exp.Unix(),
|
||||||
|
"namespace": ns,
|
||||||
|
}
|
||||||
|
pb, _ := json.Marshal(payload)
|
||||||
|
hb64 := base64.RawURLEncoding.EncodeToString(hb)
|
||||||
|
pb64 := base64.RawURLEncoding.EncodeToString(pb)
|
||||||
|
signingInput := hb64 + "." + pb64
|
||||||
|
sig := ed25519.Sign(s.edSigningKey, []byte(signingInput))
|
||||||
|
sb64 := base64.RawURLEncoding.EncodeToString(sig)
|
||||||
|
return signingInput + "." + sb64, exp.Unix(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) generateRSAJWT(ns, subject string, ttl time.Duration) (string, int64, error) {
|
||||||
if s.signingKey == nil {
|
if s.signingKey == nil {
|
||||||
return "", 0, errors.New("signing key unavailable")
|
return "", 0, errors.New("signing key unavailable")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,11 +24,14 @@ import (
|
|||||||
|
|
||||||
// Service handles authentication business logic
|
// Service handles authentication business logic
|
||||||
type Service struct {
|
type Service struct {
|
||||||
logger *logging.ColoredLogger
|
logger *logging.ColoredLogger
|
||||||
orm client.NetworkClient
|
orm client.NetworkClient
|
||||||
signingKey *rsa.PrivateKey
|
signingKey *rsa.PrivateKey
|
||||||
keyID string
|
keyID string
|
||||||
defaultNS string
|
edSigningKey ed25519.PrivateKey
|
||||||
|
edKeyID string
|
||||||
|
preferEdDSA bool
|
||||||
|
defaultNS string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signingKeyPEM string, defaultNS string) (*Service, error) {
|
func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signingKeyPEM string, defaultNS string) (*Service, error) {
|
||||||
@ -58,6 +61,16 @@ func NewService(logger *logging.ColoredLogger, orm client.NetworkClient, signing
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetEdDSAKey configures an Ed25519 signing key for EdDSA JWT support.
|
||||||
|
// When set, new tokens are signed with EdDSA; RS256 is still accepted for verification.
|
||||||
|
func (s *Service) SetEdDSAKey(privKey ed25519.PrivateKey) {
|
||||||
|
s.edSigningKey = privKey
|
||||||
|
pubBytes := []byte(privKey.Public().(ed25519.PublicKey))
|
||||||
|
sum := sha256.Sum256(pubBytes)
|
||||||
|
s.edKeyID = "ed_" + hex.EncodeToString(sum[:8])
|
||||||
|
s.preferEdDSA = true
|
||||||
|
}
|
||||||
|
|
||||||
// CreateNonce generates a new nonce and stores it in the database
|
// CreateNonce generates a new nonce and stores it in the database
|
||||||
func (s *Service) CreateNonce(ctx context.Context, wallet, purpose, namespace string) (string, error) {
|
func (s *Service) CreateNonce(ctx context.Context, wallet, purpose, namespace string) (string, error) {
|
||||||
// Generate a URL-safe random nonce (32 bytes)
|
// Generate a URL-safe random nonce (32 bytes)
|
||||||
|
|||||||
@ -2,11 +2,18 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -164,3 +171,248 @@ func TestVerifySolSignature(t *testing.T) {
|
|||||||
t.Error("VerifySignature should have failed for invalid base64 signature")
|
t.Error("VerifySignature should have failed for invalid base64 signature")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createDualKeyService creates a service with both RSA and EdDSA keys configured
|
||||||
|
func createDualKeyService(t *testing.T) *Service {
|
||||||
|
t.Helper()
|
||||||
|
s := createTestService(t) // has RSA
|
||||||
|
_, edPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate ed25519 key: %v", err)
|
||||||
|
}
|
||||||
|
s.SetEdDSAKey(edPriv)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEdDSAJWTFlow(t *testing.T) {
|
||||||
|
s := createDualKeyService(t)
|
||||||
|
|
||||||
|
ns := "test-ns"
|
||||||
|
sub := "0xabcdef1234567890abcdef1234567890abcdef12"
|
||||||
|
ttl := 15 * time.Minute
|
||||||
|
|
||||||
|
// With EdDSA preferred, GenerateJWT should produce an EdDSA token
|
||||||
|
token, exp, err := s.GenerateJWT(ns, sub, ttl)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateJWT (EdDSA) failed: %v", err)
|
||||||
|
}
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("generated EdDSA token is empty")
|
||||||
|
}
|
||||||
|
if exp <= time.Now().Unix() {
|
||||||
|
t.Errorf("expiration time %d is in the past", exp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the header contains EdDSA
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
hb, _ := base64.RawURLEncoding.DecodeString(parts[0])
|
||||||
|
var header map[string]string
|
||||||
|
json.Unmarshal(hb, &header)
|
||||||
|
if header["alg"] != "EdDSA" {
|
||||||
|
t.Errorf("expected alg EdDSA, got %s", header["alg"])
|
||||||
|
}
|
||||||
|
if header["kid"] != s.edKeyID {
|
||||||
|
t.Errorf("expected kid %s, got %s", s.edKeyID, header["kid"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseAndVerifyJWT (EdDSA) failed: %v", err)
|
||||||
|
}
|
||||||
|
if claims.Sub != sub {
|
||||||
|
t.Errorf("expected subject %s, got %s", sub, claims.Sub)
|
||||||
|
}
|
||||||
|
if claims.Namespace != ns {
|
||||||
|
t.Errorf("expected namespace %s, got %s", ns, claims.Namespace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRS256BackwardCompat(t *testing.T) {
|
||||||
|
s := createDualKeyService(t)
|
||||||
|
|
||||||
|
// Generate an RS256 token directly (simulating a legacy token)
|
||||||
|
s.preferEdDSA = false
|
||||||
|
token, _, err := s.GenerateJWT("test-ns", "user1", 15*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateJWT (RS256) failed: %v", err)
|
||||||
|
}
|
||||||
|
s.preferEdDSA = true // re-enable EdDSA preference
|
||||||
|
|
||||||
|
// Verify the RS256 token still works with dual-key service
|
||||||
|
claims, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseAndVerifyJWT should accept RS256 token: %v", err)
|
||||||
|
}
|
||||||
|
if claims.Sub != "user1" {
|
||||||
|
t.Errorf("expected subject user1, got %s", claims.Sub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAlgorithmConfusion_Rejected(t *testing.T) {
|
||||||
|
s := createDualKeyService(t)
|
||||||
|
|
||||||
|
t.Run("none_algorithm", func(t *testing.T) {
|
||||||
|
// Craft a token with alg=none
|
||||||
|
header := map[string]string{"alg": "none", "typ": "JWT"}
|
||||||
|
hb, _ := json.Marshal(header)
|
||||||
|
payload := map[string]any{
|
||||||
|
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
|
||||||
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
||||||
|
}
|
||||||
|
pb, _ := json.Marshal(payload)
|
||||||
|
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
|
||||||
|
base64.RawURLEncoding.EncodeToString(pb) + "."
|
||||||
|
|
||||||
|
_, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should reject alg=none")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("HS256_algorithm", func(t *testing.T) {
|
||||||
|
header := map[string]string{"alg": "HS256", "typ": "JWT", "kid": s.keyID}
|
||||||
|
hb, _ := json.Marshal(header)
|
||||||
|
payload := map[string]any{
|
||||||
|
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
|
||||||
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
||||||
|
}
|
||||||
|
pb, _ := json.Marshal(payload)
|
||||||
|
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
|
||||||
|
base64.RawURLEncoding.EncodeToString(pb) + "." +
|
||||||
|
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
|
||||||
|
|
||||||
|
_, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should reject alg=HS256")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("kid_alg_mismatch_EdDSA_kid_RS256_alg", func(t *testing.T) {
|
||||||
|
// Use EdDSA kid but claim RS256 alg
|
||||||
|
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": s.edKeyID}
|
||||||
|
hb, _ := json.Marshal(header)
|
||||||
|
payload := map[string]any{
|
||||||
|
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
|
||||||
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
||||||
|
}
|
||||||
|
pb, _ := json.Marshal(payload)
|
||||||
|
// Sign with RSA (trying to confuse the verifier into using RSA on EdDSA kid)
|
||||||
|
hb64 := base64.RawURLEncoding.EncodeToString(hb)
|
||||||
|
pb64 := base64.RawURLEncoding.EncodeToString(pb)
|
||||||
|
signingInput := hb64 + "." + pb64
|
||||||
|
sum := sha256.Sum256([]byte(signingInput))
|
||||||
|
rsaSig, _ := rsa.SignPKCS1v15(rand.Reader, s.signingKey, 4, sum[:]) // crypto.SHA256 = 4
|
||||||
|
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(rsaSig)
|
||||||
|
|
||||||
|
_, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should reject kid/alg mismatch (EdDSA kid with RS256 alg)")
|
||||||
|
}
|
||||||
|
if err != nil && !strings.Contains(err.Error(), "algorithm mismatch") {
|
||||||
|
t.Errorf("expected 'algorithm mismatch' error, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown_kid", func(t *testing.T) {
|
||||||
|
header := map[string]string{"alg": "RS256", "typ": "JWT", "kid": "unknown-kid-123"}
|
||||||
|
hb, _ := json.Marshal(header)
|
||||||
|
payload := map[string]any{
|
||||||
|
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
|
||||||
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
||||||
|
}
|
||||||
|
pb, _ := json.Marshal(payload)
|
||||||
|
token := base64.RawURLEncoding.EncodeToString(hb) + "." +
|
||||||
|
base64.RawURLEncoding.EncodeToString(pb) + "." +
|
||||||
|
base64.RawURLEncoding.EncodeToString([]byte("fake-sig"))
|
||||||
|
|
||||||
|
_, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should reject unknown kid")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("legacy_token_EdDSA_rejected", func(t *testing.T) {
|
||||||
|
// Token with no kid and alg=EdDSA — should be rejected (legacy must be RS256)
|
||||||
|
header := map[string]string{"alg": "EdDSA", "typ": "JWT"}
|
||||||
|
hb, _ := json.Marshal(header)
|
||||||
|
payload := map[string]any{
|
||||||
|
"iss": "debros-gateway", "sub": "attacker", "aud": "gateway",
|
||||||
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(),
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(), "namespace": "test-ns",
|
||||||
|
}
|
||||||
|
pb, _ := json.Marshal(payload)
|
||||||
|
hb64 := base64.RawURLEncoding.EncodeToString(hb)
|
||||||
|
pb64 := base64.RawURLEncoding.EncodeToString(pb)
|
||||||
|
signingInput := hb64 + "." + pb64
|
||||||
|
sig := ed25519.Sign(s.edSigningKey, []byte(signingInput))
|
||||||
|
token := signingInput + "." + base64.RawURLEncoding.EncodeToString(sig)
|
||||||
|
|
||||||
|
_, err := s.ParseAndVerifyJWT(token)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("should reject legacy token (no kid) with EdDSA alg")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWKSHandler_DualKey(t *testing.T) {
|
||||||
|
s := createDualKeyService(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.JWKSHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Keys []map[string]string `json:"keys"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
|
||||||
|
t.Fatalf("failed to decode JWKS response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Keys) != 2 {
|
||||||
|
t.Fatalf("expected 2 keys in JWKS, got %d", len(result.Keys))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we have both RSA and OKP keys
|
||||||
|
algSet := map[string]bool{}
|
||||||
|
for _, k := range result.Keys {
|
||||||
|
algSet[k["alg"]] = true
|
||||||
|
if k["kid"] == "" {
|
||||||
|
t.Error("key missing kid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !algSet["RS256"] {
|
||||||
|
t.Error("JWKS missing RS256 key")
|
||||||
|
}
|
||||||
|
if !algSet["EdDSA"] {
|
||||||
|
t.Error("JWKS missing EdDSA key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWKSHandler_RSAOnly(t *testing.T) {
|
||||||
|
s := createTestService(t) // RSA only, no EdDSA
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
s.JWKSHandler(w, req)
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Keys []map[string]string `json:"keys"`
|
||||||
|
}
|
||||||
|
json.NewDecoder(w.Body).Decode(&result)
|
||||||
|
|
||||||
|
if len(result.Keys) != 1 {
|
||||||
|
t.Fatalf("expected 1 key in JWKS (RSA only), got %d", len(result.Keys))
|
||||||
|
}
|
||||||
|
if result.Keys[0]["alg"] != "RS256" {
|
||||||
|
t.Errorf("expected RS256, got %s", result.Keys[0]["alg"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -429,7 +429,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
|||||||
logger.Logger,
|
logger.Logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Initialize auth service with persistent signing key
|
// Initialize auth service with persistent signing keys (RSA + EdDSA)
|
||||||
keyPEM, err := loadOrCreateSigningKey(cfg.DataDir, logger)
|
keyPEM, err := loadOrCreateSigningKey(cfg.DataDir, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load or create JWT signing key: %w", err)
|
return fmt.Errorf("failed to load or create JWT signing key: %w", err)
|
||||||
@ -438,6 +438,17 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize auth service: %w", err)
|
return fmt.Errorf("failed to initialize auth service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load or create EdDSA key for new JWT tokens
|
||||||
|
edKey, err := loadOrCreateEdSigningKey(cfg.DataDir, logger)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "Failed to load EdDSA signing key; new JWTs will use RS256",
|
||||||
|
zap.Error(err))
|
||||||
|
} else {
|
||||||
|
authService.SetEdDSAKey(edKey)
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "EdDSA signing key loaded; new JWTs will use EdDSA")
|
||||||
|
}
|
||||||
|
|
||||||
deps.AuthService = authService
|
deps.AuthService = authService
|
||||||
|
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",
|
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",
|
||||||
|
|||||||
@ -42,7 +42,7 @@ func (g *Gateway) startNamespaceHealthLoop(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
probeTicker := time.NewTicker(30 * time.Second)
|
probeTicker := time.NewTicker(30 * time.Second)
|
||||||
reconcileTicker := time.NewTicker(1 * time.Hour)
|
reconcileTicker := time.NewTicker(5 * time.Minute)
|
||||||
defer probeTicker.Stop()
|
defer probeTicker.Stop()
|
||||||
defer reconcileTicker.Stop()
|
defer reconcileTicker.Stop()
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package gateway
|
package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
@ -14,6 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const jwtKeyFileName = "jwt-signing-key.pem"
|
const jwtKeyFileName = "jwt-signing-key.pem"
|
||||||
|
const eddsaKeyFileName = "jwt-eddsa-key.pem"
|
||||||
|
|
||||||
// loadOrCreateSigningKey loads the JWT signing key from disk, or generates a new one
|
// loadOrCreateSigningKey loads the JWT signing key from disk, or generates a new one
|
||||||
// if none exists. This ensures JWTs survive gateway restarts.
|
// if none exists. This ensures JWTs survive gateway restarts.
|
||||||
@ -61,3 +63,56 @@ func loadOrCreateSigningKey(dataDir string, logger *logging.ColoredLogger) ([]by
|
|||||||
zap.String("path", keyPath))
|
zap.String("path", keyPath))
|
||||||
return keyPEM, nil
|
return keyPEM, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadOrCreateEdSigningKey loads or generates an Ed25519 private key for EdDSA JWT signing.
|
||||||
|
// The key is stored as a PKCS8-encoded PEM file alongside the RSA key.
|
||||||
|
func loadOrCreateEdSigningKey(dataDir string, logger *logging.ColoredLogger) (ed25519.PrivateKey, error) {
|
||||||
|
keyPath := filepath.Join(dataDir, "secrets", eddsaKeyFileName)
|
||||||
|
|
||||||
|
// Try to load existing key
|
||||||
|
if keyPEM, err := os.ReadFile(keyPath); err == nil && len(keyPEM) > 0 {
|
||||||
|
block, _ := pem.Decode(keyPEM)
|
||||||
|
if block != nil {
|
||||||
|
parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||||
|
if err == nil {
|
||||||
|
if edKey, ok := parsed.(ed25519.PrivateKey); ok {
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Loaded existing EdDSA signing key",
|
||||||
|
zap.String("path", keyPath))
|
||||||
|
return edKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "Existing EdDSA signing key is invalid, generating new one",
|
||||||
|
zap.String("path", keyPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new Ed25519 key
|
||||||
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate Ed25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal Ed25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "PRIVATE KEY",
|
||||||
|
Bytes: pkcs8Bytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Ensure secrets directory exists
|
||||||
|
secretsDir := filepath.Dir(keyPath)
|
||||||
|
if err := os.MkdirAll(secretsDir, 0700); err != nil {
|
||||||
|
return nil, fmt.Errorf("create secrets directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
|
||||||
|
return nil, fmt.Errorf("write EdDSA signing key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Generated and saved new EdDSA signing key",
|
||||||
|
zap.String("path", keyPath))
|
||||||
|
return priv, nil
|
||||||
|
}
|
||||||
|
|||||||
@ -130,11 +130,14 @@ func (cm *ClusterManager) HandleRecoveredNode(ctx context.Context, nodeID string
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(results) > 0 && results[0].Count > 0 {
|
if len(results) > 0 && results[0].Count > 0 {
|
||||||
// Node still has legitimate assignments — just mark active and return
|
// Node still has legitimate assignments — mark active and repair degraded clusters
|
||||||
cm.logger.Info("Recovered node still has cluster assignments, marking active",
|
cm.logger.Info("Recovered node still has cluster assignments, marking active",
|
||||||
zap.String("node_id", nodeID),
|
zap.String("node_id", nodeID),
|
||||||
zap.Int("assignments", results[0].Count))
|
zap.Int("assignments", results[0].Count))
|
||||||
cm.markNodeActive(ctx, nodeID)
|
cm.markNodeActive(ctx, nodeID)
|
||||||
|
|
||||||
|
// Trigger repair for any degraded clusters this node belongs to
|
||||||
|
cm.repairDegradedClusters(ctx, nodeID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -522,6 +525,37 @@ func (cm *ClusterManager) markNodeActive(ctx context.Context, nodeID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// repairDegradedClusters finds degraded clusters that the recovered node
|
||||||
|
// belongs to and triggers RepairCluster for each one.
|
||||||
|
func (cm *ClusterManager) repairDegradedClusters(ctx context.Context, nodeID string) {
|
||||||
|
type clusterRef struct {
|
||||||
|
NamespaceName string `db:"namespace_name"`
|
||||||
|
}
|
||||||
|
var refs []clusterRef
|
||||||
|
query := `
|
||||||
|
SELECT DISTINCT c.namespace_name
|
||||||
|
FROM namespace_cluster_nodes cn
|
||||||
|
JOIN namespace_clusters c ON cn.namespace_cluster_id = c.id
|
||||||
|
WHERE cn.node_id = ? AND c.status = 'degraded'
|
||||||
|
`
|
||||||
|
if err := cm.db.Query(ctx, &refs, query, nodeID); err != nil {
|
||||||
|
cm.logger.Warn("Failed to query degraded clusters for recovered node",
|
||||||
|
zap.String("node_id", nodeID), zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ref := range refs {
|
||||||
|
cm.logger.Info("Triggering repair for degraded cluster after node recovery",
|
||||||
|
zap.String("namespace", ref.NamespaceName),
|
||||||
|
zap.String("recovered_node", nodeID))
|
||||||
|
if err := cm.RepairCluster(ctx, ref.NamespaceName); err != nil {
|
||||||
|
cm.logger.Warn("Failed to repair degraded cluster",
|
||||||
|
zap.String("namespace", ref.NamespaceName),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// removeDeadNodeFromRaft sends a DELETE request to a surviving RQLite node
|
// removeDeadNodeFromRaft sends a DELETE request to a surviving RQLite node
|
||||||
// to remove the dead node from the Raft voter set.
|
// to remove the dead node from the Raft voter set.
|
||||||
func (cm *ClusterManager) removeDeadNodeFromRaft(ctx context.Context, deadRaftAddr string, survivingNodes []survivingNodePorts) {
|
func (cm *ClusterManager) removeDeadNodeFromRaft(ctx context.Context, deadRaftAddr string, survivingNodes []survivingNodePorts) {
|
||||||
|
|||||||
@ -25,8 +25,19 @@ const (
|
|||||||
DefaultDeadAfter = 12 // consecutive misses → dead
|
DefaultDeadAfter = 12 // consecutive misses → dead
|
||||||
DefaultQuorumWindow = 5 * time.Minute
|
DefaultQuorumWindow = 5 * time.Minute
|
||||||
DefaultMinQuorum = 2 // out of K observers must agree
|
DefaultMinQuorum = 2 // out of K observers must agree
|
||||||
|
|
||||||
|
// DefaultStartupGracePeriod prevents false dead declarations after
|
||||||
|
// cluster-wide restart. During this period, no nodes are declared dead.
|
||||||
|
DefaultStartupGracePeriod = 5 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MetadataReader provides lifecycle metadata for peers. Implemented by
|
||||||
|
// ClusterDiscoveryService. The health monitor uses this to check maintenance
|
||||||
|
// status and LastSeen before falling through to HTTP probes.
|
||||||
|
type MetadataReader interface {
|
||||||
|
GetPeerLifecycleState(nodeID string) (state string, lastSeen time.Time, found bool)
|
||||||
|
}
|
||||||
|
|
||||||
// Config holds the configuration for a Monitor.
|
// Config holds the configuration for a Monitor.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
NodeID string // this node's ID (dns_nodes.id / peer ID)
|
NodeID string // this node's ID (dns_nodes.id / peer ID)
|
||||||
@ -35,6 +46,16 @@ type Config struct {
|
|||||||
ProbeInterval time.Duration // how often to probe (default 10s)
|
ProbeInterval time.Duration // how often to probe (default 10s)
|
||||||
ProbeTimeout time.Duration // per-probe HTTP timeout (default 3s)
|
ProbeTimeout time.Duration // per-probe HTTP timeout (default 3s)
|
||||||
Neighbors int // K — how many ring neighbors to monitor (default 3)
|
Neighbors int // K — how many ring neighbors to monitor (default 3)
|
||||||
|
|
||||||
|
// MetadataReader provides LibP2P lifecycle metadata for peers.
|
||||||
|
// When set, the monitor checks peer maintenance state and LastSeen
|
||||||
|
// before falling through to HTTP probes.
|
||||||
|
MetadataReader MetadataReader
|
||||||
|
|
||||||
|
// StartupGracePeriod prevents false dead declarations after cluster-wide
|
||||||
|
// restart. During this period, nodes can be marked suspect but never dead.
|
||||||
|
// Default: 5 minutes.
|
||||||
|
StartupGracePeriod time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// nodeInfo is a row from dns_nodes used for probing.
|
// nodeInfo is a row from dns_nodes used for probing.
|
||||||
@ -56,6 +77,7 @@ type Monitor struct {
|
|||||||
cfg Config
|
cfg Config
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
startTime time.Time // when the monitor was created
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
peers map[string]*peerState // nodeID → state
|
peers map[string]*peerState // nodeID → state
|
||||||
@ -75,6 +97,9 @@ func NewMonitor(cfg Config) *Monitor {
|
|||||||
if cfg.Neighbors == 0 {
|
if cfg.Neighbors == 0 {
|
||||||
cfg.Neighbors = DefaultNeighbors
|
cfg.Neighbors = DefaultNeighbors
|
||||||
}
|
}
|
||||||
|
if cfg.StartupGracePeriod == 0 {
|
||||||
|
cfg.StartupGracePeriod = DefaultStartupGracePeriod
|
||||||
|
}
|
||||||
if cfg.Logger == nil {
|
if cfg.Logger == nil {
|
||||||
cfg.Logger = zap.NewNop()
|
cfg.Logger = zap.NewNop()
|
||||||
}
|
}
|
||||||
@ -84,19 +109,20 @@ func NewMonitor(cfg Config) *Monitor {
|
|||||||
httpClient: &http.Client{
|
httpClient: &http.Client{
|
||||||
Timeout: cfg.ProbeTimeout,
|
Timeout: cfg.ProbeTimeout,
|
||||||
},
|
},
|
||||||
logger: cfg.Logger.With(zap.String("component", "health-monitor")),
|
logger: cfg.Logger.With(zap.String("component", "health-monitor")),
|
||||||
peers: make(map[string]*peerState),
|
startTime: time.Now(),
|
||||||
|
peers: make(map[string]*peerState),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnNodeDead registers a callback invoked when a node is confirmed dead by
|
// OnNodeDead registers a callback invoked when a node is confirmed dead by
|
||||||
// quorum. The callback runs synchronously in the monitor goroutine.
|
// quorum. The callback runs with the monitor lock released.
|
||||||
func (m *Monitor) OnNodeDead(fn func(nodeID string)) {
|
func (m *Monitor) OnNodeDead(fn func(nodeID string)) {
|
||||||
m.onDeadFn = fn
|
m.onDeadFn = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnNodeRecovered registers a callback invoked when a previously dead node
|
// OnNodeRecovered registers a callback invoked when a previously dead node
|
||||||
// transitions back to healthy. Used to clean up orphaned services.
|
// transitions back to healthy. The callback runs with the monitor lock released.
|
||||||
func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) {
|
func (m *Monitor) OnNodeRecovered(fn func(nodeID string)) {
|
||||||
m.onRecoveredFn = fn
|
m.onRecoveredFn = fn
|
||||||
}
|
}
|
||||||
@ -107,6 +133,7 @@ func (m *Monitor) Start(ctx context.Context) {
|
|||||||
zap.String("node_id", m.cfg.NodeID),
|
zap.String("node_id", m.cfg.NodeID),
|
||||||
zap.Duration("probe_interval", m.cfg.ProbeInterval),
|
zap.Duration("probe_interval", m.cfg.ProbeInterval),
|
||||||
zap.Int("neighbors", m.cfg.Neighbors),
|
zap.Int("neighbors", m.cfg.Neighbors),
|
||||||
|
zap.Duration("startup_grace", m.cfg.StartupGracePeriod),
|
||||||
)
|
)
|
||||||
|
|
||||||
ticker := time.NewTicker(m.cfg.ProbeInterval)
|
ticker := time.NewTicker(m.cfg.ProbeInterval)
|
||||||
@ -123,6 +150,11 @@ func (m *Monitor) Start(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isInStartupGrace returns true if the startup grace period is still active.
|
||||||
|
func (m *Monitor) isInStartupGrace() bool {
|
||||||
|
return time.Since(m.startTime) < m.cfg.StartupGracePeriod
|
||||||
|
}
|
||||||
|
|
||||||
// probeRound runs a single round of probing our ring neighbors.
|
// probeRound runs a single round of probing our ring neighbors.
|
||||||
func (m *Monitor) probeRound(ctx context.Context) {
|
func (m *Monitor) probeRound(ctx context.Context) {
|
||||||
neighbors, err := m.getRingNeighbors(ctx)
|
neighbors, err := m.getRingNeighbors(ctx)
|
||||||
@ -140,7 +172,7 @@ func (m *Monitor) probeRound(ctx context.Context) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(node nodeInfo) {
|
go func(node nodeInfo) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
ok := m.probe(ctx, node)
|
ok := m.probeNode(ctx, node)
|
||||||
m.updateState(ctx, node.ID, ok)
|
m.updateState(ctx, node.ID, ok)
|
||||||
}(n)
|
}(n)
|
||||||
}
|
}
|
||||||
@ -150,6 +182,28 @@ func (m *Monitor) probeRound(ctx context.Context) {
|
|||||||
m.pruneStaleState(neighbors)
|
m.pruneStaleState(neighbors)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// probeNode checks a node's health. It first checks LibP2P metadata (if
|
||||||
|
// available) to avoid unnecessary HTTP probes, then falls through to HTTP.
|
||||||
|
func (m *Monitor) probeNode(ctx context.Context, node nodeInfo) bool {
|
||||||
|
if m.cfg.MetadataReader != nil {
|
||||||
|
state, lastSeen, found := m.cfg.MetadataReader.GetPeerLifecycleState(node.ID)
|
||||||
|
if found {
|
||||||
|
// Maintenance node with recent LastSeen → count as healthy
|
||||||
|
if state == "maintenance" && time.Since(lastSeen) < 2*time.Minute {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recently seen active node → count as healthy (no HTTP needed)
|
||||||
|
if state == "active" && time.Since(lastSeen) < 30*time.Second {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall through to HTTP probe
|
||||||
|
return m.probe(ctx, node)
|
||||||
|
}
|
||||||
|
|
||||||
// probe sends an HTTP ping to a single node. Returns true if healthy.
|
// probe sends an HTTP ping to a single node. Returns true if healthy.
|
||||||
func (m *Monitor) probe(ctx context.Context, node nodeInfo) bool {
|
func (m *Monitor) probe(ctx context.Context, node nodeInfo) bool {
|
||||||
url := fmt.Sprintf("http://%s:6001/v1/internal/ping", node.InternalIP)
|
url := fmt.Sprintf("http://%s:6001/v1/internal/ping", node.InternalIP)
|
||||||
@ -167,9 +221,9 @@ func (m *Monitor) probe(ctx context.Context, node nodeInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates the in-memory state for a peer after a probe.
|
// updateState updates the in-memory state for a peer after a probe.
|
||||||
|
// Callbacks are invoked with the lock released to prevent deadlocks (C2 fix).
|
||||||
func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) {
|
func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
ps, exists := m.peers[nodeID]
|
ps, exists := m.peers[nodeID]
|
||||||
if !exists {
|
if !exists {
|
||||||
@ -178,23 +232,26 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if healthy {
|
if healthy {
|
||||||
// Recovered
|
wasDead := ps.status == "dead"
|
||||||
if ps.status != "healthy" {
|
shouldCallback := wasDead && m.onRecoveredFn != nil
|
||||||
wasDead := ps.status == "dead"
|
prevStatus := ps.status
|
||||||
m.logger.Info("Node recovered", zap.String("target", nodeID),
|
|
||||||
zap.String("previous_status", ps.status))
|
|
||||||
m.writeEvent(ctx, nodeID, "recovered")
|
|
||||||
|
|
||||||
// Fire recovery callback for nodes that were confirmed dead
|
// Update state BEFORE releasing lock (C2 fix)
|
||||||
if wasDead && m.onRecoveredFn != nil {
|
|
||||||
m.mu.Unlock()
|
|
||||||
m.onRecoveredFn(nodeID)
|
|
||||||
m.mu.Lock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ps.missCount = 0
|
ps.missCount = 0
|
||||||
ps.status = "healthy"
|
ps.status = "healthy"
|
||||||
ps.reportedDead = false
|
ps.reportedDead = false
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if prevStatus != "healthy" {
|
||||||
|
m.logger.Info("Node recovered", zap.String("target", nodeID),
|
||||||
|
zap.String("previous_status", prevStatus))
|
||||||
|
m.writeEvent(ctx, nodeID, "recovered")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fire recovery callback without holding the lock (C2 fix)
|
||||||
|
if shouldCallback {
|
||||||
|
m.onRecoveredFn(nodeID)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,6 +260,23 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case ps.missCount >= DefaultDeadAfter && !ps.reportedDead:
|
case ps.missCount >= DefaultDeadAfter && !ps.reportedDead:
|
||||||
|
// During startup grace period, don't declare dead — only suspect
|
||||||
|
if m.isInStartupGrace() {
|
||||||
|
if ps.status != "suspect" {
|
||||||
|
ps.status = "suspect"
|
||||||
|
ps.suspectAt = time.Now()
|
||||||
|
m.mu.Unlock()
|
||||||
|
m.logger.Warn("Node SUSPECT (startup grace — deferring dead)",
|
||||||
|
zap.String("target", nodeID),
|
||||||
|
zap.Int("misses", ps.missCount),
|
||||||
|
)
|
||||||
|
m.writeEvent(ctx, nodeID, "suspect")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if ps.status != "dead" {
|
if ps.status != "dead" {
|
||||||
m.logger.Error("Node declared DEAD",
|
m.logger.Error("Node declared DEAD",
|
||||||
zap.String("target", nodeID),
|
zap.String("target", nodeID),
|
||||||
@ -211,22 +285,34 @@ func (m *Monitor) updateState(ctx context.Context, nodeID string, healthy bool)
|
|||||||
}
|
}
|
||||||
ps.status = "dead"
|
ps.status = "dead"
|
||||||
ps.reportedDead = true
|
ps.reportedDead = true
|
||||||
|
|
||||||
|
// Copy what we need before releasing lock
|
||||||
|
shouldCheckQuorum := m.cfg.DB != nil && m.onDeadFn != nil
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
m.writeEvent(ctx, nodeID, "dead")
|
m.writeEvent(ctx, nodeID, "dead")
|
||||||
m.checkQuorum(ctx, nodeID)
|
if shouldCheckQuorum {
|
||||||
|
m.checkQuorum(ctx, nodeID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy":
|
case ps.missCount >= DefaultSuspectAfter && ps.status == "healthy":
|
||||||
ps.status = "suspect"
|
ps.status = "suspect"
|
||||||
ps.suspectAt = time.Now()
|
ps.suspectAt = time.Now()
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
m.logger.Warn("Node SUSPECT",
|
m.logger.Warn("Node SUSPECT",
|
||||||
zap.String("target", nodeID),
|
zap.String("target", nodeID),
|
||||||
zap.Int("misses", ps.missCount),
|
zap.Int("misses", ps.missCount),
|
||||||
)
|
)
|
||||||
m.writeEvent(ctx, nodeID, "suspect")
|
m.writeEvent(ctx, nodeID, "suspect")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeEvent inserts a health event into node_health_events. Must be called
|
// writeEvent inserts a health event into node_health_events.
|
||||||
// with m.mu held.
|
|
||||||
func (m *Monitor) writeEvent(ctx context.Context, targetID, status string) {
|
func (m *Monitor) writeEvent(ctx context.Context, targetID, status string) {
|
||||||
if m.cfg.DB == nil {
|
if m.cfg.DB == nil {
|
||||||
return
|
return
|
||||||
@ -242,7 +328,8 @@ func (m *Monitor) writeEvent(ctx context.Context, targetID, status string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkQuorum queries the events table to see if enough observers agree the
|
// checkQuorum queries the events table to see if enough observers agree the
|
||||||
// target is dead, then fires the onDead callback. Must be called with m.mu held.
|
// target is dead, then fires the onDead callback. Called WITHOUT the lock held
|
||||||
|
// (C2 fix — previously called with lock held, causing deadlocks in callbacks).
|
||||||
func (m *Monitor) checkQuorum(ctx context.Context, targetID string) {
|
func (m *Monitor) checkQuorum(ctx context.Context, targetID string) {
|
||||||
if m.cfg.DB == nil || m.onDeadFn == nil {
|
if m.cfg.DB == nil || m.onDeadFn == nil {
|
||||||
return
|
return
|
||||||
@ -287,10 +374,7 @@ func (m *Monitor) checkQuorum(ctx context.Context, targetID string) {
|
|||||||
zap.String("target", targetID),
|
zap.String("target", targetID),
|
||||||
zap.Int("observers", count),
|
zap.Int("observers", count),
|
||||||
)
|
)
|
||||||
// Release the lock before calling the callback to avoid deadlocks.
|
|
||||||
m.mu.Unlock()
|
|
||||||
m.onDeadFn(targetID)
|
m.onDeadFn(targetID)
|
||||||
m.mu.Lock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRingNeighbors queries dns_nodes for active nodes, sorts them, and
|
// getRingNeighbors queries dns_nodes for active nodes, sorts them, and
|
||||||
|
|||||||
@ -158,10 +158,12 @@ func TestRingNeighbors_KLargerThanRing(t *testing.T) {
|
|||||||
|
|
||||||
func TestStateTransitions(t *testing.T) {
|
func TestStateTransitions(t *testing.T) {
|
||||||
m := NewMonitor(Config{
|
m := NewMonitor(Config{
|
||||||
NodeID: "self",
|
NodeID: "self",
|
||||||
ProbeInterval: time.Second,
|
ProbeInterval: time.Second,
|
||||||
Neighbors: 3,
|
Neighbors: 3,
|
||||||
|
StartupGracePeriod: 1 * time.Millisecond, // disable grace for this test
|
||||||
})
|
})
|
||||||
|
time.Sleep(2 * time.Millisecond) // ensure grace period expired
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@ -298,9 +300,11 @@ func TestOnNodeDead_Callback(t *testing.T) {
|
|||||||
var called atomic.Int32
|
var called atomic.Int32
|
||||||
|
|
||||||
m := NewMonitor(Config{
|
m := NewMonitor(Config{
|
||||||
NodeID: "self",
|
NodeID: "self",
|
||||||
Neighbors: 3,
|
Neighbors: 3,
|
||||||
|
StartupGracePeriod: 1 * time.Millisecond,
|
||||||
})
|
})
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
m.OnNodeDead(func(nodeID string) {
|
m.OnNodeDead(func(nodeID string) {
|
||||||
called.Add(1)
|
called.Add(1)
|
||||||
})
|
})
|
||||||
@ -316,3 +320,204 @@ func TestOnNodeDead_Callback(t *testing.T) {
|
|||||||
t.Fatalf("expected dead, got %s", m.peers["victim"].status)
|
t.Fatalf("expected dead, got %s", m.peers["victim"].status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
// Startup grace period
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestStartupGrace_PreventsDead(t *testing.T) {
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
Neighbors: 3,
|
||||||
|
StartupGracePeriod: 1 * time.Hour, // very long grace
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Accumulate enough misses for dead (12)
|
||||||
|
for i := 0; i < DefaultDeadAfter+5; i++ {
|
||||||
|
m.updateState(ctx, "peer1", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
status := m.peers["peer1"].status
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// During grace, should be suspect, NOT dead
|
||||||
|
if status != "suspect" {
|
||||||
|
t.Fatalf("expected suspect during startup grace, got %s", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartupGrace_AllowsDeadAfterExpiry(t *testing.T) {
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
Neighbors: 3,
|
||||||
|
StartupGracePeriod: 1 * time.Millisecond,
|
||||||
|
})
|
||||||
|
time.Sleep(2 * time.Millisecond) // grace expired
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
for i := 0; i < DefaultDeadAfter; i++ {
|
||||||
|
m.updateState(ctx, "peer1", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
status := m.peers["peer1"].status
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if status != "dead" {
|
||||||
|
t.Fatalf("expected dead after grace expired, got %s", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
// MetadataReader integration
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
|
||||||
|
type mockMetadataReader struct {
|
||||||
|
state string
|
||||||
|
lastSeen time.Time
|
||||||
|
found bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockMetadataReader) GetPeerLifecycleState(nodeID string) (string, time.Time, bool) {
|
||||||
|
return m.state, m.lastSeen, m.found
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeNode_MaintenanceCountsHealthy(t *testing.T) {
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
MetadataReader: &mockMetadataReader{
|
||||||
|
state: "maintenance",
|
||||||
|
lastSeen: time.Now(),
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"} // unreachable
|
||||||
|
ok := m.probeNode(context.Background(), node)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("maintenance node with recent LastSeen should count as healthy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeNode_RecentActiveSkipsHTTP(t *testing.T) {
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
MetadataReader: &mockMetadataReader{
|
||||||
|
state: "active",
|
||||||
|
lastSeen: time.Now(),
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Use unreachable IP — if HTTP were attempted, it would fail
|
||||||
|
node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"}
|
||||||
|
ok := m.probeNode(context.Background(), node)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("recently seen active node should skip HTTP and count as healthy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeNode_StaleMetadataFallsToHTTP(t *testing.T) {
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
ProbeTimeout: 100 * time.Millisecond,
|
||||||
|
MetadataReader: &mockMetadataReader{
|
||||||
|
state: "active",
|
||||||
|
lastSeen: time.Now().Add(-5 * time.Minute), // stale
|
||||||
|
found: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
node := nodeInfo{ID: "peer1", InternalIP: "192.0.2.1"} // unreachable
|
||||||
|
ok := m.probeNode(context.Background(), node)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("stale metadata should fall through to HTTP probe, which should fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeNode_UnknownPeerFallsToHTTP(t *testing.T) {
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
ProbeTimeout: 100 * time.Millisecond,
|
||||||
|
MetadataReader: &mockMetadataReader{
|
||||||
|
found: false, // peer not found in metadata
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
node := nodeInfo{ID: "unknown", InternalIP: "192.0.2.1"}
|
||||||
|
ok := m.probeNode(context.Background(), node)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("unknown peer should fall through to HTTP probe, which should fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProbeNode_NoMetadataReader(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
ProbeTimeout: 2 * time.Second,
|
||||||
|
MetadataReader: nil, // no metadata reader
|
||||||
|
})
|
||||||
|
|
||||||
|
// Without MetadataReader, should go straight to HTTP
|
||||||
|
addr := strings.TrimPrefix(srv.URL, "http://")
|
||||||
|
node := nodeInfo{ID: "peer1", InternalIP: addr}
|
||||||
|
// Note: probe() hardcodes port 6001, so this won't hit our server.
|
||||||
|
// But we verify it doesn't panic and falls through correctly.
|
||||||
|
_ = m.probeNode(context.Background(), node)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
// Recovery callback (C2 fix)
|
||||||
|
// ---------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestRecoveryCallback_InvokedWithoutLock(t *testing.T) {
|
||||||
|
m := NewMonitor(Config{
|
||||||
|
NodeID: "self",
|
||||||
|
Neighbors: 3,
|
||||||
|
StartupGracePeriod: 1 * time.Millisecond,
|
||||||
|
})
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
|
||||||
|
var recoveredNode string
|
||||||
|
m.OnNodeRecovered(func(nodeID string) {
|
||||||
|
recoveredNode = nodeID
|
||||||
|
// If lock were held, this would deadlock since we try to access peers
|
||||||
|
// Just verify callback fires correctly
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Drive to dead state
|
||||||
|
for i := 0; i < DefaultDeadAfter; i++ {
|
||||||
|
m.updateState(ctx, "peer1", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
if m.peers["peer1"].status != "dead" {
|
||||||
|
m.mu.Unlock()
|
||||||
|
t.Fatal("expected dead state")
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Recover
|
||||||
|
m.updateState(ctx, "peer1", true)
|
||||||
|
|
||||||
|
if recoveredNode != "peer1" {
|
||||||
|
t.Fatalf("expected recovery callback for peer1, got %q", recoveredNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
if m.peers["peer1"].status != "healthy" {
|
||||||
|
m.mu.Unlock()
|
||||||
|
t.Fatal("expected healthy after recovery")
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
|||||||
184
pkg/node/lifecycle/manager.go
Normal file
184
pkg/node/lifecycle/manager.go
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
package lifecycle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// State represents a node's lifecycle state.
|
||||||
|
type State string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateJoining State = "joining"
|
||||||
|
StateActive State = "active"
|
||||||
|
StateDraining State = "draining"
|
||||||
|
StateMaintenance State = "maintenance"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxMaintenanceTTL is the maximum duration a node can remain in maintenance
|
||||||
|
// mode. The leader's health monitor enforces this limit — nodes that exceed
|
||||||
|
// it are treated as unreachable so they can't hide in maintenance forever.
|
||||||
|
const MaxMaintenanceTTL = 15 * time.Minute
|
||||||
|
|
||||||
|
// validTransitions defines the allowed state machine transitions.
|
||||||
|
// Each entry maps from-state → set of valid to-states.
|
||||||
|
var validTransitions = map[State]map[State]bool{
|
||||||
|
StateJoining: {StateActive: true},
|
||||||
|
StateActive: {StateDraining: true, StateMaintenance: true},
|
||||||
|
StateDraining: {StateMaintenance: true},
|
||||||
|
StateMaintenance: {StateActive: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateChangeCallback is called when the lifecycle state changes.
|
||||||
|
type StateChangeCallback func(old, new State)
|
||||||
|
|
||||||
|
// Manager manages a node's lifecycle state machine.
|
||||||
|
// It has no external dependencies (no LibP2P, no discovery imports)
|
||||||
|
// and is fully testable in isolation.
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
state State
|
||||||
|
maintenanceTTL time.Time
|
||||||
|
enterTime time.Time // when the current state was entered
|
||||||
|
onStateChange []StateChangeCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new lifecycle manager in the joining state.
|
||||||
|
func NewManager() *Manager {
|
||||||
|
return &Manager{
|
||||||
|
state: StateJoining,
|
||||||
|
enterTime: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// State returns the current lifecycle state.
|
||||||
|
func (m *Manager) State() State {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.state
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaintenanceTTL returns the maintenance mode expiration time.
|
||||||
|
// Returns zero value if not in maintenance.
|
||||||
|
func (m *Manager) MaintenanceTTL() time.Time {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.maintenanceTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
// StateEnteredAt returns when the current state was entered.
|
||||||
|
func (m *Manager) StateEnteredAt() time.Time {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.enterTime
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnStateChange registers a callback invoked on state transitions.
|
||||||
|
// Callbacks are called with the lock released to avoid deadlocks.
|
||||||
|
func (m *Manager) OnStateChange(cb StateChangeCallback) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.onStateChange = append(m.onStateChange, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransitionTo moves the node to a new lifecycle state.
|
||||||
|
// Returns an error if the transition is not valid.
|
||||||
|
func (m *Manager) TransitionTo(newState State) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
old := m.state
|
||||||
|
|
||||||
|
allowed, exists := validTransitions[old]
|
||||||
|
if !exists || !allowed[newState] {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("invalid lifecycle transition: %s → %s", old, newState)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.state = newState
|
||||||
|
m.enterTime = time.Now()
|
||||||
|
|
||||||
|
// Clear maintenance TTL when leaving maintenance
|
||||||
|
if newState != StateMaintenance {
|
||||||
|
m.maintenanceTTL = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy callbacks before releasing lock
|
||||||
|
callbacks := make([]StateChangeCallback, len(m.onStateChange))
|
||||||
|
copy(callbacks, m.onStateChange)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Invoke callbacks without holding the lock
|
||||||
|
for _, cb := range callbacks {
|
||||||
|
cb(old, newState)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnterMaintenance transitions to maintenance with a TTL.
|
||||||
|
// The TTL is capped at MaxMaintenanceTTL.
|
||||||
|
func (m *Manager) EnterMaintenance(ttl time.Duration) error {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = MaxMaintenanceTTL
|
||||||
|
}
|
||||||
|
if ttl > MaxMaintenanceTTL {
|
||||||
|
ttl = MaxMaintenanceTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
old := m.state
|
||||||
|
|
||||||
|
// Allow both active→maintenance and draining→maintenance
|
||||||
|
allowed, exists := validTransitions[old]
|
||||||
|
if !exists || !allowed[StateMaintenance] {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("invalid lifecycle transition: %s → %s", old, StateMaintenance)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.state = StateMaintenance
|
||||||
|
m.maintenanceTTL = time.Now().Add(ttl)
|
||||||
|
m.enterTime = time.Now()
|
||||||
|
|
||||||
|
callbacks := make([]StateChangeCallback, len(m.onStateChange))
|
||||||
|
copy(callbacks, m.onStateChange)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
for _, cb := range callbacks {
|
||||||
|
cb(old, StateMaintenance)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMaintenanceExpired returns true if the node is in maintenance and the TTL
|
||||||
|
// has expired. Used by the leader's health monitor to enforce the max TTL.
|
||||||
|
func (m *Manager) IsMaintenanceExpired() bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
if m.state != StateMaintenance {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !m.maintenanceTTL.IsZero() && time.Now().After(m.maintenanceTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAvailable returns true if the node is in a state that can serve requests.
|
||||||
|
func (m *Manager) IsAvailable() bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.state == StateActive
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInMaintenance returns true if the node is in maintenance mode.
|
||||||
|
func (m *Manager) IsInMaintenance() bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.state == StateMaintenance
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot returns a point-in-time copy of the lifecycle state for
|
||||||
|
// embedding in metadata without holding the lock.
|
||||||
|
func (m *Manager) Snapshot() (state State, ttl time.Time) {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
return m.state, m.maintenanceTTL
|
||||||
|
}
|
||||||
320
pkg/node/lifecycle/manager_test.go
Normal file
320
pkg/node/lifecycle/manager_test.go
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
package lifecycle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManager(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
if m.State() != StateJoining {
|
||||||
|
t.Fatalf("expected initial state %q, got %q", StateJoining, m.State())
|
||||||
|
}
|
||||||
|
if m.IsAvailable() {
|
||||||
|
t.Fatal("joining node should not be available")
|
||||||
|
}
|
||||||
|
if m.IsInMaintenance() {
|
||||||
|
t.Fatal("joining node should not be in maintenance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidTransitions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
from State
|
||||||
|
to State
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"joining→active", StateJoining, StateActive, false},
|
||||||
|
{"active→draining", StateActive, StateDraining, false},
|
||||||
|
{"draining→maintenance", StateDraining, StateMaintenance, false},
|
||||||
|
{"active→maintenance", StateActive, StateMaintenance, false},
|
||||||
|
{"maintenance→active", StateMaintenance, StateActive, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := &Manager{state: tt.from, enterTime: time.Now()}
|
||||||
|
err := m.TransitionTo(tt.to)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Fatalf("TransitionTo(%q): err=%v, wantErr=%v", tt.to, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if err == nil && m.State() != tt.to {
|
||||||
|
t.Fatalf("expected state %q, got %q", tt.to, m.State())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidTransitions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
from State
|
||||||
|
to State
|
||||||
|
}{
|
||||||
|
{"joining→draining", StateJoining, StateDraining},
|
||||||
|
{"joining→maintenance", StateJoining, StateMaintenance},
|
||||||
|
{"joining→joining", StateJoining, StateJoining},
|
||||||
|
{"active→active", StateActive, StateActive},
|
||||||
|
{"active→joining", StateActive, StateJoining},
|
||||||
|
{"draining→active", StateDraining, StateActive},
|
||||||
|
{"draining→joining", StateDraining, StateJoining},
|
||||||
|
{"maintenance→draining", StateMaintenance, StateDraining},
|
||||||
|
{"maintenance→joining", StateMaintenance, StateJoining},
|
||||||
|
{"maintenance→maintenance", StateMaintenance, StateMaintenance},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := &Manager{state: tt.from, enterTime: time.Now()}
|
||||||
|
err := m.TransitionTo(tt.to)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for transition %s → %s", tt.from, tt.to)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnterMaintenance(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
err := m.EnterMaintenance(5 * time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EnterMaintenance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.IsInMaintenance() {
|
||||||
|
t.Fatal("expected maintenance state")
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := m.MaintenanceTTL()
|
||||||
|
if ttl.IsZero() {
|
||||||
|
t.Fatal("expected non-zero maintenance TTL")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TTL should be roughly 5 minutes from now
|
||||||
|
remaining := time.Until(ttl)
|
||||||
|
if remaining < 4*time.Minute || remaining > 6*time.Minute {
|
||||||
|
t.Fatalf("expected TTL ~5min from now, got %v", remaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnterMaintenanceTTLCapped(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
// Request 1 hour, should be capped at MaxMaintenanceTTL
|
||||||
|
err := m.EnterMaintenance(1 * time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EnterMaintenance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := m.MaintenanceTTL()
|
||||||
|
remaining := time.Until(ttl)
|
||||||
|
if remaining > MaxMaintenanceTTL+time.Second {
|
||||||
|
t.Fatalf("TTL should be capped at %v, got %v remaining", MaxMaintenanceTTL, remaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnterMaintenanceZeroTTL(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
// Zero TTL should default to MaxMaintenanceTTL
|
||||||
|
err := m.EnterMaintenance(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EnterMaintenance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := m.MaintenanceTTL()
|
||||||
|
remaining := time.Until(ttl)
|
||||||
|
if remaining < MaxMaintenanceTTL-time.Second {
|
||||||
|
t.Fatalf("zero TTL should default to MaxMaintenanceTTL, got %v remaining", remaining)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaintenanceTTLClearedOnExit(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
_ = m.EnterMaintenance(5 * time.Minute)
|
||||||
|
|
||||||
|
if m.MaintenanceTTL().IsZero() {
|
||||||
|
t.Fatal("expected non-zero TTL in maintenance")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
if !m.MaintenanceTTL().IsZero() {
|
||||||
|
t.Fatal("expected zero TTL after leaving maintenance")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsMaintenanceExpired(t *testing.T) {
|
||||||
|
m := &Manager{
|
||||||
|
state: StateMaintenance,
|
||||||
|
maintenanceTTL: time.Now().Add(-1 * time.Minute), // expired 1 minute ago
|
||||||
|
enterTime: time.Now().Add(-20 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.IsMaintenanceExpired() {
|
||||||
|
t.Fatal("expected maintenance to be expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not expired
|
||||||
|
m.maintenanceTTL = time.Now().Add(5 * time.Minute)
|
||||||
|
if m.IsMaintenanceExpired() {
|
||||||
|
t.Fatal("expected maintenance to not be expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not in maintenance
|
||||||
|
m.state = StateActive
|
||||||
|
if m.IsMaintenanceExpired() {
|
||||||
|
t.Fatal("expected non-maintenance state to not report expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateChangeCallback(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
|
||||||
|
var callbackOld, callbackNew State
|
||||||
|
called := false
|
||||||
|
m.OnStateChange(func(old, new State) {
|
||||||
|
callbackOld = old
|
||||||
|
callbackNew = new
|
||||||
|
called = true
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("callback was not called")
|
||||||
|
}
|
||||||
|
if callbackOld != StateJoining || callbackNew != StateActive {
|
||||||
|
t.Fatalf("callback got old=%q new=%q, want old=%q new=%q",
|
||||||
|
callbackOld, callbackNew, StateJoining, StateActive)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleCallbacks(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
m.OnStateChange(func(_, _ State) { count++ })
|
||||||
|
m.OnStateChange(func(_, _ State) { count++ })
|
||||||
|
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
if count != 2 {
|
||||||
|
t.Fatalf("expected 2 callbacks, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSnapshot(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
_ = m.EnterMaintenance(10 * time.Minute)
|
||||||
|
|
||||||
|
state, ttl := m.Snapshot()
|
||||||
|
if state != StateMaintenance {
|
||||||
|
t.Fatalf("expected maintenance, got %q", state)
|
||||||
|
}
|
||||||
|
if ttl.IsZero() {
|
||||||
|
t.Fatal("expected non-zero TTL in snapshot")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
// Concurrent reads
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = m.State()
|
||||||
|
_ = m.IsAvailable()
|
||||||
|
_ = m.IsInMaintenance()
|
||||||
|
_ = m.IsMaintenanceExpired()
|
||||||
|
_, _ = m.Snapshot()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrent maintenance enter/exit cycles
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = m.EnterMaintenance(1 * time.Minute)
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateEnteredAt(t *testing.T) {
|
||||||
|
before := time.Now()
|
||||||
|
m := NewManager()
|
||||||
|
after := time.Now()
|
||||||
|
|
||||||
|
entered := m.StateEnteredAt()
|
||||||
|
if entered.Before(before) || entered.After(after) {
|
||||||
|
t.Fatalf("StateEnteredAt %v not between %v and %v", entered, before, after)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
_ = m.TransitionTo(StateActive)
|
||||||
|
|
||||||
|
newEntered := m.StateEnteredAt()
|
||||||
|
if !newEntered.After(entered) {
|
||||||
|
t.Fatal("expected StateEnteredAt to update after transition")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnterMaintenanceFromInvalidState(t *testing.T) {
|
||||||
|
m := NewManager() // joining state
|
||||||
|
err := m.EnterMaintenance(5 * time.Minute)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error entering maintenance from joining state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFullLifecycle(t *testing.T) {
|
||||||
|
m := NewManager()
|
||||||
|
|
||||||
|
// joining → active
|
||||||
|
if err := m.TransitionTo(StateActive); err != nil {
|
||||||
|
t.Fatalf("joining→active: %v", err)
|
||||||
|
}
|
||||||
|
if !m.IsAvailable() {
|
||||||
|
t.Fatal("active node should be available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// active → draining
|
||||||
|
if err := m.TransitionTo(StateDraining); err != nil {
|
||||||
|
t.Fatalf("active→draining: %v", err)
|
||||||
|
}
|
||||||
|
if m.IsAvailable() {
|
||||||
|
t.Fatal("draining node should not be available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// draining → maintenance
|
||||||
|
if err := m.EnterMaintenance(10 * time.Minute); err != nil {
|
||||||
|
t.Fatalf("draining→maintenance: %v", err)
|
||||||
|
}
|
||||||
|
if !m.IsInMaintenance() {
|
||||||
|
t.Fatal("should be in maintenance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// maintenance → active
|
||||||
|
if err := m.TransitionTo(StateActive); err != nil {
|
||||||
|
t.Fatalf("maintenance→active: %v", err)
|
||||||
|
}
|
||||||
|
if !m.IsAvailable() {
|
||||||
|
t.Fatal("should be available after maintenance")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/DeBrosOfficial/network/pkg/gateway"
|
"github.com/DeBrosOfficial/network/pkg/gateway"
|
||||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/node/lifecycle"
|
||||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
database "github.com/DeBrosOfficial/network/pkg/rqlite"
|
database "github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
"github.com/libp2p/go-libp2p/core/host"
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
@ -24,6 +25,9 @@ type Node struct {
|
|||||||
logger *logging.ColoredLogger
|
logger *logging.ColoredLogger
|
||||||
host host.Host
|
host host.Host
|
||||||
|
|
||||||
|
// Lifecycle state machine (joining → active ⇄ maintenance)
|
||||||
|
lifecycle *lifecycle.Manager
|
||||||
|
|
||||||
rqliteManager *database.RQLiteManager
|
rqliteManager *database.RQLiteManager
|
||||||
rqliteAdapter *database.RQLiteAdapter
|
rqliteAdapter *database.RQLiteAdapter
|
||||||
clusterDiscovery *database.ClusterDiscoveryService
|
clusterDiscovery *database.ClusterDiscoveryService
|
||||||
@ -54,8 +58,9 @@ func NewNode(cfg *config.Config) (*Node, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Node{
|
return &Node{
|
||||||
config: cfg,
|
config: cfg,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
lifecycle: lifecycle.NewManager(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,9 +129,20 @@ func (n *Node) Start(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All services started — transition lifecycle: joining → active
|
||||||
|
if err := n.lifecycle.TransitionTo(lifecycle.StateActive); err != nil {
|
||||||
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to transition lifecycle to active", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish updated metadata with active lifecycle state
|
||||||
|
if n.clusterDiscovery != nil {
|
||||||
|
n.clusterDiscovery.UpdateOwnMetadata()
|
||||||
|
}
|
||||||
|
|
||||||
n.logger.ComponentInfo(logging.ComponentNode, "Network node started successfully",
|
n.logger.ComponentInfo(logging.ComponentNode, "Network node started successfully",
|
||||||
zap.String("peer_id", n.GetPeerID()),
|
zap.String("peer_id", n.GetPeerID()),
|
||||||
zap.Strings("listen_addrs", listenAddrs),
|
zap.Strings("listen_addrs", listenAddrs),
|
||||||
|
zap.String("lifecycle", string(n.lifecycle.State())),
|
||||||
)
|
)
|
||||||
|
|
||||||
n.startConnectionMonitoring()
|
n.startConnectionMonitoring()
|
||||||
@ -138,6 +154,17 @@ func (n *Node) Start(ctx context.Context) error {
|
|||||||
func (n *Node) Stop() error {
|
func (n *Node) Stop() error {
|
||||||
n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node")
|
n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node")
|
||||||
|
|
||||||
|
// Enter maintenance so peers know we're shutting down
|
||||||
|
if n.lifecycle.IsAvailable() {
|
||||||
|
if err := n.lifecycle.EnterMaintenance(5 * time.Minute); err != nil {
|
||||||
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to enter maintenance on shutdown", zap.Error(err))
|
||||||
|
}
|
||||||
|
// Publish maintenance state before tearing down services
|
||||||
|
if n.clusterDiscovery != nil {
|
||||||
|
n.clusterDiscovery.UpdateOwnMetadata()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stop HTTP Gateway server
|
// Stop HTTP Gateway server
|
||||||
if n.apiGatewayServer != nil {
|
if n.apiGatewayServer != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
|||||||
@ -34,6 +34,7 @@ func (n *Node) startRQLite(ctx context.Context) error {
|
|||||||
n.config.Discovery.RaftAdvAddress,
|
n.config.Discovery.RaftAdvAddress,
|
||||||
n.config.Discovery.HttpAdvAddress,
|
n.config.Discovery.HttpAdvAddress,
|
||||||
n.config.Node.DataDir,
|
n.config.Node.DataDir,
|
||||||
|
n.lifecycle,
|
||||||
n.logger.Logger,
|
n.logger.Logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -71,9 +71,10 @@ func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rq
|
|||||||
}
|
}
|
||||||
|
|
||||||
if remotePeerCount >= requiredRemotePeers {
|
if remotePeerCount >= requiredRemotePeers {
|
||||||
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
|
// Check discovery-peers.json (safe location outside raft dir)
|
||||||
|
peersPath := filepath.Join(rqliteDataDir, "discovery-peers.json")
|
||||||
r.discoveryService.TriggerSync()
|
r.discoveryService.TriggerSync()
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
if info, err := os.Stat(peersPath); err == nil && info.Size() > 10 {
|
if info, err := os.Stat(peersPath); err == nil && info.Size() > 10 {
|
||||||
data, err := os.ReadFile(peersPath)
|
data, err := os.ReadFile(peersPath)
|
||||||
@ -97,13 +98,11 @@ func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rql
|
|||||||
if err := r.discoveryService.TriggerPeerExchange(ctx); err != nil {
|
if err := r.discoveryService.TriggerPeerExchange(ctx); err != nil {
|
||||||
r.logger.Warn("Failed to trigger peer exchange during pre-start discovery", zap.Error(err))
|
r.logger.Warn("Failed to trigger peer exchange during pre-start discovery", zap.Error(err))
|
||||||
}
|
}
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
r.discoveryService.TriggerSync()
|
r.discoveryService.TriggerSync()
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
// Wait up to 2 minutes for peer discovery - LibP2P DHT can take 60+ seconds
|
// Wait up to 45s for peer discovery — parallel dials compensate for the shorter deadline
|
||||||
// to re-establish connections after simultaneous restart
|
discoveryDeadline := time.Now().Add(45 * time.Second)
|
||||||
discoveryDeadline := time.Now().Add(2 * time.Minute)
|
|
||||||
var discoveredPeers int
|
var discoveredPeers int
|
||||||
|
|
||||||
for time.Now().Before(discoveryDeadline) {
|
for time.Now().Before(discoveryDeadline) {
|
||||||
@ -151,7 +150,7 @@ func (r *RQLiteManager) performPreStartClusterDiscovery(ctx context.Context, rql
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.discoveryService.TriggerSync()
|
r.discoveryService.TriggerSync()
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -182,9 +181,8 @@ func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.discoveryService.TriggerPeerExchange(ctx)
|
r.discoveryService.TriggerPeerExchange(ctx)
|
||||||
time.Sleep(2 * time.Second)
|
|
||||||
r.discoveryService.TriggerSync()
|
r.discoveryService.TriggerSync()
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
rqliteDataDir, _ := r.rqliteDataDirPath()
|
rqliteDataDir, _ := r.rqliteDataDirPath()
|
||||||
ourIndex := r.getRaftLogIndex()
|
ourIndex := r.getRaftLogIndex()
|
||||||
@ -201,7 +199,7 @@ func (r *RQLiteManager) recoverFromSplitBrain(ctx context.Context) error {
|
|||||||
r.logger.Warn("Failed to clear raft state during split-brain recovery", zap.Error(err))
|
r.logger.Warn("Failed to clear raft state during split-brain recovery", zap.Error(err))
|
||||||
}
|
}
|
||||||
r.discoveryService.TriggerPeerExchange(ctx)
|
r.discoveryService.TriggerPeerExchange(ctx)
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(500 * time.Millisecond)
|
||||||
if err := r.discoveryService.ForceWritePeersJSON(); err != nil {
|
if err := r.discoveryService.ForceWritePeersJSON(); err != nil {
|
||||||
r.logger.Warn("Failed to write peers.json during split-brain recovery", zap.Error(err))
|
r.logger.Warn("Failed to write peers.json during split-brain recovery", zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -326,14 +324,15 @@ func (r *RQLiteManager) hasExistingRaftState(rqliteDataDir string) bool {
|
|||||||
if info, err := os.Stat(raftLogPath); err == nil && info.Size() > 1024 {
|
if info, err := os.Stat(raftLogPath); err == nil && info.Size() > 1024 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
peersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
|
// Don't check peers.json — discovery-peers.json is now written outside
|
||||||
_, err := os.Stat(peersPath)
|
// the raft dir and should not be treated as existing Raft state.
|
||||||
return err == nil
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RQLiteManager) clearRaftState(rqliteDataDir string) error {
|
func (r *RQLiteManager) clearRaftState(rqliteDataDir string) error {
|
||||||
_ = os.Remove(filepath.Join(rqliteDataDir, "raft.db"))
|
_ = os.Remove(filepath.Join(rqliteDataDir, "raft.db"))
|
||||||
_ = os.Remove(filepath.Join(rqliteDataDir, "raft", "peers.json"))
|
_ = os.Remove(filepath.Join(rqliteDataDir, "raft", "peers.json"))
|
||||||
|
_ = os.Remove(filepath.Join(rqliteDataDir, "discovery-peers.json"))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -3,10 +3,12 @@ package rqlite
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/discovery"
|
"github.com/DeBrosOfficial/network/pkg/discovery"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/node/lifecycle"
|
||||||
"github.com/libp2p/go-libp2p/core/host"
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@ -20,9 +22,13 @@ type ClusterDiscoveryService struct {
|
|||||||
nodeType string
|
nodeType string
|
||||||
raftAddress string
|
raftAddress string
|
||||||
httpAddress string
|
httpAddress string
|
||||||
|
wireGuardIP string // extracted from raftAddress (IP component)
|
||||||
dataDir string
|
dataDir string
|
||||||
minClusterSize int // Minimum cluster size required
|
minClusterSize int // Minimum cluster size required
|
||||||
|
|
||||||
|
// Lifecycle manager for this node's state machine
|
||||||
|
lifecycle *lifecycle.Manager
|
||||||
|
|
||||||
knownPeers map[string]*discovery.RQLiteNodeMetadata // NodeID -> Metadata
|
knownPeers map[string]*discovery.RQLiteNodeMetadata // NodeID -> Metadata
|
||||||
peerHealth map[string]*PeerHealth // NodeID -> Health
|
peerHealth map[string]*PeerHealth // NodeID -> Health
|
||||||
lastUpdate time.Time
|
lastUpdate time.Time
|
||||||
@ -45,6 +51,7 @@ func NewClusterDiscoveryService(
|
|||||||
raftAddress string,
|
raftAddress string,
|
||||||
httpAddress string,
|
httpAddress string,
|
||||||
dataDir string,
|
dataDir string,
|
||||||
|
lm *lifecycle.Manager,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) *ClusterDiscoveryService {
|
) *ClusterDiscoveryService {
|
||||||
minClusterSize := 1
|
minClusterSize := 1
|
||||||
@ -52,6 +59,12 @@ func NewClusterDiscoveryService(
|
|||||||
minClusterSize = rqliteManager.config.MinClusterSize
|
minClusterSize = rqliteManager.config.MinClusterSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract WireGuard IP from the raft address (e.g., "10.0.0.1" from "10.0.0.1:7001")
|
||||||
|
wgIP := ""
|
||||||
|
if host, _, err := net.SplitHostPort(raftAddress); err == nil {
|
||||||
|
wgIP = host
|
||||||
|
}
|
||||||
|
|
||||||
return &ClusterDiscoveryService{
|
return &ClusterDiscoveryService{
|
||||||
host: h,
|
host: h,
|
||||||
discoveryMgr: discoveryMgr,
|
discoveryMgr: discoveryMgr,
|
||||||
@ -60,8 +73,10 @@ func NewClusterDiscoveryService(
|
|||||||
nodeType: nodeType,
|
nodeType: nodeType,
|
||||||
raftAddress: raftAddress,
|
raftAddress: raftAddress,
|
||||||
httpAddress: httpAddress,
|
httpAddress: httpAddress,
|
||||||
|
wireGuardIP: wgIP,
|
||||||
dataDir: dataDir,
|
dataDir: dataDir,
|
||||||
minClusterSize: minClusterSize,
|
minClusterSize: minClusterSize,
|
||||||
|
lifecycle: lm,
|
||||||
knownPeers: make(map[string]*discovery.RQLiteNodeMetadata),
|
knownPeers: make(map[string]*discovery.RQLiteNodeMetadata),
|
||||||
peerHealth: make(map[string]*PeerHealth),
|
peerHealth: make(map[string]*PeerHealth),
|
||||||
updateInterval: 30 * time.Second,
|
updateInterval: 30 * time.Second,
|
||||||
@ -119,6 +134,25 @@ func (c *ClusterDiscoveryService) Stop() {
|
|||||||
c.logger.Info("Cluster discovery service stopped")
|
c.logger.Info("Cluster discovery service stopped")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Lifecycle returns the node's lifecycle manager.
|
||||||
|
func (c *ClusterDiscoveryService) Lifecycle() *lifecycle.Manager {
|
||||||
|
return c.lifecycle
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerLifecycleState returns the lifecycle state and last-seen time for a
|
||||||
|
// peer identified by its RQLite node ID (raft address). This method implements
|
||||||
|
// the MetadataReader interface used by the health monitor.
|
||||||
|
func (c *ClusterDiscoveryService) GetPeerLifecycleState(nodeID string) (state string, lastSeen time.Time, found bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
peer, ok := c.knownPeers[nodeID]
|
||||||
|
if !ok {
|
||||||
|
return "", time.Time{}, false
|
||||||
|
}
|
||||||
|
return peer.EffectiveLifecycleState(), peer.LastSeen, true
|
||||||
|
}
|
||||||
|
|
||||||
// IsVoter returns true if the given raft address should be a voter
|
// IsVoter returns true if the given raft address should be a voter
|
||||||
// in the default cluster based on the current known peers.
|
// in the default cluster based on the current known peers.
|
||||||
func (c *ClusterDiscoveryService) IsVoter(raftAddress string) bool {
|
func (c *ClusterDiscoveryService) IsVoter(raftAddress string) bool {
|
||||||
|
|||||||
@ -39,6 +39,17 @@ func (c *ClusterDiscoveryService) collectPeerMetadata() []*discovery.RQLiteNodeM
|
|||||||
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
||||||
LastSeen: time.Now(),
|
LastSeen: time.Now(),
|
||||||
ClusterVersion: "1.0",
|
ClusterVersion: "1.0",
|
||||||
|
PeerID: c.host.ID().String(),
|
||||||
|
WireGuardIP: c.wireGuardIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate lifecycle state
|
||||||
|
if c.lifecycle != nil {
|
||||||
|
state, ttl := c.lifecycle.Snapshot()
|
||||||
|
ourMetadata.LifecycleState = string(state)
|
||||||
|
if state == "maintenance" {
|
||||||
|
ourMetadata.MaintenanceTTL = ttl
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.adjustSelfAdvertisedAddresses(ourMetadata) {
|
if c.adjustSelfAdvertisedAddresses(ourMetadata) {
|
||||||
@ -272,7 +283,7 @@ func (c *ClusterDiscoveryService) getPeersJSONUnlocked() []map[string]interface{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// computeVoterSet returns the set of raft addresses that should be voters.
|
// computeVoterSet returns the set of raft addresses that should be voters.
|
||||||
// It sorts addresses by their IP component and selects the first maxVoters.
|
// It sorts addresses by their numeric IP and selects the first maxVoters.
|
||||||
// This is deterministic — all nodes compute the same voter set from the same peer list.
|
// This is deterministic — all nodes compute the same voter set from the same peer list.
|
||||||
func computeVoterSet(raftAddrs []string, maxVoters int) map[string]struct{} {
|
func computeVoterSet(raftAddrs []string, maxVoters int) map[string]struct{} {
|
||||||
sorted := make([]string, len(raftAddrs))
|
sorted := make([]string, len(raftAddrs))
|
||||||
@ -281,7 +292,7 @@ func computeVoterSet(raftAddrs []string, maxVoters int) map[string]struct{} {
|
|||||||
sort.Slice(sorted, func(i, j int) bool {
|
sort.Slice(sorted, func(i, j int) bool {
|
||||||
ipI := extractIPForSort(sorted[i])
|
ipI := extractIPForSort(sorted[i])
|
||||||
ipJ := extractIPForSort(sorted[j])
|
ipJ := extractIPForSort(sorted[j])
|
||||||
return ipI < ipJ
|
return compareIPs(ipI, ipJ)
|
||||||
})
|
})
|
||||||
|
|
||||||
voters := make(map[string]struct{})
|
voters := make(map[string]struct{})
|
||||||
@ -303,6 +314,31 @@ func extractIPForSort(raftAddr string) string {
|
|||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// compareIPs compares two IP strings numerically (not alphabetically).
|
||||||
|
// Alphabetical sort gives wrong results: "10.0.0.10" < "10.0.0.2" alphabetically,
|
||||||
|
// but numerically 10.0.0.2 < 10.0.0.10. This was causing wrong nodes to be
|
||||||
|
// selected as voters (e.g., 10.0.0.1, 10.0.0.10, 10.0.0.11 instead of 10.0.0.1-5).
|
||||||
|
func compareIPs(a, b string) bool {
|
||||||
|
ipA := net.ParseIP(a)
|
||||||
|
ipB := net.ParseIP(b)
|
||||||
|
|
||||||
|
// Fallback to string comparison if parsing fails
|
||||||
|
if ipA == nil || ipB == nil {
|
||||||
|
return a < b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize to 16-byte representation for consistent comparison
|
||||||
|
ipA = ipA.To16()
|
||||||
|
ipB = ipB.To16()
|
||||||
|
|
||||||
|
for i := range ipA {
|
||||||
|
if ipA[i] != ipB[i] {
|
||||||
|
return ipA[i] < ipB[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// IsVoter returns true if the given raft address is in the voter set
|
// IsVoter returns true if the given raft address is in the voter set
|
||||||
// based on the current known peers. Must be called with c.mu held.
|
// based on the current known peers. Must be called with c.mu held.
|
||||||
func (c *ClusterDiscoveryService) IsVoterLocked(raftAddress string) bool {
|
func (c *ClusterDiscoveryService) IsVoterLocked(raftAddress string) bool {
|
||||||
@ -328,6 +364,14 @@ func (c *ClusterDiscoveryService) writePeersJSON() error {
|
|||||||
return c.writePeersJSONWithData(peers)
|
return c.writePeersJSONWithData(peers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writePeersJSONWithData writes the discovery peers file to a SAFE location
|
||||||
|
// outside the raft directory. This is critical: rqlite v8 treats any
|
||||||
|
// peers.json inside <dataDir>/raft/ as a recovery signal and RESETS
|
||||||
|
// the Raft configuration on startup. Writing there on every periodic sync
|
||||||
|
// caused split-brain on every node restart.
|
||||||
|
//
|
||||||
|
// Safe location: <dataDir>/rqlite/discovery-peers.json
|
||||||
|
// Dangerous location: <dataDir>/rqlite/raft/peers.json (only for explicit recovery)
|
||||||
func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error {
|
func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]interface{}) error {
|
||||||
dataDir := os.ExpandEnv(c.dataDir)
|
dataDir := os.ExpandEnv(c.dataDir)
|
||||||
if strings.HasPrefix(dataDir, "~") {
|
if strings.HasPrefix(dataDir, "~") {
|
||||||
@ -338,30 +382,25 @@ func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]inte
|
|||||||
dataDir = filepath.Join(home, dataDir[1:])
|
dataDir = filepath.Join(home, dataDir[1:])
|
||||||
}
|
}
|
||||||
|
|
||||||
rqliteDir := filepath.Join(dataDir, "rqlite", "raft")
|
// Write to <dataDir>/rqlite/ — NOT inside raft/ subdirectory.
|
||||||
|
// rqlite v8 auto-recovers from raft/peers.json on every startup,
|
||||||
|
// which resets the Raft config and causes split-brain.
|
||||||
|
rqliteDir := filepath.Join(dataDir, "rqlite")
|
||||||
|
|
||||||
if err := os.MkdirAll(rqliteDir, 0755); err != nil {
|
if err := os.MkdirAll(rqliteDir, 0755); err != nil {
|
||||||
return fmt.Errorf("failed to create raft directory %s: %w", rqliteDir, err)
|
return fmt.Errorf("failed to create rqlite directory %s: %w", rqliteDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peersFile := filepath.Join(rqliteDir, "peers.json")
|
peersFile := filepath.Join(rqliteDir, "discovery-peers.json")
|
||||||
backupFile := filepath.Join(rqliteDir, "peers.json.backup")
|
|
||||||
|
|
||||||
if _, err := os.Stat(peersFile); err == nil {
|
|
||||||
data, err := os.ReadFile(peersFile)
|
|
||||||
if err == nil {
|
|
||||||
_ = os.WriteFile(backupFile, data, 0644)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := json.MarshalIndent(peers, "", " ")
|
data, err := json.MarshalIndent(peers, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal peers.json: %w", err)
|
return fmt.Errorf("failed to marshal discovery-peers.json: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tempFile := peersFile + ".tmp"
|
tempFile := peersFile + ".tmp"
|
||||||
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
||||||
return fmt.Errorf("failed to write temp peers.json %s: %w", tempFile, err)
|
return fmt.Errorf("failed to write temp discovery-peers.json %s: %w", tempFile, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Rename(tempFile, peersFile); err != nil {
|
if err := os.Rename(tempFile, peersFile); err != nil {
|
||||||
@ -375,7 +414,57 @@ func (c *ClusterDiscoveryService) writePeersJSONWithData(peers []map[string]inte
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Info("peers.json written",
|
c.logger.Debug("discovery-peers.json written",
|
||||||
|
zap.Int("peers", len(peers)),
|
||||||
|
zap.Strings("nodes", nodeIDs))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeRecoveryPeersJSON writes peers.json to the raft directory for
|
||||||
|
// INTENTIONAL cluster recovery only. rqlite v8 will read this file on
|
||||||
|
// startup and reset the Raft configuration accordingly. Only call this
|
||||||
|
// when you explicitly want to trigger Raft recovery.
|
||||||
|
func (c *ClusterDiscoveryService) writeRecoveryPeersJSON(peers []map[string]interface{}) error {
|
||||||
|
dataDir := os.ExpandEnv(c.dataDir)
|
||||||
|
if strings.HasPrefix(dataDir, "~") {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to determine home directory: %w", err)
|
||||||
|
}
|
||||||
|
dataDir = filepath.Join(home, dataDir[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
raftDir := filepath.Join(dataDir, "rqlite", "raft")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(raftDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create raft directory %s: %w", raftDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peersFile := filepath.Join(raftDir, "peers.json")
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(peers, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal recovery peers.json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tempFile := peersFile + ".tmp"
|
||||||
|
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write temp recovery peers.json %s: %w", tempFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(tempFile, peersFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to rename %s to %s: %w", tempFile, peersFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeIDs := make([]string, 0, len(peers))
|
||||||
|
for _, p := range peers {
|
||||||
|
if id, ok := p["id"].(string); ok {
|
||||||
|
nodeIDs = append(nodeIDs, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Warn("RECOVERY peers.json written to raft directory — rqlited will reset Raft config on next startup",
|
||||||
zap.Int("peers", len(peers)),
|
zap.Int("peers", len(peers)),
|
||||||
zap.Strings("nodes", nodeIDs))
|
zap.Strings("nodes", nodeIDs))
|
||||||
|
|
||||||
|
|||||||
@ -128,9 +128,12 @@ func (c *ClusterDiscoveryService) TriggerSync() {
|
|||||||
c.updateClusterMembership()
|
c.updateClusterMembership()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForceWritePeersJSON forces writing peers.json regardless of membership changes
|
// ForceWritePeersJSON writes peers.json to the RAFT directory for intentional
|
||||||
|
// cluster recovery. rqlite v8 will read this on startup and reset its Raft
|
||||||
|
// configuration. Only call this when you explicitly want Raft recovery
|
||||||
|
// (e.g., after clearing raft state or during split-brain recovery).
|
||||||
func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
|
func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
|
||||||
c.logger.Info("Force writing peers.json")
|
c.logger.Info("Force writing recovery peers.json to raft directory")
|
||||||
|
|
||||||
metadata := c.collectPeerMetadata()
|
metadata := c.collectPeerMetadata()
|
||||||
|
|
||||||
@ -153,16 +156,17 @@ func (c *ClusterDiscoveryService) ForceWritePeersJSON() error {
|
|||||||
peers := c.getPeersJSONUnlocked()
|
peers := c.getPeersJSONUnlocked()
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if err := c.writePeersJSONWithData(peers); err != nil {
|
// Write to RAFT directory — this is intentional recovery
|
||||||
c.logger.Error("Failed to force write peers.json",
|
if err := c.writeRecoveryPeersJSON(peers); err != nil {
|
||||||
|
c.logger.Error("Failed to force write recovery peers.json",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.String("data_dir", c.dataDir),
|
zap.String("data_dir", c.dataDir),
|
||||||
zap.Int("peers", len(peers)))
|
zap.Int("peers", len(peers)))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Info("peers.json written",
|
// Also update discovery location
|
||||||
zap.Int("peers", len(peers)))
|
_ = c.writePeersJSONWithData(peers)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -179,7 +183,9 @@ func (c *ClusterDiscoveryService) TriggerPeerExchange(ctx context.Context) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore
|
// UpdateOwnMetadata updates our own RQLite metadata in the peerstore.
|
||||||
|
// This is called periodically and after significant state changes (lifecycle
|
||||||
|
// transitions, service status updates) to ensure peers have current info.
|
||||||
func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
|
func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
currentRaftAddr := c.raftAddress
|
currentRaftAddr := c.raftAddress
|
||||||
@ -194,6 +200,17 @@ func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
|
|||||||
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
||||||
LastSeen: time.Now(),
|
LastSeen: time.Now(),
|
||||||
ClusterVersion: "1.0",
|
ClusterVersion: "1.0",
|
||||||
|
PeerID: c.host.ID().String(),
|
||||||
|
WireGuardIP: c.wireGuardIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate lifecycle state from the lifecycle manager
|
||||||
|
if c.lifecycle != nil {
|
||||||
|
state, ttl := c.lifecycle.Snapshot()
|
||||||
|
metadata.LifecycleState = string(state)
|
||||||
|
if state == "maintenance" {
|
||||||
|
metadata.MaintenanceTTL = ttl
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.adjustSelfAdvertisedAddresses(metadata) {
|
if c.adjustSelfAdvertisedAddresses(metadata) {
|
||||||
@ -215,7 +232,41 @@ func (c *ClusterDiscoveryService) UpdateOwnMetadata() {
|
|||||||
|
|
||||||
c.logger.Debug("Metadata updated",
|
c.logger.Debug("Metadata updated",
|
||||||
zap.String("node", metadata.NodeID),
|
zap.String("node", metadata.NodeID),
|
||||||
zap.Uint64("log_index", metadata.RaftLogIndex))
|
zap.Uint64("log_index", metadata.RaftLogIndex),
|
||||||
|
zap.String("lifecycle", metadata.LifecycleState))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvideMetadata builds and returns the current node metadata without storing it.
|
||||||
|
// Implements discovery.MetadataProvider so the MetadataPublisher can call this
|
||||||
|
// on a regular interval and store the result in the peerstore.
|
||||||
|
func (c *ClusterDiscoveryService) ProvideMetadata() *discovery.RQLiteNodeMetadata {
|
||||||
|
c.mu.RLock()
|
||||||
|
currentRaftAddr := c.raftAddress
|
||||||
|
currentHTTPAddr := c.httpAddress
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
metadata := &discovery.RQLiteNodeMetadata{
|
||||||
|
NodeID: currentRaftAddr,
|
||||||
|
RaftAddress: currentRaftAddr,
|
||||||
|
HTTPAddress: currentHTTPAddr,
|
||||||
|
NodeType: c.nodeType,
|
||||||
|
RaftLogIndex: c.rqliteManager.getRaftLogIndex(),
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
ClusterVersion: "1.0",
|
||||||
|
PeerID: c.host.ID().String(),
|
||||||
|
WireGuardIP: c.wireGuardIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.lifecycle != nil {
|
||||||
|
state, ttl := c.lifecycle.Snapshot()
|
||||||
|
metadata.LifecycleState = string(state)
|
||||||
|
if state == "maintenance" {
|
||||||
|
metadata.MaintenanceTTL = ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.adjustSelfAdvertisedAddresses(metadata)
|
||||||
|
return metadata
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreRemotePeerMetadata stores metadata received from a remote peer
|
// StoreRemotePeerMetadata stores metadata received from a remote peer
|
||||||
|
|||||||
@ -83,6 +83,14 @@ func (is *InstanceSpawner) SpawnInstance(ctx context.Context, cfg InstanceConfig
|
|||||||
"-raft-adv-addr", cfg.RaftAdvAddress,
|
"-raft-adv-addr", cfg.RaftAdvAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Raft tuning — match the global node's tuning for consistency
|
||||||
|
args = append(args,
|
||||||
|
"-raft-election-timeout", "5s",
|
||||||
|
"-raft-heartbeat-timeout", "2s",
|
||||||
|
"-raft-apply-timeout", "30s",
|
||||||
|
"-raft-leader-lease-timeout", "5s",
|
||||||
|
)
|
||||||
|
|
||||||
// Add join addresses if not the leader (must be before data directory)
|
// Add join addresses if not the leader (must be before data directory)
|
||||||
if !cfg.IsLeader && len(cfg.JoinAddresses) > 0 {
|
if !cfg.IsLeader && len(cfg.JoinAddresses) > 0 {
|
||||||
for _, addr := range cfg.JoinAddresses {
|
for _, addr := range cfg.JoinAddresses {
|
||||||
|
|||||||
131
pkg/rqlite/leadership.go
Normal file
131
pkg/rqlite/leadership.go
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
package rqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TransferLeadership attempts to transfer Raft leadership to another voter.
|
||||||
|
// Used by both the RQLiteManager (on Stop) and the CLI (pre-upgrade).
|
||||||
|
// Returns nil if this node is not the leader or if transfer succeeds.
|
||||||
|
func TransferLeadership(port int, logger *zap.Logger) error {
|
||||||
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
|
|
||||||
|
// 1. Check if we're the leader
|
||||||
|
statusURL := fmt.Sprintf("http://localhost:%d/status", port)
|
||||||
|
resp, err := client.Get(statusURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query status: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var status RQLiteStatus
|
||||||
|
if err := json.Unmarshal(body, &status); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Store.Raft.State != "Leader" {
|
||||||
|
logger.Debug("Not the leader, skipping transfer", zap.Int("port", port))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("This node is the Raft leader, attempting leadership transfer",
|
||||||
|
zap.Int("port", port),
|
||||||
|
zap.String("leader_id", status.Store.Raft.LeaderID))
|
||||||
|
|
||||||
|
// 2. Find an eligible voter to transfer to
|
||||||
|
nodesURL := fmt.Sprintf("http://localhost:%d/nodes?nonvoters&ver=2&timeout=5s", port)
|
||||||
|
nodesResp, err := client.Get(nodesURL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query nodes: %w", err)
|
||||||
|
}
|
||||||
|
defer nodesResp.Body.Close()
|
||||||
|
|
||||||
|
nodesBody, err := io.ReadAll(nodesResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read nodes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try ver=2 wrapped format, fall back to plain array
|
||||||
|
var nodes RQLiteNodes
|
||||||
|
var wrapped struct {
|
||||||
|
Nodes RQLiteNodes `json:"nodes"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(nodesBody, &wrapped); err == nil && wrapped.Nodes != nil {
|
||||||
|
nodes = wrapped.Nodes
|
||||||
|
} else {
|
||||||
|
_ = json.Unmarshal(nodesBody, &nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find a reachable voter that is NOT us
|
||||||
|
var targetID string
|
||||||
|
for _, n := range nodes {
|
||||||
|
if n.Voter && n.Reachable && n.ID != status.Store.Raft.LeaderID {
|
||||||
|
targetID = n.ID
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetID == "" {
|
||||||
|
logger.Warn("No eligible voter found for leadership transfer — will rely on SIGTERM graceful step-down",
|
||||||
|
zap.Int("port", port))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Attempt transfer via rqlite v8+ API
|
||||||
|
// POST /nodes/<target>/transfer-leadership
|
||||||
|
// If the API doesn't exist (404), fall back to relying on SIGTERM.
|
||||||
|
transferURL := fmt.Sprintf("http://localhost:%d/nodes/%s/transfer-leadership", port, targetID)
|
||||||
|
transferResp, err := client.Post(transferURL, "application/json", nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Leadership transfer request failed, relying on SIGTERM",
|
||||||
|
zap.Error(err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
transferResp.Body.Close()
|
||||||
|
|
||||||
|
if transferResp.StatusCode == http.StatusNotFound {
|
||||||
|
logger.Info("Leadership transfer API not available (rqlite version), relying on SIGTERM")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if transferResp.StatusCode != http.StatusOK {
|
||||||
|
logger.Warn("Leadership transfer returned unexpected status",
|
||||||
|
zap.Int("status", transferResp.StatusCode))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Verify transfer
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
verifyResp, err := client.Get(statusURL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Could not verify transfer (node may have already stepped down)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer verifyResp.Body.Close()
|
||||||
|
|
||||||
|
verifyBody, _ := io.ReadAll(verifyResp.Body)
|
||||||
|
var newStatus RQLiteStatus
|
||||||
|
if err := json.Unmarshal(verifyBody, &newStatus); err == nil {
|
||||||
|
if newStatus.Store.Raft.State != "Leader" {
|
||||||
|
logger.Info("Leadership transferred successfully",
|
||||||
|
zap.String("new_leader", newStatus.Store.Raft.LeaderID),
|
||||||
|
zap.Int("port", port))
|
||||||
|
} else {
|
||||||
|
logger.Warn("Still leader after transfer attempt — will rely on SIGTERM",
|
||||||
|
zap.Int("port", port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -66,6 +66,29 @@ func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string)
|
|||||||
// Kill any orphaned rqlited from a previous crash
|
// Kill any orphaned rqlited from a previous crash
|
||||||
r.killOrphanedRQLite()
|
r.killOrphanedRQLite()
|
||||||
|
|
||||||
|
// Remove stale peers.json from the raft directory to prevent rqlite v8
|
||||||
|
// from triggering automatic Raft recovery on normal restarts.
|
||||||
|
//
|
||||||
|
// Only delete when raft.db EXISTS (normal restart). If raft.db does NOT
|
||||||
|
// exist, peers.json was likely placed intentionally by ForceWritePeersJSON()
|
||||||
|
// as part of a recovery flow (clearRaftState + ForceWritePeersJSON + launch).
|
||||||
|
stalePeersPath := filepath.Join(rqliteDataDir, "raft", "peers.json")
|
||||||
|
raftDBPath := filepath.Join(rqliteDataDir, "raft.db")
|
||||||
|
if _, err := os.Stat(stalePeersPath); err == nil {
|
||||||
|
if _, err := os.Stat(raftDBPath); err == nil {
|
||||||
|
// raft.db exists → this is a normal restart, peers.json is stale
|
||||||
|
r.logger.Warn("Removing stale peers.json from raft directory to prevent accidental recovery",
|
||||||
|
zap.String("path", stalePeersPath))
|
||||||
|
_ = os.Remove(stalePeersPath)
|
||||||
|
_ = os.Remove(stalePeersPath + ".backup")
|
||||||
|
_ = os.Remove(stalePeersPath + ".tmp")
|
||||||
|
} else {
|
||||||
|
// raft.db missing → intentional recovery, keep peers.json for rqlited
|
||||||
|
r.logger.Info("Keeping peers.json in raft directory for intentional cluster recovery",
|
||||||
|
zap.String("path", stalePeersPath))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build RQLite command
|
// Build RQLite command
|
||||||
args := []string{
|
args := []string{
|
||||||
"-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort),
|
"-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort),
|
||||||
@ -90,6 +113,30 @@ func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Raft tuning — higher timeouts suit WireGuard latency
|
||||||
|
raftElection := r.config.RaftElectionTimeout
|
||||||
|
if raftElection == 0 {
|
||||||
|
raftElection = 5 * time.Second
|
||||||
|
}
|
||||||
|
raftHeartbeat := r.config.RaftHeartbeatTimeout
|
||||||
|
if raftHeartbeat == 0 {
|
||||||
|
raftHeartbeat = 2 * time.Second
|
||||||
|
}
|
||||||
|
raftApply := r.config.RaftApplyTimeout
|
||||||
|
if raftApply == 0 {
|
||||||
|
raftApply = 30 * time.Second
|
||||||
|
}
|
||||||
|
raftLeaderLease := r.config.RaftLeaderLeaseTimeout
|
||||||
|
if raftLeaderLease == 0 {
|
||||||
|
raftLeaderLease = 5 * time.Second
|
||||||
|
}
|
||||||
|
args = append(args,
|
||||||
|
"-raft-election-timeout", raftElection.String(),
|
||||||
|
"-raft-heartbeat-timeout", raftHeartbeat.String(),
|
||||||
|
"-raft-apply-timeout", raftApply.String(),
|
||||||
|
"-raft-leader-lease-timeout", raftLeaderLease.String(),
|
||||||
|
)
|
||||||
|
|
||||||
if r.config.RQLiteJoinAddress != "" && !r.hasExistingState(rqliteDataDir) {
|
if r.config.RQLiteJoinAddress != "" && !r.hasExistingState(rqliteDataDir) {
|
||||||
r.logger.Info("First-time join to RQLite cluster", zap.String("join_address", r.config.RQLiteJoinAddress))
|
r.logger.Info("First-time join to RQLite cluster", zap.String("join_address", r.config.RQLiteJoinAddress))
|
||||||
|
|
||||||
|
|||||||
@ -70,6 +70,7 @@ func (r *RQLiteManager) Start(ctx context.Context) error {
|
|||||||
if r.discoveryService != nil {
|
if r.discoveryService != nil {
|
||||||
go r.startHealthMonitoring(ctx)
|
go r.startHealthMonitoring(ctx)
|
||||||
go r.startVoterReconciliation(ctx)
|
go r.startVoterReconciliation(ctx)
|
||||||
|
go r.startOrphanedNodeRecovery(ctx) // C1 fix: recover nodes orphaned by failed voter changes
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start child process watchdog to detect and recover from crashes
|
// Start child process watchdog to detect and recover from crashes
|
||||||
@ -138,21 +139,9 @@ func (r *RQLiteManager) Stop() error {
|
|||||||
// transferLeadershipIfLeader checks if this node is the Raft leader and
|
// transferLeadershipIfLeader checks if this node is the Raft leader and
|
||||||
// requests a leadership transfer to minimize election disruption.
|
// requests a leadership transfer to minimize election disruption.
|
||||||
func (r *RQLiteManager) transferLeadershipIfLeader() {
|
func (r *RQLiteManager) transferLeadershipIfLeader() {
|
||||||
status, err := r.getRQLiteStatus()
|
if err := TransferLeadership(r.config.RQLitePort, r.logger); err != nil {
|
||||||
if err != nil {
|
r.logger.Warn("Leadership transfer failed, relying on SIGTERM", zap.Error(err))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if status.Store.Raft.State != "Leader" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.logger.Info("This node is the Raft leader, requesting leadership transfer before shutdown")
|
|
||||||
|
|
||||||
// RQLite doesn't have a direct leadership transfer API, but we can
|
|
||||||
// signal readiness to step down. The fastest approach is to let the
|
|
||||||
// SIGTERM handler in rqlited handle this — rqlite v8 gracefully
|
|
||||||
// steps down on SIGTERM when possible. We log the state for visibility.
|
|
||||||
r.logger.Info("Leader will transfer on SIGTERM (rqlite built-in behavior)")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupPIDFile removes the PID file on shutdown
|
// cleanupPIDFile removes the PID file on shutdown
|
||||||
|
|||||||
@ -7,15 +7,32 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// voterChangeCooldown is how long to wait after a failed voter change
|
||||||
|
// before retrying the same node.
|
||||||
|
voterChangeCooldown = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// voterReconciler holds voter change cooldown state.
|
||||||
|
type voterReconciler struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cooldowns map[string]time.Time // nodeID → earliest next attempt
|
||||||
|
}
|
||||||
|
|
||||||
// startVoterReconciliation periodically checks and corrects voter/non-voter
|
// startVoterReconciliation periodically checks and corrects voter/non-voter
|
||||||
// assignments. Only takes effect on the leader node. Corrects at most one
|
// assignments. Only takes effect on the leader node. Corrects at most one
|
||||||
// node per cycle to minimize disruption.
|
// node per cycle to minimize disruption.
|
||||||
func (r *RQLiteManager) startVoterReconciliation(ctx context.Context) {
|
func (r *RQLiteManager) startVoterReconciliation(ctx context.Context) {
|
||||||
|
reconciler := &voterReconciler{
|
||||||
|
cooldowns: make(map[string]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
// Wait for cluster to stabilize after startup
|
// Wait for cluster to stabilize after startup
|
||||||
time.Sleep(3 * time.Minute)
|
time.Sleep(3 * time.Minute)
|
||||||
|
|
||||||
@ -27,21 +44,104 @@ func (r *RQLiteManager) startVoterReconciliation(ctx context.Context) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := r.reconcileVoters(); err != nil {
|
if err := r.reconcileVoters(reconciler); err != nil {
|
||||||
r.logger.Debug("Voter reconciliation skipped", zap.Error(err))
|
r.logger.Debug("Voter reconciliation skipped", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// startOrphanedNodeRecovery runs every 5 minutes on the leader. It scans for
|
||||||
|
// nodes that appear in the discovery peer list but NOT in the Raft cluster
|
||||||
|
// (orphaned by a failed remove+rejoin during voter reconciliation). For each
|
||||||
|
// orphaned node, it re-adds them via POST /join. (C1 fix)
|
||||||
|
func (r *RQLiteManager) startOrphanedNodeRecovery(ctx context.Context) {
|
||||||
|
// Wait for cluster to stabilize
|
||||||
|
time.Sleep(5 * time.Minute)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
r.recoverOrphanedNodes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recoverOrphanedNodes finds nodes known to discovery but missing from the
|
||||||
|
// Raft cluster and re-adds them.
|
||||||
|
func (r *RQLiteManager) recoverOrphanedNodes() {
|
||||||
|
if r.discoveryService == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only the leader runs orphan recovery
|
||||||
|
status, err := r.getRQLiteStatus()
|
||||||
|
if err != nil || status.Store.Raft.State != "Leader" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all Raft cluster members
|
||||||
|
raftNodes, err := r.getAllClusterNodes()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
raftNodeSet := make(map[string]bool, len(raftNodes))
|
||||||
|
for _, n := range raftNodes {
|
||||||
|
raftNodeSet[n.ID] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all discovery peers
|
||||||
|
discoveryPeers := r.discoveryService.GetAllPeers()
|
||||||
|
|
||||||
|
for _, peer := range discoveryPeers {
|
||||||
|
if peer.RaftAddress == r.discoverConfig.RaftAdvAddress {
|
||||||
|
continue // skip self
|
||||||
|
}
|
||||||
|
if raftNodeSet[peer.RaftAddress] {
|
||||||
|
continue // already in cluster
|
||||||
|
}
|
||||||
|
|
||||||
|
// This peer is in discovery but not in Raft — it's orphaned
|
||||||
|
r.logger.Warn("Found orphaned node (in discovery but not in Raft cluster), re-adding",
|
||||||
|
zap.String("node_raft_addr", peer.RaftAddress),
|
||||||
|
zap.String("node_id", peer.NodeID))
|
||||||
|
|
||||||
|
// Determine voter status
|
||||||
|
raftAddrs := make([]string, 0, len(discoveryPeers))
|
||||||
|
for _, p := range discoveryPeers {
|
||||||
|
raftAddrs = append(raftAddrs, p.RaftAddress)
|
||||||
|
}
|
||||||
|
voters := computeVoterSet(raftAddrs, MaxDefaultVoters)
|
||||||
|
_, shouldBeVoter := voters[peer.RaftAddress]
|
||||||
|
|
||||||
|
if err := r.joinClusterNode(peer.RaftAddress, peer.RaftAddress, shouldBeVoter); err != nil {
|
||||||
|
r.logger.Error("Failed to re-add orphaned node",
|
||||||
|
zap.String("node", peer.RaftAddress),
|
||||||
|
zap.Bool("voter", shouldBeVoter),
|
||||||
|
zap.Error(err))
|
||||||
|
} else {
|
||||||
|
r.logger.Info("Successfully re-added orphaned node to Raft cluster",
|
||||||
|
zap.String("node", peer.RaftAddress),
|
||||||
|
zap.Bool("voter", shouldBeVoter))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// reconcileVoters compares actual cluster voter assignments (from RQLite's
|
// reconcileVoters compares actual cluster voter assignments (from RQLite's
|
||||||
// /nodes endpoint) against the deterministic desired set (computeVoterSet)
|
// /nodes endpoint) against the deterministic desired set (computeVoterSet)
|
||||||
// and corrects mismatches. Uses remove + re-join since RQLite's /join
|
// and corrects mismatches.
|
||||||
// ignores voter flag changes for existing members.
|
|
||||||
//
|
//
|
||||||
// Safety: only runs on the leader, only when all nodes are reachable,
|
// Improvements over original:
|
||||||
// never demotes the leader, and fixes at most one node per cycle.
|
// - Promotion: tries direct POST /join with voter=true first (no remove needed)
|
||||||
func (r *RQLiteManager) reconcileVoters() error {
|
// - Leader stability: verifies leader is stable before demotion
|
||||||
|
// - Cooldown: skips nodes that recently failed a voter change
|
||||||
|
// - Fixes at most one node per cycle
|
||||||
|
func (r *RQLiteManager) reconcileVoters(reconciler *voterReconciler) error {
|
||||||
// 1. Only the leader reconciles
|
// 1. Only the leader reconciles
|
||||||
status, err := r.getRQLiteStatus()
|
status, err := r.getRQLiteStatus()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -68,14 +168,21 @@ func (r *RQLiteManager) reconcileVoters() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Compute desired voter set from raft addresses
|
// 4. Leader stability: verify term hasn't changed recently
|
||||||
|
// (Re-check status to confirm we're still the stable leader)
|
||||||
|
status2, err := r.getRQLiteStatus()
|
||||||
|
if err != nil || status2.Store.Raft.State != "Leader" || status2.Store.Raft.Term != status.Store.Raft.Term {
|
||||||
|
return fmt.Errorf("leader state changed during reconciliation check")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Compute desired voter set from raft addresses
|
||||||
raftAddrs := make([]string, 0, len(nodes))
|
raftAddrs := make([]string, 0, len(nodes))
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
raftAddrs = append(raftAddrs, n.ID)
|
raftAddrs = append(raftAddrs, n.ID)
|
||||||
}
|
}
|
||||||
desiredVoters := computeVoterSet(raftAddrs, MaxDefaultVoters)
|
desiredVoters := computeVoterSet(raftAddrs, MaxDefaultVoters)
|
||||||
|
|
||||||
// 5. Safety: never demote ourselves (the current leader)
|
// 6. Safety: never demote ourselves (the current leader)
|
||||||
myRaftAddr := status.Store.Raft.LeaderID
|
myRaftAddr := status.Store.Raft.LeaderID
|
||||||
if _, shouldBeVoter := desiredVoters[myRaftAddr]; !shouldBeVoter {
|
if _, shouldBeVoter := desiredVoters[myRaftAddr]; !shouldBeVoter {
|
||||||
r.logger.Warn("Leader is not in computed voter set — skipping reconciliation",
|
r.logger.Warn("Leader is not in computed voter set — skipping reconciliation",
|
||||||
@ -83,10 +190,19 @@ func (r *RQLiteManager) reconcileVoters() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Find one mismatch to fix (one change per cycle)
|
// 7. Find one mismatch to fix (one change per cycle)
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
_, shouldBeVoter := desiredVoters[n.ID]
|
_, shouldBeVoter := desiredVoters[n.ID]
|
||||||
|
|
||||||
|
// Check cooldown
|
||||||
|
reconciler.mu.Lock()
|
||||||
|
cooldownUntil, hasCooldown := reconciler.cooldowns[n.ID]
|
||||||
|
if hasCooldown && time.Now().Before(cooldownUntil) {
|
||||||
|
reconciler.mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reconciler.mu.Unlock()
|
||||||
|
|
||||||
if n.Voter && !shouldBeVoter {
|
if n.Voter && !shouldBeVoter {
|
||||||
// Skip if this is the leader
|
// Skip if this is the leader
|
||||||
if n.ID == myRaftAddr {
|
if n.ID == myRaftAddr {
|
||||||
@ -100,6 +216,9 @@ func (r *RQLiteManager) reconcileVoters() error {
|
|||||||
r.logger.Warn("Failed to demote voter",
|
r.logger.Warn("Failed to demote voter",
|
||||||
zap.String("node_id", n.ID),
|
zap.String("node_id", n.ID),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
|
reconciler.mu.Lock()
|
||||||
|
reconciler.cooldowns[n.ID] = time.Now().Add(voterChangeCooldown)
|
||||||
|
reconciler.mu.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,10 +231,24 @@ func (r *RQLiteManager) reconcileVoters() error {
|
|||||||
r.logger.Info("Promoting non-voter to voter",
|
r.logger.Info("Promoting non-voter to voter",
|
||||||
zap.String("node_id", n.ID))
|
zap.String("node_id", n.ID))
|
||||||
|
|
||||||
|
// Try direct promotion first (POST /join with voter=true)
|
||||||
|
if err := r.joinClusterNode(n.ID, n.ID, true); err == nil {
|
||||||
|
r.logger.Info("Successfully promoted non-voter to voter (direct join)",
|
||||||
|
zap.String("node_id", n.ID))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct join didn't change voter status, fall back to remove+rejoin
|
||||||
|
r.logger.Info("Direct promotion didn't work, trying remove+rejoin",
|
||||||
|
zap.String("node_id", n.ID))
|
||||||
|
|
||||||
if err := r.changeNodeVoterStatus(n.ID, true); err != nil {
|
if err := r.changeNodeVoterStatus(n.ID, true); err != nil {
|
||||||
r.logger.Warn("Failed to promote non-voter",
|
r.logger.Warn("Failed to promote non-voter",
|
||||||
zap.String("node_id", n.ID),
|
zap.String("node_id", n.ID),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
|
reconciler.mu.Lock()
|
||||||
|
reconciler.cooldowns[n.ID] = time.Now().Add(voterChangeCooldown)
|
||||||
|
reconciler.mu.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,13 +263,12 @@ func (r *RQLiteManager) reconcileVoters() error {
|
|||||||
|
|
||||||
// changeNodeVoterStatus changes a node's voter status by removing it from the
|
// changeNodeVoterStatus changes a node's voter status by removing it from the
|
||||||
// cluster and immediately re-adding it with the desired voter flag.
|
// cluster and immediately re-adding it with the desired voter flag.
|
||||||
// This is necessary because RQLite's /join endpoint ignores voter flag changes
|
|
||||||
// for nodes that are already cluster members with the same ID and address.
|
|
||||||
//
|
//
|
||||||
// Safety improvements:
|
// Safety improvements:
|
||||||
// - Pre-check: verify quorum would survive the temporary removal
|
// - Pre-check: verify quorum would survive the temporary removal
|
||||||
// - Rollback: if rejoin fails, attempt to re-add with original status
|
// - Pre-check: verify target node is still reachable
|
||||||
// - Retry: attempt rejoin up to 3 times with backoff
|
// - Rollback: if rejoin fails, attempt to re-add with original status
|
||||||
|
// - Retry: 5 attempts with exponential backoff (2s, 4s, 8s, 15s, 30s)
|
||||||
func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error {
|
func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error {
|
||||||
// Pre-check: if demoting a voter, verify quorum safety
|
// Pre-check: if demoting a voter, verify quorum safety
|
||||||
if !voter {
|
if !voter {
|
||||||
@ -145,34 +277,53 @@ func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error {
|
|||||||
return fmt.Errorf("quorum pre-check: %w", err)
|
return fmt.Errorf("quorum pre-check: %w", err)
|
||||||
}
|
}
|
||||||
voterCount := 0
|
voterCount := 0
|
||||||
|
targetReachable := false
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
if n.Voter && n.Reachable {
|
if n.Voter && n.Reachable {
|
||||||
voterCount++
|
voterCount++
|
||||||
}
|
}
|
||||||
|
if n.ID == nodeID && n.Reachable {
|
||||||
|
targetReachable = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !targetReachable {
|
||||||
|
return fmt.Errorf("target node %s is not reachable, skipping voter change", nodeID)
|
||||||
}
|
}
|
||||||
// After removing this voter, we need (voterCount-1)/2 + 1 for quorum
|
// After removing this voter, we need (voterCount-1)/2 + 1 for quorum
|
||||||
// which means voterCount-1 > (voterCount-1)/2, i.e., voterCount >= 3
|
|
||||||
if voterCount <= 2 {
|
if voterCount <= 2 {
|
||||||
return fmt.Errorf("cannot remove voter: only %d reachable voters, quorum would be lost", voterCount)
|
return fmt.Errorf("cannot remove voter: only %d reachable voters, quorum would be lost", voterCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fresh quorum check immediately before removal
|
||||||
|
nodes, err := r.getAllClusterNodes()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fresh quorum check: %w", err)
|
||||||
|
}
|
||||||
|
for _, n := range nodes {
|
||||||
|
if !n.Reachable {
|
||||||
|
return fmt.Errorf("node %s is unreachable, aborting voter change", n.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Step 1: Remove the node from the cluster
|
// Step 1: Remove the node from the cluster
|
||||||
if err := r.removeClusterNode(nodeID); err != nil {
|
if err := r.removeClusterNode(nodeID); err != nil {
|
||||||
return fmt.Errorf("remove node: %w", err)
|
return fmt.Errorf("remove node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for Raft to commit the configuration change, then rejoin with retries
|
// Wait for Raft to commit the configuration change, then rejoin with retries
|
||||||
|
// Exponential backoff: 2s, 4s, 8s, 15s, 30s
|
||||||
|
backoffs := []time.Duration{2 * time.Second, 4 * time.Second, 8 * time.Second, 15 * time.Second, 30 * time.Second}
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < 3; attempt++ {
|
for attempt, wait := range backoffs {
|
||||||
waitTime := time.Duration(2+attempt*2) * time.Second // 2s, 4s, 6s
|
time.Sleep(wait)
|
||||||
time.Sleep(waitTime)
|
|
||||||
|
|
||||||
if err := r.joinClusterNode(nodeID, nodeID, voter); err != nil {
|
if err := r.joinClusterNode(nodeID, nodeID, voter); err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
r.logger.Warn("Rejoin attempt failed, retrying",
|
r.logger.Warn("Rejoin attempt failed, retrying",
|
||||||
zap.String("node_id", nodeID),
|
zap.String("node_id", nodeID),
|
||||||
zap.Int("attempt", attempt+1),
|
zap.Int("attempt", attempt+1),
|
||||||
|
zap.Int("max_attempts", len(backoffs)),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -187,12 +338,12 @@ func (r *RQLiteManager) changeNodeVoterStatus(nodeID string, voter bool) error {
|
|||||||
|
|
||||||
originalVoter := !voter
|
originalVoter := !voter
|
||||||
if err := r.joinClusterNode(nodeID, nodeID, originalVoter); err != nil {
|
if err := r.joinClusterNode(nodeID, nodeID, originalVoter); err != nil {
|
||||||
r.logger.Error("Rollback also failed — node may be orphaned from cluster",
|
r.logger.Error("Rollback also failed — node may be orphaned (orphan recovery will re-add it)",
|
||||||
zap.String("node_id", nodeID),
|
zap.String("node_id", nodeID),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("rejoin node after 3 attempts: %w", lastErr)
|
return fmt.Errorf("rejoin node after %d attempts: %w", len(backoffs), lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAllClusterNodes queries /nodes?nonvoters&ver=2 to get all cluster members
|
// getAllClusterNodes queries /nodes?nonvoters&ver=2 to get all cluster members
|
||||||
|
|||||||
@ -10,13 +10,34 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
watchdogInterval = 30 * time.Second
|
// watchdogInterval is how often we check if rqlited is alive.
|
||||||
|
watchdogInterval = 30 * time.Second
|
||||||
|
|
||||||
|
// watchdogMaxRestart is the maximum number of restart attempts before giving up.
|
||||||
watchdogMaxRestart = 3
|
watchdogMaxRestart = 3
|
||||||
|
|
||||||
|
// watchdogGracePeriod is how long to wait after a restart before
|
||||||
|
// the watchdog starts checking. This gives rqlited time to rejoin
|
||||||
|
// the Raft cluster — Raft election timeouts + log replay can take
|
||||||
|
// 60-120 seconds after a restart.
|
||||||
|
watchdogGracePeriod = 120 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// startProcessWatchdog monitors the RQLite child process and restarts it if it crashes.
|
// startProcessWatchdog monitors the RQLite child process and restarts it if it crashes.
|
||||||
// It checks both process liveness and HTTP responsiveness.
|
// It only restarts when the process has actually DIED (exited). It does NOT kill
|
||||||
|
// rqlited for being slow to find a leader — that's normal during cluster rejoin.
|
||||||
func (r *RQLiteManager) startProcessWatchdog(ctx context.Context) {
|
func (r *RQLiteManager) startProcessWatchdog(ctx context.Context) {
|
||||||
|
// Wait for the grace period before starting to monitor.
|
||||||
|
// rqlited needs time to:
|
||||||
|
// 1. Open the raft log and snapshots
|
||||||
|
// 2. Reconnect to existing Raft peers
|
||||||
|
// 3. Either rejoin as follower or participate in a new election
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(watchdogGracePeriod):
|
||||||
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(watchdogInterval)
|
ticker := time.NewTicker(watchdogInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@ -46,10 +67,21 @@ func (r *RQLiteManager) startProcessWatchdog(ctx context.Context) {
|
|||||||
restartCount++
|
restartCount++
|
||||||
r.logger.Info("RQLite process restarted by watchdog",
|
r.logger.Info("RQLite process restarted by watchdog",
|
||||||
zap.Int("restart_count", restartCount))
|
zap.Int("restart_count", restartCount))
|
||||||
|
|
||||||
|
// Give the restarted process time to stabilize before checking again
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(watchdogGracePeriod):
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Process is alive — check HTTP responsiveness
|
// Process is alive — reset restart counter on sustained health
|
||||||
if !r.isHTTPResponsive() {
|
if r.isHTTPResponsive() {
|
||||||
r.logger.Warn("RQLite process is alive but not responding to HTTP")
|
if restartCount > 0 {
|
||||||
|
r.logger.Info("RQLite process has stabilized, resetting restart counter",
|
||||||
|
zap.Int("previous_restart_count", restartCount))
|
||||||
|
restartCount = 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
208
scripts/clean-testnet.sh
Executable file
208
scripts/clean-testnet.sh
Executable file
@ -0,0 +1,208 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# Clean all testnet nodes for fresh reinstall.
|
||||||
|
# Preserves Anyone relay keys (/var/lib/anon/) for --anyone-migrate.
|
||||||
|
# DOES NOT TOUCH DEVNET NODES.
|
||||||
|
#
|
||||||
|
# Usage: scripts/clean-testnet.sh [--nuclear]
|
||||||
|
# --nuclear Also remove shared binaries (rqlited, ipfs, coredns, caddy, etc.)
|
||||||
|
#
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
CONF="$ROOT_DIR/scripts/remote-nodes.conf"
|
||||||
|
|
||||||
|
[[ -f "$CONF" ]] || { echo "ERROR: Missing $CONF"; exit 1; }
|
||||||
|
command -v sshpass >/dev/null 2>&1 || { echo "ERROR: sshpass not installed (brew install sshpass / apt install sshpass)"; exit 1; }
|
||||||
|
|
||||||
|
NUCLEAR=false
|
||||||
|
[[ "${1:-}" == "--nuclear" ]] && NUCLEAR=true
|
||||||
|
|
||||||
|
SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=10 -o LogLevel=ERROR -o PubkeyAuthentication=no)
|
||||||
|
|
||||||
|
# ── Cleanup script (runs as root on each remote node) ─────────────────────
|
||||||
|
# Uses a quoted heredoc so NO local variable expansion happens.
|
||||||
|
# This script is uploaded to /tmp/orama-clean.sh and executed remotely.
|
||||||
|
CLEANUP_SCRIPT=$(cat <<'SCRIPT_END'
|
||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
echo " Stopping services..."
|
||||||
|
systemctl stop debros-node debros-gateway debros-ipfs debros-ipfs-cluster debros-olric debros-anyone-relay debros-anyone-client coredns caddy 2>/dev/null || true
|
||||||
|
systemctl disable debros-node debros-gateway debros-ipfs debros-ipfs-cluster debros-olric debros-anyone-relay debros-anyone-client coredns caddy 2>/dev/null || true
|
||||||
|
|
||||||
|
echo " Removing systemd service files..."
|
||||||
|
rm -f /etc/systemd/system/debros-*.service
|
||||||
|
rm -f /etc/systemd/system/coredns.service
|
||||||
|
rm -f /etc/systemd/system/caddy.service
|
||||||
|
rm -f /etc/systemd/system/orama-deploy-*.service
|
||||||
|
systemctl daemon-reload
|
||||||
|
|
||||||
|
echo " Tearing down WireGuard..."
|
||||||
|
systemctl stop wg-quick@wg0 2>/dev/null || true
|
||||||
|
wg-quick down wg0 2>/dev/null || true
|
||||||
|
systemctl disable wg-quick@wg0 2>/dev/null || true
|
||||||
|
rm -f /etc/wireguard/wg0.conf
|
||||||
|
|
||||||
|
echo " Resetting UFW firewall..."
|
||||||
|
ufw --force reset
|
||||||
|
ufw allow 22/tcp
|
||||||
|
ufw --force enable
|
||||||
|
|
||||||
|
echo " Killing debros processes..."
|
||||||
|
pkill -u debros 2>/dev/null || true
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
echo " Removing debros user and data..."
|
||||||
|
userdel -r debros 2>/dev/null || true
|
||||||
|
rm -rf /home/debros
|
||||||
|
|
||||||
|
echo " Removing sudoers files..."
|
||||||
|
rm -f /etc/sudoers.d/debros-access
|
||||||
|
rm -f /etc/sudoers.d/debros-deployments
|
||||||
|
rm -f /etc/sudoers.d/debros-wireguard
|
||||||
|
|
||||||
|
echo " Removing CoreDNS and Caddy configs..."
|
||||||
|
rm -rf /etc/coredns
|
||||||
|
rm -rf /etc/caddy
|
||||||
|
rm -rf /var/lib/caddy
|
||||||
|
|
||||||
|
echo " Cleaning temp files..."
|
||||||
|
rm -f /tmp/orama /tmp/network-source.tar.gz /tmp/network-source.zip
|
||||||
|
rm -rf /tmp/network-extract /tmp/coredns-build /tmp/caddy-build
|
||||||
|
|
||||||
|
# Nuclear: also remove shared binaries
|
||||||
|
if [ "${1:-}" = "--nuclear" ]; then
|
||||||
|
echo " Removing shared binaries (nuclear)..."
|
||||||
|
rm -f /usr/local/bin/rqlited
|
||||||
|
rm -f /usr/local/bin/ipfs
|
||||||
|
rm -f /usr/local/bin/ipfs-cluster-service
|
||||||
|
rm -f /usr/local/bin/olric-server
|
||||||
|
rm -f /usr/local/bin/coredns
|
||||||
|
rm -f /usr/local/bin/xcaddy
|
||||||
|
rm -f /usr/bin/caddy
|
||||||
|
rm -f /usr/local/bin/orama
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Verify Anyone relay keys are preserved
|
||||||
|
if [ -d /var/lib/anon/keys ]; then
|
||||||
|
echo " Anyone relay keys PRESERVED at /var/lib/anon/keys"
|
||||||
|
if [ -f /var/lib/anon/fingerprint ]; then
|
||||||
|
fp=$(cat /var/lib/anon/fingerprint 2>/dev/null || true)
|
||||||
|
echo " Relay fingerprint: $fp"
|
||||||
|
fi
|
||||||
|
if [ -f /var/lib/anon/wallet ]; then
|
||||||
|
wallet=$(cat /var/lib/anon/wallet 2>/dev/null || true)
|
||||||
|
echo " Relay wallet: $wallet"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo " WARNING: No Anyone relay keys found at /var/lib/anon/"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " DONE"
|
||||||
|
SCRIPT_END
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Parse testnet nodes only ──────────────────────────────────────────────
|
||||||
|
hosts=()
|
||||||
|
passes=()
|
||||||
|
users=()
|
||||||
|
|
||||||
|
while IFS='|' read -r env hostspec pass role key; do
|
||||||
|
[[ -z "$env" || "$env" == \#* ]] && continue
|
||||||
|
env="${env%%#*}"
|
||||||
|
env="$(echo "$env" | xargs)"
|
||||||
|
[[ "$env" != "testnet" ]] && continue
|
||||||
|
|
||||||
|
hosts+=("$hostspec")
|
||||||
|
passes+=("$pass")
|
||||||
|
users+=("${hostspec%%@*}")
|
||||||
|
done < "$CONF"
|
||||||
|
|
||||||
|
if [[ ${#hosts[@]} -eq 0 ]]; then
|
||||||
|
echo "ERROR: No testnet nodes found in $CONF"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "== clean-testnet.sh — ${#hosts[@]} testnet nodes =="
|
||||||
|
for i in "${!hosts[@]}"; do
|
||||||
|
echo " [$((i+1))] ${hosts[$i]}"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
echo "This will CLEAN all testnet nodes (stop services, remove data)."
|
||||||
|
echo "Anyone relay keys (/var/lib/anon/) will be PRESERVED."
|
||||||
|
echo "Devnet nodes will NOT be touched."
|
||||||
|
$NUCLEAR && echo "Nuclear mode: shared binaries will also be removed."
|
||||||
|
echo ""
|
||||||
|
read -rp "Type 'yes' to continue: " confirm
|
||||||
|
if [[ "$confirm" != "yes" ]]; then
|
||||||
|
echo "Aborted."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Execute cleanup on each node ──────────────────────────────────────────
|
||||||
|
failed=()
|
||||||
|
succeeded=0
|
||||||
|
NUCLEAR_FLAG=""
|
||||||
|
$NUCLEAR && NUCLEAR_FLAG="--nuclear"
|
||||||
|
|
||||||
|
for i in "${!hosts[@]}"; do
|
||||||
|
h="${hosts[$i]}"
|
||||||
|
p="${passes[$i]}"
|
||||||
|
u="${users[$i]}"
|
||||||
|
echo ""
|
||||||
|
echo "== [$((i+1))/${#hosts[@]}] Cleaning $h =="
|
||||||
|
|
||||||
|
# Step 1: Upload cleanup script
|
||||||
|
# No -n flag here — we're piping the script content via stdin
|
||||||
|
if ! echo "$CLEANUP_SCRIPT" | sshpass -p "$p" ssh "${SSH_OPTS[@]}" "$h" \
|
||||||
|
"cat > /tmp/orama-clean.sh && chmod +x /tmp/orama-clean.sh" 2>&1; then
|
||||||
|
echo " !! FAILED to upload script to $h"
|
||||||
|
failed+=("$h")
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 2: Execute the cleanup script as root
|
||||||
|
if [[ "$u" == "root" ]]; then
|
||||||
|
# Root: run directly
|
||||||
|
if ! sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" \
|
||||||
|
"bash /tmp/orama-clean.sh $NUCLEAR_FLAG; rm -f /tmp/orama-clean.sh" 2>&1; then
|
||||||
|
echo " !! FAILED: $h"
|
||||||
|
failed+=("$h")
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
# Non-root: escape password for single-quote embedding, pipe to sudo -S
|
||||||
|
escaped_p=$(printf '%s' "$p" | sed "s/'/'\\\\''/g")
|
||||||
|
if ! sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" \
|
||||||
|
"printf '%s\n' '${escaped_p}' | sudo -S bash /tmp/orama-clean.sh $NUCLEAR_FLAG; rm -f /tmp/orama-clean.sh" 2>&1; then
|
||||||
|
echo " !! FAILED: $h"
|
||||||
|
failed+=("$h")
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " OK: $h cleaned"
|
||||||
|
((succeeded++)) || true
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "========================================"
|
||||||
|
echo "Cleanup complete: $succeeded succeeded, ${#failed[@]} failed"
|
||||||
|
if [[ ${#failed[@]} -gt 0 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "Failed nodes:"
|
||||||
|
for f in "${failed[@]}"; do
|
||||||
|
echo " - $f"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
echo "Troubleshooting:"
|
||||||
|
echo " 1. Check connectivity: ssh <user>@<host>"
|
||||||
|
echo " 2. Check password in remote-nodes.conf"
|
||||||
|
echo " 3. Try cleaning manually: docs/CLEAN_NODE.md"
|
||||||
|
fi
|
||||||
|
echo ""
|
||||||
|
echo "Anyone relay keys preserved at /var/lib/anon/ on all nodes."
|
||||||
|
echo "Use --anyone-migrate during install to reuse existing relay identity."
|
||||||
|
echo "========================================"
|
||||||
289
scripts/recover-rqlite.sh
Normal file
289
scripts/recover-rqlite.sh
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# Recover RQLite cluster from split-brain.
|
||||||
|
#
|
||||||
|
# Strategy:
|
||||||
|
# 1. Stop debros-node on ALL nodes simultaneously
|
||||||
|
# 2. Keep raft/ data ONLY on the node with the highest commit index (leader candidate)
|
||||||
|
# 3. Delete raft/ on all other nodes (they'll join fresh via -join)
|
||||||
|
# 4. Start the leader candidate first, wait for it to become Leader
|
||||||
|
# 5. Start all other nodes — they discover the leader via LibP2P and join
|
||||||
|
# 6. Verify cluster health
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# scripts/recover-rqlite.sh --devnet --leader 57.129.7.232
|
||||||
|
# scripts/recover-rqlite.sh --testnet --leader <ip>
|
||||||
|
#
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ── Parse flags ──────────────────────────────────────────────────────────────
|
||||||
|
ENV=""
|
||||||
|
LEADER_HOST=""
|
||||||
|
|
||||||
|
for arg in "$@"; do
|
||||||
|
case "$arg" in
|
||||||
|
--devnet) ENV="devnet" ;;
|
||||||
|
--testnet) ENV="testnet" ;;
|
||||||
|
--leader=*) LEADER_HOST="${arg#--leader=}" ;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: scripts/recover-rqlite.sh --devnet|--testnet --leader=<public_ip_or_user@host>"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown flag: $arg" >&2
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ -z "$ENV" ]]; then
|
||||||
|
echo "ERROR: specify --devnet or --testnet" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ -z "$LEADER_HOST" ]]; then
|
||||||
|
echo "ERROR: specify --leader=<host> (the node with highest commit index)" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Paths ────────────────────────────────────────────────────────────────────
|
||||||
|
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
CONF="$ROOT_DIR/scripts/remote-nodes.conf"
|
||||||
|
|
||||||
|
die() { echo "ERROR: $*" >&2; exit 1; }
|
||||||
|
[[ -f "$CONF" ]] || die "Missing $CONF"
|
||||||
|
|
||||||
|
# ── Load nodes from conf ────────────────────────────────────────────────────
|
||||||
|
HOSTS=()
|
||||||
|
PASSES=()
|
||||||
|
ROLES=()
|
||||||
|
SSH_KEYS=()
|
||||||
|
|
||||||
|
while IFS='|' read -r env host pass role key; do
|
||||||
|
[[ -z "$env" || "$env" == \#* ]] && continue
|
||||||
|
env="${env%%#*}"
|
||||||
|
env="$(echo "$env" | xargs)"
|
||||||
|
[[ "$env" != "$ENV" ]] && continue
|
||||||
|
|
||||||
|
HOSTS+=("$host")
|
||||||
|
PASSES+=("$pass")
|
||||||
|
ROLES+=("${role:-node}")
|
||||||
|
SSH_KEYS+=("${key:-}")
|
||||||
|
done < "$CONF"
|
||||||
|
|
||||||
|
if [[ ${#HOSTS[@]} -eq 0 ]]; then
|
||||||
|
die "No nodes found for environment '$ENV' in $CONF"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "== recover-rqlite.sh ($ENV) — ${#HOSTS[@]} nodes =="
|
||||||
|
echo "Leader candidate: $LEADER_HOST"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Find leader index
|
||||||
|
LEADER_IDX=-1
|
||||||
|
for i in "${!HOSTS[@]}"; do
|
||||||
|
if [[ "${HOSTS[$i]}" == *"$LEADER_HOST"* ]]; then
|
||||||
|
LEADER_IDX=$i
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ $LEADER_IDX -eq -1 ]]; then
|
||||||
|
die "Leader host '$LEADER_HOST' not found in node list"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Nodes:"
|
||||||
|
for i in "${!HOSTS[@]}"; do
|
||||||
|
marker=""
|
||||||
|
[[ $i -eq $LEADER_IDX ]] && marker=" ← LEADER (keep data)"
|
||||||
|
echo " [$i] ${HOSTS[$i]} (${ROLES[$i]})$marker"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ── SSH helpers ──────────────────────────────────────────────────────────────
|
||||||
|
SSH_OPTS=(-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=10)
|
||||||
|
|
||||||
|
node_ssh() {
|
||||||
|
local idx="$1"
|
||||||
|
shift
|
||||||
|
local h="${HOSTS[$idx]}"
|
||||||
|
local p="${PASSES[$idx]}"
|
||||||
|
local k="${SSH_KEYS[$idx]:-}"
|
||||||
|
|
||||||
|
if [[ -n "$k" ]]; then
|
||||||
|
local expanded_key="${k/#\~/$HOME}"
|
||||||
|
if [[ -f "$expanded_key" ]]; then
|
||||||
|
ssh -i "$expanded_key" "${SSH_OPTS[@]}" "$h" "$@" 2>/dev/null
|
||||||
|
return $?
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
sshpass -p "$p" ssh -n "${SSH_OPTS[@]}" "$h" "$@" 2>/dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Confirmation ─────────────────────────────────────────────────────────────
|
||||||
|
echo "⚠️ THIS WILL:"
|
||||||
|
echo " 1. Stop debros-node on ALL ${#HOSTS[@]} nodes"
|
||||||
|
echo " 2. DELETE raft/ data on ${#HOSTS[@]}-1 nodes (backup to /tmp/rqlite-raft-backup/)"
|
||||||
|
echo " 3. Keep raft/ data ONLY on ${HOSTS[$LEADER_IDX]} (leader candidate)"
|
||||||
|
echo " 4. Restart all nodes to reform the cluster"
|
||||||
|
echo ""
|
||||||
|
read -r -p "Continue? [y/N] " confirm
|
||||||
|
if [[ "$confirm" != "y" && "$confirm" != "Y" ]]; then
|
||||||
|
echo "Aborted."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
RAFT_DIR="/home/debros/.orama/data/rqlite/raft"
|
||||||
|
BACKUP_DIR="/tmp/rqlite-raft-backup"
|
||||||
|
|
||||||
|
# ── Phase 1: Stop debros-node on ALL nodes ───────────────────────────────────
|
||||||
|
echo "== Phase 1: Stopping debros-node on all ${#HOSTS[@]} nodes =="
|
||||||
|
failed=()
|
||||||
|
for i in "${!HOSTS[@]}"; do
|
||||||
|
h="${HOSTS[$i]}"
|
||||||
|
p="${PASSES[$i]}"
|
||||||
|
echo -n " Stopping $h ... "
|
||||||
|
if node_ssh "$i" "printf '%s\n' '$p' | sudo -S systemctl stop debros-node 2>&1 && echo STOPPED"; then
|
||||||
|
echo ""
|
||||||
|
else
|
||||||
|
echo "FAILED"
|
||||||
|
failed+=("$h")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ ${#failed[@]} -gt 0 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "⚠️ ${#failed[@]} nodes failed to stop. Attempting kill..."
|
||||||
|
for i in "${!HOSTS[@]}"; do
|
||||||
|
h="${HOSTS[$i]}"
|
||||||
|
p="${PASSES[$i]}"
|
||||||
|
for fh in "${failed[@]}"; do
|
||||||
|
if [[ "$h" == "$fh" ]]; then
|
||||||
|
node_ssh "$i" "printf '%s\n' '$p' | sudo -S killall -9 orama-node rqlited 2>/dev/null; echo KILLED" || true
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Waiting 5s for processes to fully stop..."
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
# ── Phase 2: Backup and delete raft/ on non-leader nodes ────────────────────
|
||||||
|
echo "== Phase 2: Clearing raft state on non-leader nodes =="
|
||||||
|
for i in "${!HOSTS[@]}"; do
|
||||||
|
[[ $i -eq $LEADER_IDX ]] && continue
|
||||||
|
|
||||||
|
h="${HOSTS[$i]}"
|
||||||
|
p="${PASSES[$i]}"
|
||||||
|
echo -n " Clearing $h ... "
|
||||||
|
if node_ssh "$i" "
|
||||||
|
printf '%s\n' '$p' | sudo -S bash -c '
|
||||||
|
rm -rf $BACKUP_DIR
|
||||||
|
if [ -d $RAFT_DIR ]; then
|
||||||
|
cp -r $RAFT_DIR $BACKUP_DIR 2>/dev/null || true
|
||||||
|
rm -rf $RAFT_DIR
|
||||||
|
echo \"CLEARED (backup at $BACKUP_DIR)\"
|
||||||
|
else
|
||||||
|
echo \"NO_RAFT_DIR (nothing to clear)\"
|
||||||
|
fi
|
||||||
|
'
|
||||||
|
"; then
|
||||||
|
true
|
||||||
|
else
|
||||||
|
echo "FAILED"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Leader node ${HOSTS[$LEADER_IDX]} raft/ data preserved."
|
||||||
|
|
||||||
|
# ── Phase 3: Start leader node ──────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "== Phase 3: Starting leader node (${HOSTS[$LEADER_IDX]}) =="
|
||||||
|
lp="${PASSES[$LEADER_IDX]}"
|
||||||
|
node_ssh "$LEADER_IDX" "printf '%s\n' '$lp' | sudo -S systemctl start debros-node" || die "Failed to start leader node"
|
||||||
|
|
||||||
|
echo " Waiting for leader to become Leader..."
|
||||||
|
max_wait=120
|
||||||
|
elapsed=0
|
||||||
|
while [[ $elapsed -lt $max_wait ]]; do
|
||||||
|
state=$(node_ssh "$LEADER_IDX" "curl -s --max-time 3 http://localhost:5001/status 2>/dev/null | python3 -c \"import sys,json; d=json.load(sys.stdin); print(d.get('store',{}).get('raft',{}).get('state',''))\" 2>/dev/null" || echo "")
|
||||||
|
if [[ "$state" == "Leader" ]]; then
|
||||||
|
echo " ✓ Leader node is Leader after ${elapsed}s"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo " ... state=$state (${elapsed}s / ${max_wait}s)"
|
||||||
|
sleep 5
|
||||||
|
((elapsed+=5))
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ "$state" != "Leader" ]]; then
|
||||||
|
echo " ⚠️ Leader did not become Leader within ${max_wait}s (state=$state)"
|
||||||
|
echo " The node may need more time. Continuing anyway..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Phase 4: Start all other nodes ──────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "== Phase 4: Starting remaining nodes =="
|
||||||
|
|
||||||
|
# Start non-leader nodes in batches of 3 with 15s between batches
|
||||||
|
batch_size=3
|
||||||
|
batch_count=0
|
||||||
|
for i in "${!HOSTS[@]}"; do
|
||||||
|
[[ $i -eq $LEADER_IDX ]] && continue
|
||||||
|
|
||||||
|
h="${HOSTS[$i]}"
|
||||||
|
p="${PASSES[$i]}"
|
||||||
|
echo -n " Starting $h ... "
|
||||||
|
if node_ssh "$i" "printf '%s\n' '$p' | sudo -S systemctl start debros-node && echo STARTED"; then
|
||||||
|
true
|
||||||
|
else
|
||||||
|
echo "FAILED"
|
||||||
|
fi
|
||||||
|
|
||||||
|
((batch_count++))
|
||||||
|
if [[ $((batch_count % batch_size)) -eq 0 ]]; then
|
||||||
|
echo " (waiting 15s between batches for cluster stability)"
|
||||||
|
sleep 15
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# ── Phase 5: Wait and verify ────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "== Phase 5: Waiting for cluster to form (120s) =="
|
||||||
|
sleep 30
|
||||||
|
echo " ... 30s"
|
||||||
|
sleep 30
|
||||||
|
echo " ... 60s"
|
||||||
|
sleep 30
|
||||||
|
echo " ... 90s"
|
||||||
|
sleep 30
|
||||||
|
echo " ... 120s"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "== Cluster status =="
|
||||||
|
for i in "${!HOSTS[@]}"; do
|
||||||
|
h="${HOSTS[$i]}"
|
||||||
|
result=$(node_ssh "$i" "curl -s --max-time 5 http://localhost:5001/status 2>/dev/null | python3 -c \"
|
||||||
|
import sys,json
|
||||||
|
try:
|
||||||
|
d=json.load(sys.stdin)
|
||||||
|
r=d.get('store',{}).get('raft',{})
|
||||||
|
n=d.get('store',{}).get('num_nodes','?')
|
||||||
|
print(f'state={r.get(\"state\",\"?\")} commit={r.get(\"commit_index\",\"?\")} leader={r.get(\"leader\",{}).get(\"node_id\",\"?\")} nodes={n}')
|
||||||
|
except:
|
||||||
|
print('NO_RESPONSE')
|
||||||
|
\" 2>/dev/null" || echo "SSH_FAILED")
|
||||||
|
marker=""
|
||||||
|
[[ $i -eq $LEADER_IDX ]] && marker=" ← LEADER"
|
||||||
|
echo " ${HOSTS[$i]}: $result$marker"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "== Recovery complete =="
|
||||||
|
echo ""
|
||||||
|
echo "Next steps:"
|
||||||
|
echo " 1. Run 'scripts/inspect.sh --devnet' to verify full cluster health"
|
||||||
|
echo " 2. If some nodes show Candidate state, give them more time (up to 5 min)"
|
||||||
|
echo " 3. If nodes fail to join, check /home/debros/.orama/logs/rqlite-node.log on the node"
|
||||||
@ -3,18 +3,19 @@
|
|||||||
# Redeploy to all nodes in a given environment (devnet or testnet).
|
# Redeploy to all nodes in a given environment (devnet or testnet).
|
||||||
# Reads node credentials from scripts/remote-nodes.conf.
|
# Reads node credentials from scripts/remote-nodes.conf.
|
||||||
#
|
#
|
||||||
# Flow (per docs/DEV_DEPLOY.md):
|
# Flow:
|
||||||
# 1) make build-linux
|
# 1) make build-linux-all
|
||||||
# 2) scripts/generate-source-archive.sh -> /tmp/network-source.tar.gz
|
# 2) scripts/generate-source-archive.sh -> /tmp/network-source.tar.gz
|
||||||
# 3) scp archive + extract-deploy.sh + conf to hub node
|
# 3) scp archive + extract-deploy.sh + conf to hub node
|
||||||
# 4) from hub: sshpass scp to all other nodes + sudo bash /tmp/extract-deploy.sh
|
# 4) from hub: sshpass scp to all other nodes + sudo bash /tmp/extract-deploy.sh
|
||||||
# 5) rolling: upgrade followers one-by-one, leader last
|
# 5) rolling upgrade: followers first, leader last
|
||||||
|
# per node: pre-upgrade -> stop -> extract binary -> post-upgrade
|
||||||
#
|
#
|
||||||
# Usage:
|
# Usage:
|
||||||
# scripts/redeploy.sh --devnet
|
# scripts/redeploy.sh --devnet
|
||||||
# scripts/redeploy.sh --testnet
|
# scripts/redeploy.sh --testnet
|
||||||
# scripts/redeploy.sh --devnet --no-build
|
# scripts/redeploy.sh --devnet --no-build
|
||||||
# scripts/redeploy.sh --testnet --no-build
|
# scripts/redeploy.sh --devnet --skip-build
|
||||||
#
|
#
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
@ -26,14 +27,14 @@ for arg in "$@"; do
|
|||||||
case "$arg" in
|
case "$arg" in
|
||||||
--devnet) ENV="devnet" ;;
|
--devnet) ENV="devnet" ;;
|
||||||
--testnet) ENV="testnet" ;;
|
--testnet) ENV="testnet" ;;
|
||||||
--no-build) NO_BUILD=1 ;;
|
--no-build|--skip-build) NO_BUILD=1 ;;
|
||||||
-h|--help)
|
-h|--help)
|
||||||
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build]"
|
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build|--skip-build]"
|
||||||
exit 0
|
exit 0
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown flag: $arg" >&2
|
echo "Unknown flag: $arg" >&2
|
||||||
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build]" >&2
|
echo "Usage: scripts/redeploy.sh --devnet|--testnet [--no-build|--skip-build]" >&2
|
||||||
exit 1
|
exit 1
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
@ -106,9 +107,9 @@ echo "Hub: $HUB_HOST (idx=$HUB_IDX, key=${HUB_KEY:-none})"
|
|||||||
|
|
||||||
# ── Build ────────────────────────────────────────────────────────────────────
|
# ── Build ────────────────────────────────────────────────────────────────────
|
||||||
if [[ "$NO_BUILD" -eq 0 ]]; then
|
if [[ "$NO_BUILD" -eq 0 ]]; then
|
||||||
echo "== build-linux =="
|
echo "== build-linux-all =="
|
||||||
(cd "$ROOT_DIR" && make build-linux) || {
|
(cd "$ROOT_DIR" && make build-linux-all) || {
|
||||||
echo "WARN: make build-linux failed; continuing if existing bin-linux is acceptable."
|
echo "WARN: make build-linux-all failed; continuing if existing bin-linux is acceptable."
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
echo "== skipping build (--no-build) =="
|
echo "== skipping build (--no-build) =="
|
||||||
@ -192,12 +193,16 @@ if [[ ${#hosts[@]} -gt 0 ]] && ! command -v sshpass >/dev/null 2>&1; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo "== fan-out: upload to ${#hosts[@]} nodes =="
|
echo "== fan-out: upload to ${#hosts[@]} nodes =="
|
||||||
|
upload_failed=()
|
||||||
for i in "${!hosts[@]}"; do
|
for i in "${!hosts[@]}"; do
|
||||||
h="${hosts[$i]}"
|
h="${hosts[$i]}"
|
||||||
p="${passes[$i]}"
|
p="${passes[$i]}"
|
||||||
echo " -> $h"
|
echo " -> $h"
|
||||||
sshpass -p "$p" scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
if ! sshpass -p "$p" scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
||||||
"$TAR" "$EX" "$h":/tmp/
|
"$TAR" "$EX" "$h":/tmp/; then
|
||||||
|
echo " !! UPLOAD FAILED: $h"
|
||||||
|
upload_failed+=("$h")
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "== extract on all fan-out nodes =="
|
echo "== extract on all fan-out nodes =="
|
||||||
@ -205,10 +210,22 @@ for i in "${!hosts[@]}"; do
|
|||||||
h="${hosts[$i]}"
|
h="${hosts[$i]}"
|
||||||
p="${passes[$i]}"
|
p="${passes[$i]}"
|
||||||
echo " -> $h"
|
echo " -> $h"
|
||||||
sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
||||||
"$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK"
|
"$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK"; then
|
||||||
|
echo " !! EXTRACT FAILED: $h"
|
||||||
|
upload_failed+=("$h")
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
|
if [[ ${#upload_failed[@]} -gt 0 ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "WARNING: ${#upload_failed[@]} nodes had upload/extract failures:"
|
||||||
|
for uf in "${upload_failed[@]}"; do
|
||||||
|
echo " - $uf"
|
||||||
|
done
|
||||||
|
echo "Continuing with rolling restart..."
|
||||||
|
fi
|
||||||
|
|
||||||
echo "== extract on hub =="
|
echo "== extract on hub =="
|
||||||
printf '%s\n' "$hub_pass" | sudo -S bash "$EX" >/tmp/extract.log 2>&1
|
printf '%s\n' "$hub_pass" | sudo -S bash "$EX" >/tmp/extract.log 2>&1
|
||||||
|
|
||||||
@ -253,44 +270,131 @@ if [[ -z "$leader" ]]; then
|
|||||||
fi
|
fi
|
||||||
echo "Leader: $leader"
|
echo "Leader: $leader"
|
||||||
|
|
||||||
|
failed_nodes=()
|
||||||
|
|
||||||
|
# ── Per-node upgrade flow ──
|
||||||
|
# Uses pre-upgrade (maintenance + leadership transfer + propagation wait)
|
||||||
|
# then stops, deploys binary, and post-upgrade (start + health verification).
|
||||||
upgrade_one() {
|
upgrade_one() {
|
||||||
local h="$1" p="$2"
|
local h="$1" p="$2"
|
||||||
echo "== upgrade $h =="
|
echo "== upgrade $h =="
|
||||||
sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
|
||||||
"$h" "printf '%s\n' '$p' | sudo -S orama prod stop && printf '%s\n' '$p' | sudo -S orama upgrade --no-pull --pre-built --restart"
|
# 1. Pre-upgrade: enter maintenance, transfer leadership, wait for propagation
|
||||||
|
echo " [1/4] pre-upgrade..."
|
||||||
|
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
||||||
|
"$h" "printf '%s\n' '$p' | sudo -S orama prod pre-upgrade" 2>&1; then
|
||||||
|
echo " !! pre-upgrade failed on $h (continuing with stop)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 2. Stop all services
|
||||||
|
echo " [2/4] stopping services..."
|
||||||
|
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
||||||
|
"$h" "printf '%s\n' '$p' | sudo -S systemctl stop 'debros-*'" 2>&1; then
|
||||||
|
echo " !! stop failed on $h"
|
||||||
|
failed_nodes+=("$h")
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3. Deploy new binary
|
||||||
|
echo " [3/4] deploying binary..."
|
||||||
|
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
||||||
|
"$h" "printf '%s\n' '$p' | sudo -S bash /tmp/extract-deploy.sh >/tmp/extract.log 2>&1 && echo OK" 2>&1; then
|
||||||
|
echo " !! extract failed on $h"
|
||||||
|
failed_nodes+=("$h")
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. Post-upgrade: start services, verify health, exit maintenance
|
||||||
|
echo " [4/4] post-upgrade..."
|
||||||
|
if ! sshpass -p "$p" ssh -n -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \
|
||||||
|
"$h" "printf '%s\n' '$p' | sudo -S orama prod post-upgrade" 2>&1; then
|
||||||
|
echo " !! post-upgrade failed on $h"
|
||||||
|
failed_nodes+=("$h")
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " OK: $h"
|
||||||
}
|
}
|
||||||
|
|
||||||
upgrade_hub() {
|
upgrade_hub() {
|
||||||
echo "== upgrade hub (localhost) =="
|
echo "== upgrade hub (localhost) =="
|
||||||
printf '%s\n' "$hub_pass" | sudo -S orama prod stop
|
|
||||||
printf '%s\n' "$hub_pass" | sudo -S orama upgrade --no-pull --pre-built --restart
|
# 1. Pre-upgrade
|
||||||
|
echo " [1/4] pre-upgrade..."
|
||||||
|
if ! (printf '%s\n' "$hub_pass" | sudo -S orama prod pre-upgrade) 2>&1; then
|
||||||
|
echo " !! pre-upgrade failed on hub (continuing with stop)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 2. Stop all services
|
||||||
|
echo " [2/4] stopping services..."
|
||||||
|
if ! (printf '%s\n' "$hub_pass" | sudo -S systemctl stop 'debros-*') 2>&1; then
|
||||||
|
echo " !! stop failed on hub ($hub_host)"
|
||||||
|
failed_nodes+=("$hub_host (hub)")
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 3. Deploy new binary
|
||||||
|
echo " [3/4] deploying binary..."
|
||||||
|
if ! (printf '%s\n' "$hub_pass" | sudo -S bash "$EX" >/tmp/extract.log 2>&1); then
|
||||||
|
echo " !! extract failed on hub ($hub_host)"
|
||||||
|
failed_nodes+=("$hub_host (hub)")
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 4. Post-upgrade
|
||||||
|
echo " [4/4] post-upgrade..."
|
||||||
|
if ! (printf '%s\n' "$hub_pass" | sudo -S orama prod post-upgrade) 2>&1; then
|
||||||
|
echo " !! post-upgrade failed on hub ($hub_host)"
|
||||||
|
failed_nodes+=("$hub_host (hub)")
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo " OK: hub ($hub_host)"
|
||||||
}
|
}
|
||||||
|
|
||||||
echo "== rolling upgrade (followers first) =="
|
echo "== rolling upgrade (followers first, leader last) =="
|
||||||
for i in "${!hosts[@]}"; do
|
for i in "${!hosts[@]}"; do
|
||||||
h="${hosts[$i]}"
|
h="${hosts[$i]}"
|
||||||
p="${passes[$i]}"
|
p="${passes[$i]}"
|
||||||
[[ "$h" == "$leader" ]] && continue
|
[[ "$h" == "$leader" ]] && continue
|
||||||
upgrade_one "$h" "$p"
|
upgrade_one "$h" "$p" || true
|
||||||
done
|
done
|
||||||
|
|
||||||
# Upgrade hub if not the leader
|
# Upgrade hub if not the leader
|
||||||
if [[ "$leader" != "HUB" ]]; then
|
if [[ "$leader" != "HUB" ]]; then
|
||||||
upgrade_hub
|
upgrade_hub || true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Upgrade leader last
|
# Upgrade leader last
|
||||||
echo "== upgrade leader last =="
|
echo "== upgrade leader last =="
|
||||||
if [[ "$leader" == "HUB" ]]; then
|
if [[ "$leader" == "HUB" ]]; then
|
||||||
upgrade_hub
|
upgrade_hub || true
|
||||||
else
|
else
|
||||||
upgrade_one "$leader" "$leader_pass"
|
upgrade_one "$leader" "$leader_pass" || true
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Clean up conf from hub
|
# Clean up conf from hub
|
||||||
rm -f "$CONF"
|
rm -f "$CONF"
|
||||||
|
|
||||||
echo "Done."
|
# ── Report results ──
|
||||||
|
echo ""
|
||||||
|
echo "========================================"
|
||||||
|
if [[ ${#failed_nodes[@]} -gt 0 ]]; then
|
||||||
|
echo "UPGRADE COMPLETED WITH FAILURES (${#failed_nodes[@]} nodes failed):"
|
||||||
|
for fn in "${failed_nodes[@]}"; do
|
||||||
|
echo " FAILED: $fn"
|
||||||
|
done
|
||||||
|
echo ""
|
||||||
|
echo "Recommended actions:"
|
||||||
|
echo " 1. SSH into the failed node(s)"
|
||||||
|
echo " 2. Check logs: sudo orama prod logs node --follow"
|
||||||
|
echo " 3. Manually run: sudo orama prod post-upgrade"
|
||||||
|
echo "========================================"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "All nodes upgraded successfully."
|
||||||
|
echo "========================================"
|
||||||
|
fi
|
||||||
REMOTE
|
REMOTE
|
||||||
|
|
||||||
echo "== complete =="
|
echo "== complete =="
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user