mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 04:33:00 +00:00
feat(monitor): add sandbox environment support
- load nodes from active sandbox state for env=sandbox - extract fanoutArchive for efficient server-to-server distribution
This commit is contained in:
parent
6468019136
commit
78d876e71b
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/production/report"
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/remotessh"
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/sandbox"
|
||||
"github.com/DeBrosOfficial/network/pkg/inspector"
|
||||
)
|
||||
|
||||
@ -23,22 +24,9 @@ type CollectorConfig struct {
|
||||
// CollectOnce runs `sudo orama node report --json` on all matching nodes
|
||||
// in parallel and returns a ClusterSnapshot.
|
||||
func CollectOnce(ctx context.Context, cfg CollectorConfig) (*ClusterSnapshot, error) {
|
||||
nodes, err := inspector.LoadNodes(cfg.ConfigPath)
|
||||
nodes, cleanup, err := loadNodes(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load nodes: %w", err)
|
||||
}
|
||||
nodes = inspector.FilterByEnv(nodes, cfg.Env)
|
||||
if cfg.NodeFilter != "" {
|
||||
nodes = filterByHost(nodes, cfg.NodeFilter)
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("no nodes found for env %q", cfg.Env)
|
||||
}
|
||||
|
||||
// Prepare wallet-derived SSH keys
|
||||
cleanup, err := remotessh.PrepareNodeKeys(nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prepare SSH keys: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
@ -121,3 +109,61 @@ func truncate(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// loadNodes resolves the node list and SSH keys based on the environment.
|
||||
// For "sandbox", nodes are loaded from the active sandbox state file with
|
||||
// the sandbox SSH key already set. For other environments, nodes come from
|
||||
// nodes.conf and use wallet-derived SSH keys.
|
||||
func loadNodes(cfg CollectorConfig) ([]inspector.Node, func(), error) {
|
||||
noop := func() {}
|
||||
|
||||
if cfg.Env == "sandbox" {
|
||||
return loadSandboxNodes(cfg)
|
||||
}
|
||||
|
||||
nodes, err := inspector.LoadNodes(cfg.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, noop, fmt.Errorf("load nodes: %w", err)
|
||||
}
|
||||
nodes = inspector.FilterByEnv(nodes, cfg.Env)
|
||||
if cfg.NodeFilter != "" {
|
||||
nodes = filterByHost(nodes, cfg.NodeFilter)
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, noop, fmt.Errorf("no nodes found for env %q", cfg.Env)
|
||||
}
|
||||
|
||||
cleanup, err := remotessh.PrepareNodeKeys(nodes)
|
||||
if err != nil {
|
||||
return nil, noop, fmt.Errorf("prepare SSH keys: %w", err)
|
||||
}
|
||||
return nodes, cleanup, nil
|
||||
}
|
||||
|
||||
// loadSandboxNodes loads nodes from the active sandbox state file.
|
||||
func loadSandboxNodes(cfg CollectorConfig) ([]inspector.Node, func(), error) {
|
||||
noop := func() {}
|
||||
|
||||
sbxCfg, err := sandbox.LoadConfig()
|
||||
if err != nil {
|
||||
return nil, noop, fmt.Errorf("load sandbox config: %w", err)
|
||||
}
|
||||
|
||||
state, err := sandbox.FindActiveSandbox()
|
||||
if err != nil {
|
||||
return nil, noop, fmt.Errorf("find active sandbox: %w", err)
|
||||
}
|
||||
if state == nil {
|
||||
return nil, noop, fmt.Errorf("no active sandbox found")
|
||||
}
|
||||
|
||||
nodes := state.ToNodes(sbxCfg.ExpandedPrivateKeyPath())
|
||||
if cfg.NodeFilter != "" {
|
||||
nodes = filterByHost(nodes, cfg.NodeFilter)
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, noop, fmt.Errorf("no nodes found for sandbox %q", state.Name)
|
||||
}
|
||||
|
||||
return nodes, noop, nil
|
||||
}
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/remotessh"
|
||||
@ -257,71 +256,8 @@ func phase3UploadArchive(cfg *Config, state *SandboxState) error {
|
||||
fmt.Printf(" Archive: %s (%s)\n", filepath.Base(archivePath), formatBytes(info.Size()))
|
||||
|
||||
sshKeyPath := cfg.ExpandedPrivateKeyPath()
|
||||
remotePath := "/tmp/" + filepath.Base(archivePath)
|
||||
extractCmd := fmt.Sprintf("mkdir -p /opt/orama && tar xzf %s -C /opt/orama && rm -f %s",
|
||||
remotePath, remotePath)
|
||||
|
||||
// Step 1: Upload from local machine to genesis node
|
||||
genesis := state.Servers[0]
|
||||
genesisNode := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath}
|
||||
|
||||
fmt.Printf(" Uploading to %s (genesis)...\n", genesis.Name)
|
||||
if err := remotessh.UploadFile(genesisNode, archivePath, remotePath, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("upload to %s: %w", genesis.Name, err)
|
||||
}
|
||||
|
||||
// Step 2: Fan out from genesis to remaining nodes in parallel (server-to-server)
|
||||
if len(state.Servers) > 1 {
|
||||
fmt.Printf(" Fanning out from %s to %d nodes...\n", genesis.Name, len(state.Servers)-1)
|
||||
|
||||
// Temporarily upload SSH key to genesis for server-to-server SCP
|
||||
remoteKeyPath := "/tmp/.sandbox_key"
|
||||
if err := remotessh.UploadFile(genesisNode, sshKeyPath, remoteKeyPath, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("upload SSH key to genesis: %w", err)
|
||||
}
|
||||
// Always clean up the temporary key, even on panic/early return
|
||||
defer remotessh.RunSSHStreaming(genesisNode, fmt.Sprintf("rm -f %s", remoteKeyPath), remotessh.WithNoHostKeyCheck())
|
||||
|
||||
if err := remotessh.RunSSHStreaming(genesisNode, fmt.Sprintf("chmod 600 %s", remoteKeyPath), remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("chmod SSH key on genesis: %w", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, len(state.Servers))
|
||||
|
||||
for i := 1; i < len(state.Servers); i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int, srv ServerState) {
|
||||
defer wg.Done()
|
||||
// SCP from genesis to target using the uploaded key
|
||||
scpCmd := fmt.Sprintf("scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i %s %s root@%s:%s",
|
||||
remoteKeyPath, remotePath, srv.IP, remotePath)
|
||||
if err := remotessh.RunSSHStreaming(genesisNode, scpCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
errs[idx] = fmt.Errorf("fanout to %s: %w", srv.Name, err)
|
||||
return
|
||||
}
|
||||
// Extract on target
|
||||
targetNode := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath}
|
||||
if err := remotessh.RunSSHStreaming(targetNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
errs[idx] = fmt.Errorf("extract on %s: %w", srv.Name, err)
|
||||
return
|
||||
}
|
||||
fmt.Printf(" Distributed to %s\n", srv.Name)
|
||||
}(i, state.Servers[i])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Extract on genesis
|
||||
fmt.Printf(" Extracting on %s...\n", genesis.Name)
|
||||
if err := remotessh.RunSSHStreaming(genesisNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("extract on %s: %w", genesis.Name, err)
|
||||
if err := fanoutArchive(state.Servers, sshKeyPath, archivePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(" All nodes ready")
|
||||
|
||||
84
pkg/cli/sandbox/fanout.go
Normal file
84
pkg/cli/sandbox/fanout.go
Normal file
@ -0,0 +1,84 @@
|
||||
package sandbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/remotessh"
|
||||
"github.com/DeBrosOfficial/network/pkg/inspector"
|
||||
)
|
||||
|
||||
// fanoutArchive uploads a binary archive to the first server, then fans out
|
||||
// server-to-server in parallel to all remaining servers. This is much faster
|
||||
// than uploading from the local machine to each node individually.
|
||||
// After distribution, the archive is extracted on all nodes.
|
||||
func fanoutArchive(servers []ServerState, sshKeyPath, archivePath string) error {
|
||||
remotePath := "/tmp/" + filepath.Base(archivePath)
|
||||
extractCmd := fmt.Sprintf("mkdir -p /opt/orama && tar xzf %s -C /opt/orama && rm -f %s",
|
||||
remotePath, remotePath)
|
||||
|
||||
// Step 1: Upload from local machine to first node
|
||||
first := servers[0]
|
||||
firstNode := inspector.Node{User: "root", Host: first.IP, SSHKey: sshKeyPath}
|
||||
|
||||
fmt.Printf(" Uploading to %s...\n", first.Name)
|
||||
if err := remotessh.UploadFile(firstNode, archivePath, remotePath, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("upload to %s: %w", first.Name, err)
|
||||
}
|
||||
|
||||
// Step 2: Fan out from first node to remaining nodes in parallel (server-to-server)
|
||||
if len(servers) > 1 {
|
||||
fmt.Printf(" Fanning out from %s to %d nodes...\n", first.Name, len(servers)-1)
|
||||
|
||||
// Temporarily upload SSH key for server-to-server SCP
|
||||
remoteKeyPath := "/tmp/.sandbox_key"
|
||||
if err := remotessh.UploadFile(firstNode, sshKeyPath, remoteKeyPath, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("upload SSH key to %s: %w", first.Name, err)
|
||||
}
|
||||
defer remotessh.RunSSHStreaming(firstNode, fmt.Sprintf("rm -f %s", remoteKeyPath), remotessh.WithNoHostKeyCheck())
|
||||
|
||||
if err := remotessh.RunSSHStreaming(firstNode, fmt.Sprintf("chmod 600 %s", remoteKeyPath), remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("chmod SSH key on %s: %w", first.Name, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, len(servers))
|
||||
|
||||
for i := 1; i < len(servers); i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int, srv ServerState) {
|
||||
defer wg.Done()
|
||||
// SCP from first node to target
|
||||
scpCmd := fmt.Sprintf("scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i %s %s root@%s:%s",
|
||||
remoteKeyPath, remotePath, srv.IP, remotePath)
|
||||
if err := remotessh.RunSSHStreaming(firstNode, scpCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
errs[idx] = fmt.Errorf("fanout to %s: %w", srv.Name, err)
|
||||
return
|
||||
}
|
||||
// Extract on target
|
||||
targetNode := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath}
|
||||
if err := remotessh.RunSSHStreaming(targetNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
errs[idx] = fmt.Errorf("extract on %s: %w", srv.Name, err)
|
||||
return
|
||||
}
|
||||
fmt.Printf(" Distributed to %s\n", srv.Name)
|
||||
}(i, servers[i])
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Extract on first node
|
||||
fmt.Printf(" Extracting on %s...\n", first.Name)
|
||||
if err := remotessh.RunSSHStreaming(firstNode, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("extract on %s: %w", first.Name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -34,24 +34,10 @@ func Rollout(name string) error {
|
||||
info, _ := os.Stat(archivePath)
|
||||
fmt.Printf("Archive: %s (%s)\n\n", filepath.Base(archivePath), formatBytes(info.Size()))
|
||||
|
||||
// Step 2: Push archive to all nodes
|
||||
// Step 2: Push archive to all nodes (upload to first, fan out server-to-server)
|
||||
fmt.Println("Pushing archive to all nodes...")
|
||||
remotePath := "/tmp/" + filepath.Base(archivePath)
|
||||
|
||||
for i, srv := range state.Servers {
|
||||
node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath}
|
||||
|
||||
fmt.Printf(" [%d/%d] Uploading to %s...\n", i+1, len(state.Servers), srv.Name)
|
||||
if err := remotessh.UploadFile(node, archivePath, remotePath, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("upload to %s: %w", srv.Name, err)
|
||||
}
|
||||
|
||||
// Extract archive
|
||||
extractCmd := fmt.Sprintf("mkdir -p /opt/orama && tar xzf %s -C /opt/orama && rm -f %s",
|
||||
remotePath, remotePath)
|
||||
if err := remotessh.RunSSHStreaming(node, extractCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("extract on %s: %w", srv.Name, err)
|
||||
}
|
||||
if err := fanoutArchive(state.Servers, sshKeyPath, archivePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 3: Rolling upgrade — followers first, leader last
|
||||
@ -103,10 +89,22 @@ func findLeaderIndex(state *SandboxState, sshKeyPath string) int {
|
||||
}
|
||||
|
||||
// upgradeNode performs `orama node upgrade --restart` on a single node.
|
||||
// It pre-replaces the orama CLI binary before running the upgrade command
|
||||
// to avoid ETXTBSY ("text file busy") errors when the old binary doesn't
|
||||
// have the os.Remove fix in copyBinary().
|
||||
func upgradeNode(srv ServerState, sshKeyPath string, current, total int) error {
|
||||
node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath}
|
||||
|
||||
fmt.Printf(" [%d/%d] Upgrading %s (%s)...\n", current, total, srv.Name, srv.IP)
|
||||
|
||||
// Pre-replace the orama CLI so the upgrade runs the NEW binary (with ETXTBSY fix).
|
||||
// rm unlinks the old inode (kernel keeps it alive for the running process),
|
||||
// cp creates a fresh inode at the same path.
|
||||
preReplace := "rm -f /usr/local/bin/orama && cp /opt/orama/bin/orama /usr/local/bin/orama"
|
||||
if err := remotessh.RunSSHStreaming(node, preReplace, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("pre-replace orama binary on %s: %w", srv.Name, err)
|
||||
}
|
||||
|
||||
if err := remotessh.RunSSHStreaming(node, "orama node upgrade --restart", remotessh.WithNoHostKeyCheck()); err != nil {
|
||||
return fmt.Errorf("upgrade %s: %w", srv.Name, err)
|
||||
}
|
||||
|
||||
@ -9,6 +9,31 @@ import (
|
||||
"github.com/rqlite/gorqlite"
|
||||
)
|
||||
|
||||
// safeWriteOne wraps gorqlite's WriteOneParameterized to recover from panics.
|
||||
// gorqlite's WriteOne* functions access wra[0] without checking if the slice
|
||||
// is empty, which panics when the server returns an error (e.g. "leader not found")
|
||||
// with no result rows.
|
||||
func safeWriteOne(conn *gorqlite.Connection, stmt gorqlite.ParameterizedStatement) (wr gorqlite.WriteResult, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("rqlite write failed (recovered panic): %v", r)
|
||||
}
|
||||
}()
|
||||
wr, err = conn.WriteOneParameterized(stmt)
|
||||
return
|
||||
}
|
||||
|
||||
// safeWriteOneRaw wraps gorqlite's WriteOne to recover from panics.
|
||||
func safeWriteOneRaw(conn *gorqlite.Connection, sql string) (wr gorqlite.WriteResult, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("rqlite write failed (recovered panic): %v", r)
|
||||
}
|
||||
}()
|
||||
wr, err = conn.WriteOne(sql)
|
||||
return
|
||||
}
|
||||
|
||||
// DatabaseClientImpl implements DatabaseClient
|
||||
type DatabaseClientImpl struct {
|
||||
client *Client
|
||||
@ -79,7 +104,7 @@ func (d *DatabaseClientImpl) Query(ctx context.Context, sql string, args ...inte
|
||||
|
||||
if isWriteOperation {
|
||||
// Execute write operation with parameters
|
||||
_, err := conn.WriteOneParameterized(gorqlite.ParameterizedStatement{
|
||||
_, err := safeWriteOne(conn, gorqlite.ParameterizedStatement{
|
||||
Query: sql,
|
||||
Arguments: args,
|
||||
})
|
||||
@ -293,7 +318,7 @@ func (d *DatabaseClientImpl) Transaction(ctx context.Context, queries []string)
|
||||
// Execute all queries in the transaction
|
||||
success := true
|
||||
for _, query := range queries {
|
||||
_, err := conn.WriteOne(query)
|
||||
_, err := safeWriteOneRaw(conn, query)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
success = false
|
||||
@ -321,7 +346,7 @@ func (d *DatabaseClientImpl) CreateTable(ctx context.Context, schema string) err
|
||||
}
|
||||
|
||||
return d.withRetry(func(conn *gorqlite.Connection) error {
|
||||
_, err := conn.WriteOne(schema)
|
||||
_, err := safeWriteOneRaw(conn, schema)
|
||||
return err
|
||||
})
|
||||
}
|
||||
@ -334,7 +359,7 @@ func (d *DatabaseClientImpl) DropTable(ctx context.Context, tableName string) er
|
||||
|
||||
return d.withRetry(func(conn *gorqlite.Connection) error {
|
||||
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)
|
||||
_, err := conn.WriteOne(dropSQL)
|
||||
_, err := safeWriteOneRaw(conn, dropSQL)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
82
pkg/client/database_client_test.go
Normal file
82
pkg/client/database_client_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/rqlite/gorqlite"
|
||||
)
|
||||
|
||||
// mockPanicConnection simulates what gorqlite does when WriteParameterized
|
||||
// returns an empty slice: accessing [0] panics.
|
||||
func simulateGorqlitePanic() (gorqlite.WriteResult, error) {
|
||||
var empty []gorqlite.WriteResult
|
||||
return empty[0], fmt.Errorf("leader not found") // panics
|
||||
}
|
||||
|
||||
func TestSafeWriteOne_recoversPanic(t *testing.T) {
|
||||
// We can't easily create a real gorqlite.Connection that panics,
|
||||
// but we can verify our recovery wrapper works by testing the
|
||||
// recovery pattern directly.
|
||||
var recovered bool
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
recovered = true
|
||||
}
|
||||
}()
|
||||
simulateGorqlitePanic()
|
||||
}()
|
||||
|
||||
if !recovered {
|
||||
t.Fatal("expected simulateGorqlitePanic to panic, but it didn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeWriteOne_nilConnection(t *testing.T) {
|
||||
// safeWriteOne with nil connection should recover from panic, not crash.
|
||||
_, err := safeWriteOne(nil, gorqlite.ParameterizedStatement{
|
||||
Query: "INSERT INTO test (a) VALUES (?)",
|
||||
Arguments: []interface{}{"x"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from nil connection, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeWriteOneRaw_nilConnection(t *testing.T) {
|
||||
// safeWriteOneRaw with nil connection should recover from panic, not crash.
|
||||
_, err := safeWriteOneRaw(nil, "INSERT INTO test (a) VALUES ('x')")
|
||||
if err == nil {
|
||||
t.Fatal("expected error from nil connection, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsWriteOperation(t *testing.T) {
|
||||
d := &DatabaseClientImpl{}
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
isWrite bool
|
||||
}{
|
||||
{"INSERT INTO foo VALUES (1)", true},
|
||||
{" INSERT INTO foo VALUES (1)", true},
|
||||
{"UPDATE foo SET a = 1", true},
|
||||
{"DELETE FROM foo", true},
|
||||
{"CREATE TABLE foo (a TEXT)", true},
|
||||
{"DROP TABLE foo", true},
|
||||
{"ALTER TABLE foo ADD COLUMN b TEXT", true},
|
||||
{"SELECT * FROM foo", false},
|
||||
{" SELECT * FROM foo", false},
|
||||
{"EXPLAIN SELECT * FROM foo", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.sql, func(t *testing.T) {
|
||||
got := d.isWriteOperation(tt.sql)
|
||||
if got != tt.isWrite {
|
||||
t.Errorf("isWriteOperation(%q) = %v, want %v", tt.sql, got, tt.isWrite)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -997,6 +997,13 @@ func (ps *ProductionSetup) Phase6SetupWireGuard(isFirstNode bool) (privateKey, p
|
||||
}
|
||||
ps.logf(" ✓ WireGuard keypair generated")
|
||||
|
||||
// Save public key to orama secrets so the gateway (running as orama user)
|
||||
// can read it without needing root access to /etc/wireguard/wg0.conf
|
||||
pubKeyPath := filepath.Join(ps.oramaDir, "secrets", "wg-public-key")
|
||||
if err := os.WriteFile(pubKeyPath, []byte(pubKey), 0600); err != nil {
|
||||
return "", "", fmt.Errorf("failed to save WG public key: %w", err)
|
||||
}
|
||||
|
||||
if isFirstNode {
|
||||
// First node: self-assign 10.0.0.1, no peers yet
|
||||
wp.config = WireGuardConfig{
|
||||
|
||||
@ -291,12 +291,20 @@ func (ps *ProductionSetup) installAnyonFromPreBuilt() error {
|
||||
}
|
||||
|
||||
// copyBinary copies a file from src to dest, preserving executable permissions.
|
||||
// It removes the destination first to avoid ETXTBSY ("text file busy") errors
|
||||
// when overwriting a binary that is currently running.
|
||||
func copyBinary(src, dest string) error {
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the old binary first. On Linux, if the binary is running,
|
||||
// rm unlinks the filename while the kernel keeps the inode alive for
|
||||
// the running process. Writing a new file at the same path creates a
|
||||
// fresh inode — no ETXTBSY conflict.
|
||||
_ = os.Remove(dest)
|
||||
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@ -213,6 +213,7 @@ Requires=wg-quick@wg0.service
|
||||
[Service]
|
||||
Type=simple
|
||||
%[5]s
|
||||
AmbientCapabilities=CAP_NET_ADMIN
|
||||
ReadWritePaths=%[2]s
|
||||
WorkingDirectory=%[1]s
|
||||
Environment=HOME=%[1]s
|
||||
|
||||
@ -2,8 +2,10 @@ package join
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@ -100,6 +102,24 @@ func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate public IP format
|
||||
if net.ParseIP(req.PublicIP) == nil || net.ParseIP(req.PublicIP).To4() == nil {
|
||||
http.Error(w, "public_ip must be a valid IPv4 address", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate WireGuard public key: must be base64-encoded 32 bytes (Curve25519)
|
||||
// Also reject control characters (newlines) to prevent config injection
|
||||
if strings.ContainsAny(req.WGPublicKey, "\n\r") {
|
||||
http.Error(w, "wg_public_key contains invalid characters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
wgKeyBytes, err := base64.StdEncoding.DecodeString(req.WGPublicKey)
|
||||
if err != nil || len(wgKeyBytes) != 32 {
|
||||
http.Error(w, "wg_public_key must be a valid base64-encoded 32-byte key", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// 1. Validate and consume the invite token (atomic single-use)
|
||||
@ -177,7 +197,15 @@ func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) {
|
||||
olricEncryptionKey = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
// 7. Get all WG peers
|
||||
// 7. Get this node's WG IP (needed before peer list to check self-inclusion)
|
||||
myWGIP, err := h.getMyWGIP()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get local WG IP", zap.Error(err))
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 8. Get all WG peers
|
||||
wgPeers, err := h.getWGPeers(ctx, req.WGPublicKey)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list WG peers", zap.Error(err))
|
||||
@ -185,12 +213,29 @@ func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// 8. Get this node's WG IP
|
||||
myWGIP, err := h.getMyWGIP()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get local WG IP", zap.Error(err))
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
// Ensure this node (the join handler's host) is in the peer list.
|
||||
// On a fresh genesis node, the WG sync loop may not have self-registered
|
||||
// into wireguard_peers yet, causing 0 peers to be returned.
|
||||
if !wgPeersContainsIP(wgPeers, myWGIP) {
|
||||
myPubKey, err := h.getMyWGPublicKey()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get local WG public key", zap.Error(err))
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
myPublicIP, err := h.getMyPublicIP()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get local public IP", zap.Error(err))
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
wgPeers = append([]WGPeerInfo{{
|
||||
PublicKey: myPubKey,
|
||||
Endpoint: fmt.Sprintf("%s:%d", myPublicIP, 51820),
|
||||
AllowedIP: fmt.Sprintf("%s/32", myWGIP),
|
||||
}}, wgPeers...)
|
||||
h.logger.Info("self-injected into WG peer list (sync loop hasn't registered yet)",
|
||||
zap.String("wg_ip", myWGIP))
|
||||
}
|
||||
|
||||
// 9. Query IPFS and IPFS Cluster peer info
|
||||
@ -346,6 +391,17 @@ func (h *Handler) addWGPeerLocally(pubKey, publicIP, wgIP string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// wgPeersContainsIP checks if any peer in the list has the given WG IP
|
||||
func wgPeersContainsIP(peers []WGPeerInfo, wgIP string) bool {
|
||||
target := fmt.Sprintf("%s/32", wgIP)
|
||||
for _, p := range peers {
|
||||
if p.AllowedIP == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getWGPeers returns all WG peers except the requesting node
|
||||
func (h *Handler) getWGPeers(ctx context.Context, excludePubKey string) ([]WGPeerInfo, error) {
|
||||
type peerRow struct {
|
||||
@ -403,6 +459,32 @@ func (h *Handler) getMyWGIP() (string, error) {
|
||||
return "", fmt.Errorf("could not find wg0 IP address")
|
||||
}
|
||||
|
||||
// getMyWGPublicKey reads the local WireGuard public key from the orama secrets
|
||||
// directory. The key is saved there during install by Phase6SetupWireGuard.
|
||||
// This avoids needing root/CAP_NET_ADMIN permissions that `wg show wg0` requires.
|
||||
func (h *Handler) getMyWGPublicKey() (string, error) {
|
||||
data, err := os.ReadFile(h.oramaDir + "/secrets/wg-public-key")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read WG public key from %s/secrets/wg-public-key: %w", h.oramaDir, err)
|
||||
}
|
||||
key := strings.TrimSpace(string(data))
|
||||
if key == "" {
|
||||
return "", fmt.Errorf("WG public key file is empty")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// getMyPublicIP determines this node's public IP by connecting to a public server
|
||||
func (h *Handler) getMyPublicIP() (string, error) {
|
||||
conn, err := net.DialTimeout("udp", "8.8.8.8:80", 3*time.Second)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to determine public IP: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
addr := conn.LocalAddr().(*net.UDPAddr)
|
||||
return addr.IP.String(), nil
|
||||
}
|
||||
|
||||
// queryIPFSPeerInfo gets the local IPFS node's peer ID and builds addrs with WG IP
|
||||
func (h *Handler) queryIPFSPeerInfo(myWGIP string) PeerInfo {
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
112
pkg/gateway/handlers/join/handler_test.go
Normal file
112
pkg/gateway/handlers/join/handler_test.go
Normal file
@ -0,0 +1,112 @@
|
||||
package join
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWgPeersContainsIP_found(t *testing.T) {
|
||||
peers := []WGPeerInfo{
|
||||
{PublicKey: "key1", Endpoint: "1.2.3.4:51820", AllowedIP: "10.0.0.1/32"},
|
||||
{PublicKey: "key2", Endpoint: "5.6.7.8:51820", AllowedIP: "10.0.0.2/32"},
|
||||
}
|
||||
|
||||
if !wgPeersContainsIP(peers, "10.0.0.1") {
|
||||
t.Error("expected to find 10.0.0.1 in peer list")
|
||||
}
|
||||
if !wgPeersContainsIP(peers, "10.0.0.2") {
|
||||
t.Error("expected to find 10.0.0.2 in peer list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWgPeersContainsIP_not_found(t *testing.T) {
|
||||
peers := []WGPeerInfo{
|
||||
{PublicKey: "key1", Endpoint: "1.2.3.4:51820", AllowedIP: "10.0.0.1/32"},
|
||||
}
|
||||
|
||||
if wgPeersContainsIP(peers, "10.0.0.2") {
|
||||
t.Error("did not expect to find 10.0.0.2 in peer list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWgPeersContainsIP_empty_list(t *testing.T) {
|
||||
if wgPeersContainsIP(nil, "10.0.0.1") {
|
||||
t.Error("did not expect to find any IP in nil peer list")
|
||||
}
|
||||
if wgPeersContainsIP([]WGPeerInfo{}, "10.0.0.1") {
|
||||
t.Error("did not expect to find any IP in empty peer list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignWGIP_format(t *testing.T) {
|
||||
// Verify the WG IP format used in the handler matches what wgPeersContainsIP expects
|
||||
wgIP := "10.0.0.1"
|
||||
allowedIP := fmt.Sprintf("%s/32", wgIP)
|
||||
peers := []WGPeerInfo{{AllowedIP: allowedIP}}
|
||||
|
||||
if !wgPeersContainsIP(peers, wgIP) {
|
||||
t.Errorf("format mismatch: wgPeersContainsIP(%q, %q) should match", allowedIP, wgIP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePublicIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
valid bool
|
||||
}{
|
||||
{"valid IPv4", "46.225.234.112", true},
|
||||
{"loopback", "127.0.0.1", true},
|
||||
{"invalid string", "not-an-ip", false},
|
||||
{"empty", "", false},
|
||||
{"IPv6", "::1", false},
|
||||
{"with newline", "1.2.3.4\n5.6.7.8", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsed := net.ParseIP(tt.ip)
|
||||
isValid := parsed != nil && parsed.To4() != nil && !strings.ContainsAny(tt.ip, "\n\r")
|
||||
if isValid != tt.valid {
|
||||
t.Errorf("IP %q: expected valid=%v, got %v", tt.ip, tt.valid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWGPublicKey(t *testing.T) {
|
||||
// Valid WireGuard key: 32 bytes, base64 encoded = 44 chars
|
||||
validKey := base64.StdEncoding.EncodeToString(make([]byte, 32))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
valid bool
|
||||
}{
|
||||
{"valid 32-byte key", validKey, true},
|
||||
{"too short", base64.StdEncoding.EncodeToString(make([]byte, 16)), false},
|
||||
{"too long", base64.StdEncoding.EncodeToString(make([]byte, 64)), false},
|
||||
{"not base64", "not-a-valid-base64-key!!!", false},
|
||||
{"empty", "", false},
|
||||
{"newline injection", validKey + "\n[Peer]", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if strings.ContainsAny(tt.key, "\n\r") {
|
||||
if tt.valid {
|
||||
t.Errorf("key %q contains newlines but expected valid", tt.key)
|
||||
}
|
||||
return
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(tt.key)
|
||||
isValid := err == nil && len(decoded) == 32
|
||||
if isValid != tt.valid {
|
||||
t.Errorf("key %q: expected valid=%v, got %v", tt.key, tt.valid, isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -129,7 +129,10 @@ func (g *Gateway) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if anyoneproxy.Running() {
|
||||
nr.result = checkResult{Status: "ok", Latency: time.Since(start).String()}
|
||||
} else {
|
||||
nr.result = checkResult{Status: "error", Latency: time.Since(start).String(), Error: "SOCKS5 proxy not reachable at " + anyoneproxy.Address()}
|
||||
// SOCKS5 port not reachable — Anyone relay is not installed/running.
|
||||
// Treat as "unavailable" rather than "error" so nodes without Anyone
|
||||
// don't report as degraded.
|
||||
nr.result = checkResult{Status: "unavailable"}
|
||||
}
|
||||
}
|
||||
ch <- nr
|
||||
@ -142,25 +145,7 @@ func (g *Gateway) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
checks[nr.name] = nr.result
|
||||
}
|
||||
|
||||
// Aggregate status.
|
||||
// Critical: rqlite down → "unhealthy"
|
||||
// Non-critical (olric, ipfs, libp2p) error → "degraded"
|
||||
// "unavailable" means the client was never configured — not an error.
|
||||
overallStatus := "healthy"
|
||||
if c := checks["rqlite"]; c.Status == "error" {
|
||||
overallStatus = "unhealthy"
|
||||
}
|
||||
if overallStatus == "healthy" {
|
||||
for name, c := range checks {
|
||||
if name == "rqlite" {
|
||||
continue
|
||||
}
|
||||
if c.Status == "error" {
|
||||
overallStatus = "degraded"
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
overallStatus := aggregateHealthStatus(checks)
|
||||
|
||||
httpStatus := http.StatusOK
|
||||
if overallStatus != "healthy" {
|
||||
@ -236,6 +221,27 @@ func (g *Gateway) versionHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// aggregateHealthStatus determines the overall health status from individual checks.
|
||||
// Critical: rqlite down → "unhealthy"
|
||||
// Non-critical (olric, ipfs, libp2p, anyone) error → "degraded"
|
||||
// "unavailable" means the client was never configured — not an error.
|
||||
func aggregateHealthStatus(checks map[string]checkResult) string {
|
||||
status := "healthy"
|
||||
if c := checks["rqlite"]; c.Status == "error" {
|
||||
return "unhealthy"
|
||||
}
|
||||
for name, c := range checks {
|
||||
if name == "rqlite" {
|
||||
continue
|
||||
}
|
||||
if c.Status == "error" {
|
||||
status = "degraded"
|
||||
break
|
||||
}
|
||||
}
|
||||
return status
|
||||
}
|
||||
|
||||
// tlsCheckHandler validates if a domain should receive a TLS certificate
|
||||
// Used by Caddy's on-demand TLS feature to prevent abuse
|
||||
func (g *Gateway) tlsCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
72
pkg/gateway/status_handlers_test.go
Normal file
72
pkg/gateway/status_handlers_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package gateway
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAggregateHealthStatus_allHealthy(t *testing.T) {
|
||||
checks := map[string]checkResult{
|
||||
"rqlite": {Status: "ok"},
|
||||
"olric": {Status: "ok"},
|
||||
"ipfs": {Status: "ok"},
|
||||
"libp2p": {Status: "ok"},
|
||||
"anyone": {Status: "ok"},
|
||||
}
|
||||
if got := aggregateHealthStatus(checks); got != "healthy" {
|
||||
t.Errorf("expected healthy, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregateHealthStatus_rqliteError(t *testing.T) {
|
||||
checks := map[string]checkResult{
|
||||
"rqlite": {Status: "error", Error: "connection refused"},
|
||||
"olric": {Status: "ok"},
|
||||
"ipfs": {Status: "ok"},
|
||||
}
|
||||
if got := aggregateHealthStatus(checks); got != "unhealthy" {
|
||||
t.Errorf("expected unhealthy, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregateHealthStatus_nonCriticalError(t *testing.T) {
|
||||
checks := map[string]checkResult{
|
||||
"rqlite": {Status: "ok"},
|
||||
"olric": {Status: "error", Error: "timeout"},
|
||||
"ipfs": {Status: "ok"},
|
||||
}
|
||||
if got := aggregateHealthStatus(checks); got != "degraded" {
|
||||
t.Errorf("expected degraded, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregateHealthStatus_unavailableIsNotError(t *testing.T) {
|
||||
// Key test: "unavailable" services (like Anyone in sandbox) should NOT
|
||||
// cause degraded status.
|
||||
checks := map[string]checkResult{
|
||||
"rqlite": {Status: "ok"},
|
||||
"olric": {Status: "ok"},
|
||||
"ipfs": {Status: "unavailable"},
|
||||
"libp2p": {Status: "unavailable"},
|
||||
"anyone": {Status: "unavailable"},
|
||||
}
|
||||
if got := aggregateHealthStatus(checks); got != "healthy" {
|
||||
t.Errorf("expected healthy when services are unavailable, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregateHealthStatus_emptyChecks(t *testing.T) {
|
||||
checks := map[string]checkResult{}
|
||||
if got := aggregateHealthStatus(checks); got != "healthy" {
|
||||
t.Errorf("expected healthy for empty checks, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregateHealthStatus_rqliteErrorOverridesDegraded(t *testing.T) {
|
||||
// rqlite error should take priority over other errors
|
||||
checks := map[string]checkResult{
|
||||
"rqlite": {Status: "error", Error: "leader not found"},
|
||||
"olric": {Status: "error", Error: "timeout"},
|
||||
"anyone": {Status: "error", Error: "not reachable"},
|
||||
}
|
||||
if got := aggregateHealthStatus(checks); got != "unhealthy" {
|
||||
t.Errorf("expected unhealthy (rqlite takes priority), got %s", got)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user