diff --git a/pkg/cli/monitor/collector.go b/pkg/cli/monitor/collector.go index 1e7ec53..1742667 100644 --- a/pkg/cli/monitor/collector.go +++ b/pkg/cli/monitor/collector.go @@ -9,6 +9,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/cli/production/report" "github.com/DeBrosOfficial/network/pkg/cli/remotessh" + "github.com/DeBrosOfficial/network/pkg/cli/sandbox" "github.com/DeBrosOfficial/network/pkg/inspector" ) @@ -23,22 +24,9 @@ type CollectorConfig struct { // CollectOnce runs `sudo orama node report --json` on all matching nodes // in parallel and returns a ClusterSnapshot. func CollectOnce(ctx context.Context, cfg CollectorConfig) (*ClusterSnapshot, error) { - nodes, err := inspector.LoadNodes(cfg.ConfigPath) + nodes, cleanup, err := loadNodes(cfg) if err != nil { - return nil, fmt.Errorf("load nodes: %w", err) - } - nodes = inspector.FilterByEnv(nodes, cfg.Env) - if cfg.NodeFilter != "" { - nodes = filterByHost(nodes, cfg.NodeFilter) - } - if len(nodes) == 0 { - return nil, fmt.Errorf("no nodes found for env %q", cfg.Env) - } - - // Prepare wallet-derived SSH keys - cleanup, err := remotessh.PrepareNodeKeys(nodes) - if err != nil { - return nil, fmt.Errorf("prepare SSH keys: %w", err) + return nil, err } defer cleanup() @@ -121,3 +109,61 @@ func truncate(s string, maxLen int) string { } return s[:maxLen] + "..." } + +// loadNodes resolves the node list and SSH keys based on the environment. +// For "sandbox", nodes are loaded from the active sandbox state file with +// the sandbox SSH key already set. For other environments, nodes come from +// nodes.conf and use wallet-derived SSH keys. +func loadNodes(cfg CollectorConfig) ([]inspector.Node, func(), error) { + noop := func() {} + + if cfg.Env == "sandbox" { + return loadSandboxNodes(cfg) + } + + nodes, err := inspector.LoadNodes(cfg.ConfigPath) + if err != nil { + return nil, noop, fmt.Errorf("load nodes: %w", err) + } + nodes = inspector.FilterByEnv(nodes, cfg.Env) + if cfg.NodeFilter != "" { + nodes = filterByHost(nodes, cfg.NodeFilter) + } + if len(nodes) == 0 { + return nil, noop, fmt.Errorf("no nodes found for env %q", cfg.Env) + } + + cleanup, err := remotessh.PrepareNodeKeys(nodes) + if err != nil { + return nil, noop, fmt.Errorf("prepare SSH keys: %w", err) + } + return nodes, cleanup, nil +} + +// loadSandboxNodes loads nodes from the active sandbox state file. +func loadSandboxNodes(cfg CollectorConfig) ([]inspector.Node, func(), error) { + noop := func() {} + + sbxCfg, err := sandbox.LoadConfig() + if err != nil { + return nil, noop, fmt.Errorf("load sandbox config: %w", err) + } + + state, err := sandbox.FindActiveSandbox() + if err != nil { + return nil, noop, fmt.Errorf("find active sandbox: %w", err) + } + if state == nil { + return nil, noop, fmt.Errorf("no active sandbox found") + } + + nodes := state.ToNodes(sbxCfg.ExpandedPrivateKeyPath()) + if cfg.NodeFilter != "" { + nodes = filterByHost(nodes, cfg.NodeFilter) + } + if len(nodes) == 0 { + return nil, noop, fmt.Errorf("no nodes found for sandbox %q", state.Name) + } + + return nodes, noop, nil +} diff --git a/pkg/cli/sandbox/create.go b/pkg/cli/sandbox/create.go index cac3bd0..29434ac 100644 --- a/pkg/cli/sandbox/create.go +++ b/pkg/cli/sandbox/create.go @@ -6,7 +6,6 @@ import ( "os/exec" "path/filepath" "strings" - "sync" "time" "github.com/DeBrosOfficial/network/pkg/cli/remotessh" @@ -257,71 +256,8 @@ func phase3UploadArchive(cfg *Config, state *SandboxState) error { fmt.Printf(" Archive: %s (%s)\n", filepath.Base(archivePath), formatBytes(info.Size())) sshKeyPath := cfg.ExpandedPrivateKeyPath() - remotePath := "/tmp/" + filepath.Base(archivePath) - extractCmd := fmt.Sprintf("mkdir -p /opt/orama && tar xzf %s -C /opt/orama && rm -f %s", - remotePath, remotePath) - - // Step 1: Upload from local machine to genesis node - genesis := state.Servers[0] - genesisNode := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath} - - fmt.Printf(" Uploading to %s (genesis)...\n", genesis.Name) - if err := remotessh.UploadFile(genesisNode, archivePath, remotePath, remotessh.WithNoHostKeyCheck()); err != nil { - return fmt.Errorf("upload to %s: %w", genesis.Name, err) - } - - // Step 2: Fan out from genesis to remaining nodes in parallel (server-to-server) - if len(state.Servers) > 1 { - fmt.Printf(" Fanning out from %s to %d nodes...\n", genesis.Name, len(state.Servers)-1) - - // Temporarily upload SSH key to genesis for server-to-server SCP - remoteKeyPath := "/tmp/.sandbox_key" - if err := remotessh.UploadFile(genesisNode, sshKeyPath, remoteKeyPath, remotessh.WithNoHostKeyCheck()); err != nil { - return fmt.Errorf("upload SSH key to genesis: %w", err) - } - // Always clean up the temporary key, even on panic/early return - defer remotessh.RunSSHStreaming(genesisNode, fmt.Sprintf("rm -f %s", remoteKeyPath), remotessh.WithNoHostKeyCheck()) - - if err := remotessh.RunSSHStreaming(genesisNode, fmt.Sprintf("chmod 600 %s", remoteKeyPath), remotessh.WithNoHostKeyCheck()); err != nil { - return fmt.Errorf("chmod SSH key on genesis: %w", err) - } - - var wg sync.WaitGroup - errs := make([]error, len(state.Servers)) - - for i := 1; i < len(state.Servers); i++ { - wg.Add(1) - go func(idx int, srv ServerState) { - defer wg.Done() - // SCP from genesis to target using the uploaded key - scpCmd := fmt.Sprintf("scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i %s %s root@%s:%s", - remoteKeyPath, remotePath, srv.IP, remotePath) - if err := remotessh.RunSSHStreaming(genesisNode, scpCmd, remotessh.WithNoHostKeyCheck()); err != nil { - errs[idx] = fmt.Errorf("fanout to %s: %w", srv.Name, err) - return - } - // Extract on target - targetNode := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath} - if err := remotessh.RunSSHStreaming(targetNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil { - errs[idx] = fmt.Errorf("extract on %s: %w", srv.Name, err) - return - } - fmt.Printf(" Distributed to %s\n", srv.Name) - }(i, state.Servers[i]) - } - wg.Wait() - - for _, err := range errs { - if err != nil { - return err - } - } - } - - // Step 3: Extract on genesis - fmt.Printf(" Extracting on %s...\n", genesis.Name) - if err := remotessh.RunSSHStreaming(genesisNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil { - return fmt.Errorf("extract on %s: %w", genesis.Name, err) + if err := fanoutArchive(state.Servers, sshKeyPath, archivePath); err != nil { + return err } fmt.Println(" All nodes ready") diff --git a/pkg/cli/sandbox/fanout.go b/pkg/cli/sandbox/fanout.go new file mode 100644 index 0000000..be9fc16 --- /dev/null +++ b/pkg/cli/sandbox/fanout.go @@ -0,0 +1,84 @@ +package sandbox + +import ( + "fmt" + "path/filepath" + "sync" + + "github.com/DeBrosOfficial/network/pkg/cli/remotessh" + "github.com/DeBrosOfficial/network/pkg/inspector" +) + +// fanoutArchive uploads a binary archive to the first server, then fans out +// server-to-server in parallel to all remaining servers. This is much faster +// than uploading from the local machine to each node individually. +// After distribution, the archive is extracted on all nodes. +func fanoutArchive(servers []ServerState, sshKeyPath, archivePath string) error { + remotePath := "/tmp/" + filepath.Base(archivePath) + extractCmd := fmt.Sprintf("mkdir -p /opt/orama && tar xzf %s -C /opt/orama && rm -f %s", + remotePath, remotePath) + + // Step 1: Upload from local machine to first node + first := servers[0] + firstNode := inspector.Node{User: "root", Host: first.IP, SSHKey: sshKeyPath} + + fmt.Printf(" Uploading to %s...\n", first.Name) + if err := remotessh.UploadFile(firstNode, archivePath, remotePath, remotessh.WithNoHostKeyCheck()); err != nil { + return fmt.Errorf("upload to %s: %w", first.Name, err) + } + + // Step 2: Fan out from first node to remaining nodes in parallel (server-to-server) + if len(servers) > 1 { + fmt.Printf(" Fanning out from %s to %d nodes...\n", first.Name, len(servers)-1) + + // Temporarily upload SSH key for server-to-server SCP + remoteKeyPath := "/tmp/.sandbox_key" + if err := remotessh.UploadFile(firstNode, sshKeyPath, remoteKeyPath, remotessh.WithNoHostKeyCheck()); err != nil { + return fmt.Errorf("upload SSH key to %s: %w", first.Name, err) + } + defer remotessh.RunSSHStreaming(firstNode, fmt.Sprintf("rm -f %s", remoteKeyPath), remotessh.WithNoHostKeyCheck()) + + if err := remotessh.RunSSHStreaming(firstNode, fmt.Sprintf("chmod 600 %s", remoteKeyPath), remotessh.WithNoHostKeyCheck()); err != nil { + return fmt.Errorf("chmod SSH key on %s: %w", first.Name, err) + } + + var wg sync.WaitGroup + errs := make([]error, len(servers)) + + for i := 1; i < len(servers); i++ { + wg.Add(1) + go func(idx int, srv ServerState) { + defer wg.Done() + // SCP from first node to target + scpCmd := fmt.Sprintf("scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i %s %s root@%s:%s", + remoteKeyPath, remotePath, srv.IP, remotePath) + if err := remotessh.RunSSHStreaming(firstNode, scpCmd, remotessh.WithNoHostKeyCheck()); err != nil { + errs[idx] = fmt.Errorf("fanout to %s: %w", srv.Name, err) + return + } + // Extract on target + targetNode := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath} + if err := remotessh.RunSSHStreaming(targetNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil { + errs[idx] = fmt.Errorf("extract on %s: %w", srv.Name, err) + return + } + fmt.Printf(" Distributed to %s\n", srv.Name) + }(i, servers[i]) + } + wg.Wait() + + for _, err := range errs { + if err != nil { + return err + } + } + } + + // Step 3: Extract on first node + fmt.Printf(" Extracting on %s...\n", first.Name) + if err := remotessh.RunSSHStreaming(firstNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil { + return fmt.Errorf("extract on %s: %w", first.Name, err) + } + + return nil +} diff --git a/pkg/cli/sandbox/rollout.go b/pkg/cli/sandbox/rollout.go index ac186ee..396e8f4 100644 --- a/pkg/cli/sandbox/rollout.go +++ b/pkg/cli/sandbox/rollout.go @@ -34,24 +34,10 @@ func Rollout(name string) error { info, _ := os.Stat(archivePath) fmt.Printf("Archive: %s (%s)\n\n", filepath.Base(archivePath), formatBytes(info.Size())) - // Step 2: Push archive to all nodes + // Step 2: Push archive to all nodes (upload to first, fan out server-to-server) fmt.Println("Pushing archive to all nodes...") - remotePath := "/tmp/" + filepath.Base(archivePath) - - for i, srv := range state.Servers { - node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath} - - fmt.Printf(" [%d/%d] Uploading to %s...\n", i+1, len(state.Servers), srv.Name) - if err := remotessh.UploadFile(node, archivePath, remotePath, remotessh.WithNoHostKeyCheck()); err != nil { - return fmt.Errorf("upload to %s: %w", srv.Name, err) - } - - // Extract archive - extractCmd := fmt.Sprintf("mkdir -p /opt/orama && tar xzf %s -C /opt/orama && rm -f %s", - remotePath, remotePath) - if err := remotessh.RunSSHStreaming(node, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil { - return fmt.Errorf("extract on %s: %w", srv.Name, err) - } + if err := fanoutArchive(state.Servers, sshKeyPath, archivePath); err != nil { + return err } // Step 3: Rolling upgrade — followers first, leader last @@ -103,10 +89,22 @@ func findLeaderIndex(state *SandboxState, sshKeyPath string) int { } // upgradeNode performs `orama node upgrade --restart` on a single node. +// It pre-replaces the orama CLI binary before running the upgrade command +// to avoid ETXTBSY ("text file busy") errors when the old binary doesn't +// have the os.Remove fix in copyBinary(). func upgradeNode(srv ServerState, sshKeyPath string, current, total int) error { node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath} fmt.Printf(" [%d/%d] Upgrading %s (%s)...\n", current, total, srv.Name, srv.IP) + + // Pre-replace the orama CLI so the upgrade runs the NEW binary (with ETXTBSY fix). + // rm unlinks the old inode (kernel keeps it alive for the running process), + // cp creates a fresh inode at the same path. + preReplace := "rm -f /usr/local/bin/orama && cp /opt/orama/bin/orama /usr/local/bin/orama" + if err := remotessh.RunSSHStreaming(node, preReplace, remotessh.WithNoHostKeyCheck()); err != nil { + return fmt.Errorf("pre-replace orama binary on %s: %w", srv.Name, err) + } + if err := remotessh.RunSSHStreaming(node, "orama node upgrade --restart", remotessh.WithNoHostKeyCheck()); err != nil { return fmt.Errorf("upgrade %s: %w", srv.Name, err) } diff --git a/pkg/client/database_client.go b/pkg/client/database_client.go index cd8a85c..dc209d3 100644 --- a/pkg/client/database_client.go +++ b/pkg/client/database_client.go @@ -9,6 +9,31 @@ import ( "github.com/rqlite/gorqlite" ) +// safeWriteOne wraps gorqlite's WriteOneParameterized to recover from panics. +// gorqlite's WriteOne* functions access wra[0] without checking if the slice +// is empty, which panics when the server returns an error (e.g. "leader not found") +// with no result rows. +func safeWriteOne(conn *gorqlite.Connection, stmt gorqlite.ParameterizedStatement) (wr gorqlite.WriteResult, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("rqlite write failed (recovered panic): %v", r) + } + }() + wr, err = conn.WriteOneParameterized(stmt) + return +} + +// safeWriteOneRaw wraps gorqlite's WriteOne to recover from panics. +func safeWriteOneRaw(conn *gorqlite.Connection, sql string) (wr gorqlite.WriteResult, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("rqlite write failed (recovered panic): %v", r) + } + }() + wr, err = conn.WriteOne(sql) + return +} + // DatabaseClientImpl implements DatabaseClient type DatabaseClientImpl struct { client *Client @@ -79,7 +104,7 @@ func (d *DatabaseClientImpl) Query(ctx context.Context, sql string, args ...inte if isWriteOperation { // Execute write operation with parameters - _, err := conn.WriteOneParameterized(gorqlite.ParameterizedStatement{ + _, err := safeWriteOne(conn, gorqlite.ParameterizedStatement{ Query: sql, Arguments: args, }) @@ -293,7 +318,7 @@ func (d *DatabaseClientImpl) Transaction(ctx context.Context, queries []string) // Execute all queries in the transaction success := true for _, query := range queries { - _, err := conn.WriteOne(query) + _, err := safeWriteOneRaw(conn, query) if err != nil { lastErr = err success = false @@ -321,7 +346,7 @@ func (d *DatabaseClientImpl) CreateTable(ctx context.Context, schema string) err } return d.withRetry(func(conn *gorqlite.Connection) error { - _, err := conn.WriteOne(schema) + _, err := safeWriteOneRaw(conn, schema) return err }) } @@ -334,7 +359,7 @@ func (d *DatabaseClientImpl) DropTable(ctx context.Context, tableName string) er return d.withRetry(func(conn *gorqlite.Connection) error { dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName) - _, err := conn.WriteOne(dropSQL) + _, err := safeWriteOneRaw(conn, dropSQL) return err }) } diff --git a/pkg/client/database_client_test.go b/pkg/client/database_client_test.go new file mode 100644 index 0000000..31de01b --- /dev/null +++ b/pkg/client/database_client_test.go @@ -0,0 +1,82 @@ +package client + +import ( + "fmt" + "testing" + + "github.com/rqlite/gorqlite" +) + +// mockPanicConnection simulates what gorqlite does when WriteParameterized +// returns an empty slice: accessing [0] panics. +func simulateGorqlitePanic() (gorqlite.WriteResult, error) { + var empty []gorqlite.WriteResult + return empty[0], fmt.Errorf("leader not found") // panics +} + +func TestSafeWriteOne_recoversPanic(t *testing.T) { + // We can't easily create a real gorqlite.Connection that panics, + // but we can verify our recovery wrapper works by testing the + // recovery pattern directly. + var recovered bool + func() { + defer func() { + if r := recover(); r != nil { + recovered = true + } + }() + simulateGorqlitePanic() + }() + + if !recovered { + t.Fatal("expected simulateGorqlitePanic to panic, but it didn't") + } +} + +func TestSafeWriteOne_nilConnection(t *testing.T) { + // safeWriteOne with nil connection should recover from panic, not crash. + _, err := safeWriteOne(nil, gorqlite.ParameterizedStatement{ + Query: "INSERT INTO test (a) VALUES (?)", + Arguments: []interface{}{"x"}, + }) + if err == nil { + t.Fatal("expected error from nil connection, got nil") + } +} + +func TestSafeWriteOneRaw_nilConnection(t *testing.T) { + // safeWriteOneRaw with nil connection should recover from panic, not crash. + _, err := safeWriteOneRaw(nil, "INSERT INTO test (a) VALUES ('x')") + if err == nil { + t.Fatal("expected error from nil connection, got nil") + } +} + +func TestIsWriteOperation(t *testing.T) { + d := &DatabaseClientImpl{} + + tests := []struct { + sql string + isWrite bool + }{ + {"INSERT INTO foo VALUES (1)", true}, + {" INSERT INTO foo VALUES (1)", true}, + {"UPDATE foo SET a = 1", true}, + {"DELETE FROM foo", true}, + {"CREATE TABLE foo (a TEXT)", true}, + {"DROP TABLE foo", true}, + {"ALTER TABLE foo ADD COLUMN b TEXT", true}, + {"SELECT * FROM foo", false}, + {" SELECT * FROM foo", false}, + {"EXPLAIN SELECT * FROM foo", false}, + } + + for _, tt := range tests { + t.Run(tt.sql, func(t *testing.T) { + got := d.isWriteOperation(tt.sql) + if got != tt.isWrite { + t.Errorf("isWriteOperation(%q) = %v, want %v", tt.sql, got, tt.isWrite) + } + }) + } +} diff --git a/pkg/environments/production/orchestrator.go b/pkg/environments/production/orchestrator.go index fce62b0..7458c75 100644 --- a/pkg/environments/production/orchestrator.go +++ b/pkg/environments/production/orchestrator.go @@ -997,6 +997,13 @@ func (ps *ProductionSetup) Phase6SetupWireGuard(isFirstNode bool) (privateKey, p } ps.logf(" ✓ WireGuard keypair generated") + // Save public key to orama secrets so the gateway (running as orama user) + // can read it without needing root access to /etc/wireguard/wg0.conf + pubKeyPath := filepath.Join(ps.oramaDir, "secrets", "wg-public-key") + if err := os.WriteFile(pubKeyPath, []byte(pubKey), 0600); err != nil { + return "", "", fmt.Errorf("failed to save WG public key: %w", err) + } + if isFirstNode { // First node: self-assign 10.0.0.1, no peers yet wp.config = WireGuardConfig{ diff --git a/pkg/environments/production/prebuilt.go b/pkg/environments/production/prebuilt.go index 6bbcba2..a04fe4f 100644 --- a/pkg/environments/production/prebuilt.go +++ b/pkg/environments/production/prebuilt.go @@ -291,12 +291,20 @@ func (ps *ProductionSetup) installAnyonFromPreBuilt() error { } // copyBinary copies a file from src to dest, preserving executable permissions. +// It removes the destination first to avoid ETXTBSY ("text file busy") errors +// when overwriting a binary that is currently running. func copyBinary(src, dest string) error { // Ensure parent directory exists if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { return err } + // Remove the old binary first. On Linux, if the binary is running, + // rm unlinks the filename while the kernel keeps the inode alive for + // the running process. Writing a new file at the same path creates a + // fresh inode — no ETXTBSY conflict. + _ = os.Remove(dest) + srcFile, err := os.Open(src) if err != nil { return err diff --git a/pkg/environments/production/services.go b/pkg/environments/production/services.go index 24c2a37..3eca7e0 100644 --- a/pkg/environments/production/services.go +++ b/pkg/environments/production/services.go @@ -213,6 +213,7 @@ Requires=wg-quick@wg0.service [Service] Type=simple %[5]s +AmbientCapabilities=CAP_NET_ADMIN ReadWritePaths=%[2]s WorkingDirectory=%[1]s Environment=HOME=%[1]s diff --git a/pkg/gateway/handlers/join/handler.go b/pkg/gateway/handlers/join/handler.go index 301b39b..678c82f 100644 --- a/pkg/gateway/handlers/join/handler.go +++ b/pkg/gateway/handlers/join/handler.go @@ -2,8 +2,10 @@ package join import ( "context" + "encoding/base64" "encoding/json" "fmt" + "net" "net/http" "os" "os/exec" @@ -100,6 +102,24 @@ func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) { return } + // Validate public IP format + if net.ParseIP(req.PublicIP) == nil || net.ParseIP(req.PublicIP).To4() == nil { + http.Error(w, "public_ip must be a valid IPv4 address", http.StatusBadRequest) + return + } + + // Validate WireGuard public key: must be base64-encoded 32 bytes (Curve25519) + // Also reject control characters (newlines) to prevent config injection + if strings.ContainsAny(req.WGPublicKey, "\n\r") { + http.Error(w, "wg_public_key contains invalid characters", http.StatusBadRequest) + return + } + wgKeyBytes, err := base64.StdEncoding.DecodeString(req.WGPublicKey) + if err != nil || len(wgKeyBytes) != 32 { + http.Error(w, "wg_public_key must be a valid base64-encoded 32-byte key", http.StatusBadRequest) + return + } + ctx := r.Context() // 1. Validate and consume the invite token (atomic single-use) @@ -177,7 +197,15 @@ func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) { olricEncryptionKey = strings.TrimSpace(string(data)) } - // 7. Get all WG peers + // 7. Get this node's WG IP (needed before peer list to check self-inclusion) + myWGIP, err := h.getMyWGIP() + if err != nil { + h.logger.Error("failed to get local WG IP", zap.Error(err)) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + // 8. Get all WG peers wgPeers, err := h.getWGPeers(ctx, req.WGPublicKey) if err != nil { h.logger.Error("failed to list WG peers", zap.Error(err)) @@ -185,12 +213,29 @@ func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) { return } - // 8. Get this node's WG IP - myWGIP, err := h.getMyWGIP() - if err != nil { - h.logger.Error("failed to get local WG IP", zap.Error(err)) - http.Error(w, "internal error", http.StatusInternalServerError) - return + // Ensure this node (the join handler's host) is in the peer list. + // On a fresh genesis node, the WG sync loop may not have self-registered + // into wireguard_peers yet, causing 0 peers to be returned. + if !wgPeersContainsIP(wgPeers, myWGIP) { + myPubKey, err := h.getMyWGPublicKey() + if err != nil { + h.logger.Error("failed to get local WG public key", zap.Error(err)) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + myPublicIP, err := h.getMyPublicIP() + if err != nil { + h.logger.Error("failed to get local public IP", zap.Error(err)) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + wgPeers = append([]WGPeerInfo{{ + PublicKey: myPubKey, + Endpoint: fmt.Sprintf("%s:%d", myPublicIP, 51820), + AllowedIP: fmt.Sprintf("%s/32", myWGIP), + }}, wgPeers...) + h.logger.Info("self-injected into WG peer list (sync loop hasn't registered yet)", + zap.String("wg_ip", myWGIP)) } // 9. Query IPFS and IPFS Cluster peer info @@ -346,6 +391,17 @@ func (h *Handler) addWGPeerLocally(pubKey, publicIP, wgIP string) error { return nil } +// wgPeersContainsIP checks if any peer in the list has the given WG IP +func wgPeersContainsIP(peers []WGPeerInfo, wgIP string) bool { + target := fmt.Sprintf("%s/32", wgIP) + for _, p := range peers { + if p.AllowedIP == target { + return true + } + } + return false +} + // getWGPeers returns all WG peers except the requesting node func (h *Handler) getWGPeers(ctx context.Context, excludePubKey string) ([]WGPeerInfo, error) { type peerRow struct { @@ -403,6 +459,32 @@ func (h *Handler) getMyWGIP() (string, error) { return "", fmt.Errorf("could not find wg0 IP address") } +// getMyWGPublicKey reads the local WireGuard public key from the orama secrets +// directory. The key is saved there during install by Phase6SetupWireGuard. +// This avoids needing root/CAP_NET_ADMIN permissions that `wg show wg0` requires. +func (h *Handler) getMyWGPublicKey() (string, error) { + data, err := os.ReadFile(h.oramaDir + "/secrets/wg-public-key") + if err != nil { + return "", fmt.Errorf("failed to read WG public key from %s/secrets/wg-public-key: %w", h.oramaDir, err) + } + key := strings.TrimSpace(string(data)) + if key == "" { + return "", fmt.Errorf("WG public key file is empty") + } + return key, nil +} + +// getMyPublicIP determines this node's public IP by connecting to a public server +func (h *Handler) getMyPublicIP() (string, error) { + conn, err := net.DialTimeout("udp", "8.8.8.8:80", 3*time.Second) + if err != nil { + return "", fmt.Errorf("failed to determine public IP: %w", err) + } + defer conn.Close() + addr := conn.LocalAddr().(*net.UDPAddr) + return addr.IP.String(), nil +} + // queryIPFSPeerInfo gets the local IPFS node's peer ID and builds addrs with WG IP func (h *Handler) queryIPFSPeerInfo(myWGIP string) PeerInfo { client := &http.Client{Timeout: 5 * time.Second} diff --git a/pkg/gateway/handlers/join/handler_test.go b/pkg/gateway/handlers/join/handler_test.go new file mode 100644 index 0000000..a170aa7 --- /dev/null +++ b/pkg/gateway/handlers/join/handler_test.go @@ -0,0 +1,112 @@ +package join + +import ( + "encoding/base64" + "fmt" + "net" + "strings" + "testing" +) + +func TestWgPeersContainsIP_found(t *testing.T) { + peers := []WGPeerInfo{ + {PublicKey: "key1", Endpoint: "1.2.3.4:51820", AllowedIP: "10.0.0.1/32"}, + {PublicKey: "key2", Endpoint: "5.6.7.8:51820", AllowedIP: "10.0.0.2/32"}, + } + + if !wgPeersContainsIP(peers, "10.0.0.1") { + t.Error("expected to find 10.0.0.1 in peer list") + } + if !wgPeersContainsIP(peers, "10.0.0.2") { + t.Error("expected to find 10.0.0.2 in peer list") + } +} + +func TestWgPeersContainsIP_not_found(t *testing.T) { + peers := []WGPeerInfo{ + {PublicKey: "key1", Endpoint: "1.2.3.4:51820", AllowedIP: "10.0.0.1/32"}, + } + + if wgPeersContainsIP(peers, "10.0.0.2") { + t.Error("did not expect to find 10.0.0.2 in peer list") + } +} + +func TestWgPeersContainsIP_empty_list(t *testing.T) { + if wgPeersContainsIP(nil, "10.0.0.1") { + t.Error("did not expect to find any IP in nil peer list") + } + if wgPeersContainsIP([]WGPeerInfo{}, "10.0.0.1") { + t.Error("did not expect to find any IP in empty peer list") + } +} + +func TestAssignWGIP_format(t *testing.T) { + // Verify the WG IP format used in the handler matches what wgPeersContainsIP expects + wgIP := "10.0.0.1" + allowedIP := fmt.Sprintf("%s/32", wgIP) + peers := []WGPeerInfo{{AllowedIP: allowedIP}} + + if !wgPeersContainsIP(peers, wgIP) { + t.Errorf("format mismatch: wgPeersContainsIP(%q, %q) should match", allowedIP, wgIP) + } +} + +func TestValidatePublicIP(t *testing.T) { + tests := []struct { + name string + ip string + valid bool + }{ + {"valid IPv4", "46.225.234.112", true}, + {"loopback", "127.0.0.1", true}, + {"invalid string", "not-an-ip", false}, + {"empty", "", false}, + {"IPv6", "::1", false}, + {"with newline", "1.2.3.4\n5.6.7.8", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed := net.ParseIP(tt.ip) + isValid := parsed != nil && parsed.To4() != nil && !strings.ContainsAny(tt.ip, "\n\r") + if isValid != tt.valid { + t.Errorf("IP %q: expected valid=%v, got %v", tt.ip, tt.valid, isValid) + } + }) + } +} + +func TestValidateWGPublicKey(t *testing.T) { + // Valid WireGuard key: 32 bytes, base64 encoded = 44 chars + validKey := base64.StdEncoding.EncodeToString(make([]byte, 32)) + + tests := []struct { + name string + key string + valid bool + }{ + {"valid 32-byte key", validKey, true}, + {"too short", base64.StdEncoding.EncodeToString(make([]byte, 16)), false}, + {"too long", base64.StdEncoding.EncodeToString(make([]byte, 64)), false}, + {"not base64", "not-a-valid-base64-key!!!", false}, + {"empty", "", false}, + {"newline injection", validKey + "\n[Peer]", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if strings.ContainsAny(tt.key, "\n\r") { + if tt.valid { + t.Errorf("key %q contains newlines but expected valid", tt.key) + } + return + } + decoded, err := base64.StdEncoding.DecodeString(tt.key) + isValid := err == nil && len(decoded) == 32 + if isValid != tt.valid { + t.Errorf("key %q: expected valid=%v, got %v", tt.key, tt.valid, isValid) + } + }) + } +} diff --git a/pkg/gateway/status_handlers.go b/pkg/gateway/status_handlers.go index dc0eced..7a8259b 100644 --- a/pkg/gateway/status_handlers.go +++ b/pkg/gateway/status_handlers.go @@ -129,7 +129,10 @@ func (g *Gateway) healthHandler(w http.ResponseWriter, r *http.Request) { if anyoneproxy.Running() { nr.result = checkResult{Status: "ok", Latency: time.Since(start).String()} } else { - nr.result = checkResult{Status: "error", Latency: time.Since(start).String(), Error: "SOCKS5 proxy not reachable at " + anyoneproxy.Address()} + // SOCKS5 port not reachable — Anyone relay is not installed/running. + // Treat as "unavailable" rather than "error" so nodes without Anyone + // don't report as degraded. + nr.result = checkResult{Status: "unavailable"} } } ch <- nr @@ -142,25 +145,7 @@ func (g *Gateway) healthHandler(w http.ResponseWriter, r *http.Request) { checks[nr.name] = nr.result } - // Aggregate status. - // Critical: rqlite down → "unhealthy" - // Non-critical (olric, ipfs, libp2p) error → "degraded" - // "unavailable" means the client was never configured — not an error. - overallStatus := "healthy" - if c := checks["rqlite"]; c.Status == "error" { - overallStatus = "unhealthy" - } - if overallStatus == "healthy" { - for name, c := range checks { - if name == "rqlite" { - continue - } - if c.Status == "error" { - overallStatus = "degraded" - break - } - } - } + overallStatus := aggregateHealthStatus(checks) httpStatus := http.StatusOK if overallStatus != "healthy" { @@ -236,6 +221,27 @@ func (g *Gateway) versionHandler(w http.ResponseWriter, r *http.Request) { }) } +// aggregateHealthStatus determines the overall health status from individual checks. +// Critical: rqlite down → "unhealthy" +// Non-critical (olric, ipfs, libp2p, anyone) error → "degraded" +// "unavailable" means the client was never configured — not an error. +func aggregateHealthStatus(checks map[string]checkResult) string { + status := "healthy" + if c := checks["rqlite"]; c.Status == "error" { + return "unhealthy" + } + for name, c := range checks { + if name == "rqlite" { + continue + } + if c.Status == "error" { + status = "degraded" + break + } + } + return status +} + // tlsCheckHandler validates if a domain should receive a TLS certificate // Used by Caddy's on-demand TLS feature to prevent abuse func (g *Gateway) tlsCheckHandler(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/gateway/status_handlers_test.go b/pkg/gateway/status_handlers_test.go new file mode 100644 index 0000000..e20b239 --- /dev/null +++ b/pkg/gateway/status_handlers_test.go @@ -0,0 +1,72 @@ +package gateway + +import "testing" + +func TestAggregateHealthStatus_allHealthy(t *testing.T) { + checks := map[string]checkResult{ + "rqlite": {Status: "ok"}, + "olric": {Status: "ok"}, + "ipfs": {Status: "ok"}, + "libp2p": {Status: "ok"}, + "anyone": {Status: "ok"}, + } + if got := aggregateHealthStatus(checks); got != "healthy" { + t.Errorf("expected healthy, got %s", got) + } +} + +func TestAggregateHealthStatus_rqliteError(t *testing.T) { + checks := map[string]checkResult{ + "rqlite": {Status: "error", Error: "connection refused"}, + "olric": {Status: "ok"}, + "ipfs": {Status: "ok"}, + } + if got := aggregateHealthStatus(checks); got != "unhealthy" { + t.Errorf("expected unhealthy, got %s", got) + } +} + +func TestAggregateHealthStatus_nonCriticalError(t *testing.T) { + checks := map[string]checkResult{ + "rqlite": {Status: "ok"}, + "olric": {Status: "error", Error: "timeout"}, + "ipfs": {Status: "ok"}, + } + if got := aggregateHealthStatus(checks); got != "degraded" { + t.Errorf("expected degraded, got %s", got) + } +} + +func TestAggregateHealthStatus_unavailableIsNotError(t *testing.T) { + // Key test: "unavailable" services (like Anyone in sandbox) should NOT + // cause degraded status. + checks := map[string]checkResult{ + "rqlite": {Status: "ok"}, + "olric": {Status: "ok"}, + "ipfs": {Status: "unavailable"}, + "libp2p": {Status: "unavailable"}, + "anyone": {Status: "unavailable"}, + } + if got := aggregateHealthStatus(checks); got != "healthy" { + t.Errorf("expected healthy when services are unavailable, got %s", got) + } +} + +func TestAggregateHealthStatus_emptyChecks(t *testing.T) { + checks := map[string]checkResult{} + if got := aggregateHealthStatus(checks); got != "healthy" { + t.Errorf("expected healthy for empty checks, got %s", got) + } +} + +func TestAggregateHealthStatus_rqliteErrorOverridesDegraded(t *testing.T) { + // rqlite error should take priority over other errors + checks := map[string]checkResult{ + "rqlite": {Status: "error", Error: "leader not found"}, + "olric": {Status: "error", Error: "timeout"}, + "anyone": {Status: "error", Error: "not reachable"}, + } + if got := aggregateHealthStatus(checks); got != "unhealthy" { + t.Errorf("expected unhealthy (rqlite takes priority), got %s", got) + } +}