From 0764ac287eb94b474f6b53face3861b0e73c1144 Mon Sep 17 00:00:00 2001 From: anonpenguin23 Date: Fri, 20 Mar 2026 07:23:10 +0200 Subject: [PATCH] refactor(remotessh): use rwagent directly instead of rw CLI subprocesses - replace `rw vault ssh` calls with `rwagent.Client` in PrepareNodeKeys, LoadAgentKeys, EnsureVaultEntry, ResolveVaultPublicKey - add vaultClient interface, newClient func, and wrapAgentError for testability and improved error messages - prefer pre-built systemd dir in installNamespaceTemplates --- pkg/cli/production/install/orchestrator.go | 6 +- pkg/cli/remotessh/wallet.go | 198 +++++----- pkg/cli/remotessh/wallet_test.go | 349 +++++++++++++++++- pkg/cli/sandbox/create.go | 53 ++- pkg/cli/sandbox/hetzner.go | 6 +- pkg/cli/sandbox/setup.go | 38 +- pkg/cli/sandbox/setup_test.go | 82 ++++ .../production/installers/anyone_relay.go | 3 + .../production/installers/caddy.go | 4 +- .../production/installers/coredns.go | 6 +- .../production/installers/coredns_test.go | 151 ++++++++ pkg/environments/production/services.go | 5 +- pkg/environments/production/services_test.go | 33 ++ pkg/gateway/status_handlers.go | 8 +- pkg/ipfs/cluster_peer.go | 5 +- pkg/ipfs/cluster_peer_test.go | 95 +++++ pkg/rwagent/client.go | 222 +++++++++++ pkg/rwagent/client_test.go | 257 +++++++++++++ pkg/rwagent/errors.go | 57 +++ pkg/rwagent/types.go | 56 +++ 20 files changed, 1476 insertions(+), 158 deletions(-) create mode 100644 pkg/cli/sandbox/setup_test.go create mode 100644 pkg/environments/production/installers/coredns_test.go create mode 100644 pkg/ipfs/cluster_peer_test.go create mode 100644 pkg/rwagent/client.go create mode 100644 pkg/rwagent/client_test.go create mode 100644 pkg/rwagent/errors.go create mode 100644 pkg/rwagent/types.go diff --git a/pkg/cli/production/install/orchestrator.go b/pkg/cli/production/install/orchestrator.go index 7372bbe..04a4054 100644 --- a/pkg/cli/production/install/orchestrator.go +++ b/pkg/cli/production/install/orchestrator.go @@ -597,7 +597,11 @@ func promptForBaseDomain() string { // installNamespaceTemplates installs systemd template unit files for namespace services func (o *Orchestrator) installNamespaceTemplates() error { - sourceDir := filepath.Join(o.oramaHome, "src", "systemd") + // Check pre-built archive path first, fall back to source path + sourceDir := production.OramaSystemdDir + if _, err := os.Stat(sourceDir); os.IsNotExist(err) { + sourceDir = filepath.Join(o.oramaHome, "src", "systemd") + } systemdDir := "/etc/systemd/system" templates := []string{ diff --git a/pkg/cli/remotessh/wallet.go b/pkg/cli/remotessh/wallet.go index 5675110..1dbb2d9 100644 --- a/pkg/cli/remotessh/wallet.go +++ b/pkg/cli/remotessh/wallet.go @@ -1,6 +1,7 @@ package remotessh import ( + "context" "fmt" "os" "os/exec" @@ -8,11 +9,40 @@ import ( "strings" "github.com/DeBrosOfficial/network/pkg/inspector" + "github.com/DeBrosOfficial/network/pkg/rwagent" ) +// vaultClient is the interface used by wallet functions to talk to the agent. +// Defaults to the real rwagent.Client; tests replace it with a mock. +type vaultClient interface { + GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error) + CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error) +} + +// newClient creates the default vaultClient. Package-level var for test injection. +var newClient func() vaultClient = func() vaultClient { + return rwagent.New(os.Getenv("RW_AGENT_SOCK")) +} + +// wrapAgentError wraps rwagent errors with user-friendly messages. +// When the agent is locked, it also triggers the RootWallet desktop app +// to show the unlock dialog via deep link (best-effort, fire-and-forget). +func wrapAgentError(err error, action string) error { + if rwagent.IsNotRunning(err) { + return fmt.Errorf("%s: rootwallet agent is not running — start with: rw agent start && rw agent unlock", action) + } + if rwagent.IsLocked(err) { + return fmt.Errorf("%s: rootwallet agent is locked — unlock timed out after waiting. Unlock via the RootWallet app or run: rw agent unlock", action) + } + if rwagent.IsApprovalDenied(err) { + return fmt.Errorf("%s: rootwallet access denied — approve this app in the RootWallet desktop app", action) + } + return fmt.Errorf("%s: %w", action, err) +} + // PrepareNodeKeys resolves wallet-derived SSH keys for all nodes. -// Calls `rw vault ssh get / --priv` for each unique host/user, -// writes PEMs to temp files, and sets node.SSHKey for each node. +// Retrieves private keys from the rootwallet agent daemon, writes PEMs to +// temp files, and sets node.SSHKey for each node. // // The nodes slice is modified in place — each node.SSHKey is set to // the path of the temporary key file. @@ -20,10 +50,8 @@ import ( // Returns a cleanup function that zero-overwrites and removes all temp files. // Caller must defer cleanup(). func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) { - rw, err := rwBinary() - if err != nil { - return nil, err - } + client := newClient() + ctx := context.Background() // Create temp dir for all keys tmpDir, err := os.MkdirTemp("", "orama-ssh-") @@ -31,12 +59,11 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) { return nil, fmt.Errorf("create temp dir: %w", err) } - // Track resolved keys by host/user to avoid duplicate rw calls + // Track resolved keys by host/user to avoid duplicate agent calls keyPaths := make(map[string]string) // "host/user" → temp file path var allKeyPaths []string for i := range nodes { - // Use VaultTarget if set, otherwise default to Host/User var key string if nodes[i].VaultTarget != "" { key = nodes[i].VaultTarget @@ -48,18 +75,21 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) { continue } - // Call rw to get the private key PEM host, user := parseVaultTarget(key) - pem, err := resolveWalletKey(rw, host, user) + data, err := client.GetSSHKey(ctx, host, user, "priv") if err != nil { - // Cleanup any keys already written before returning error cleanupKeys(tmpDir, allKeyPaths) - return nil, fmt.Errorf("resolve key for %s: %w", nodes[i].Name(), err) + return nil, wrapAgentError(err, fmt.Sprintf("resolve key for %s", nodes[i].Name())) + } + + if !strings.Contains(data.PrivateKey, "BEGIN OPENSSH PRIVATE KEY") { + cleanupKeys(tmpDir, allKeyPaths) + return nil, fmt.Errorf("agent returned invalid key for %s", nodes[i].Name()) } // Write PEM to temp file with restrictive perms keyFile := filepath.Join(tmpDir, fmt.Sprintf("id_%d", i)) - if err := os.WriteFile(keyFile, []byte(pem), 0600); err != nil { + if err := os.WriteFile(keyFile, []byte(data.PrivateKey), 0600); err != nil { cleanupKeys(tmpDir, allKeyPaths) return nil, fmt.Errorf("write key for %s: %w", nodes[i].Name(), err) } @@ -77,12 +107,10 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) { // LoadAgentKeys loads SSH keys for the given nodes into the system ssh-agent. // Used by push fanout to enable agent forwarding. -// Calls `rw vault ssh agent-load ...` +// Retrieves private keys from the rootwallet agent and pipes them to ssh-add. func LoadAgentKeys(nodes []inspector.Node) error { - rw, err := rwBinary() - if err != nil { - return err - } + client := newClient() + ctx := context.Background() // Deduplicate host/user pairs seen := make(map[string]bool) @@ -105,76 +133,65 @@ func LoadAgentKeys(nodes []inspector.Node) error { return nil } - args := append([]string{"vault", "ssh", "agent-load"}, targets...) - cmd := exec.Command(rw, args...) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stderr // info messages go to stderr + for _, target := range targets { + host, user := parseVaultTarget(target) + data, err := client.GetSSHKey(ctx, host, user, "priv") + if err != nil { + return wrapAgentError(err, fmt.Sprintf("get key for %s", target)) + } - if err := cmd.Run(); err != nil { - return fmt.Errorf("rw vault ssh agent-load failed: %w", err) + // Pipe private key to ssh-add via stdin + cmd := exec.Command("ssh-add", "-") + cmd.Stdin = strings.NewReader(data.PrivateKey) + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("ssh-add failed for %s: %w", target, err) + } } + return nil } // EnsureVaultEntry creates a wallet SSH entry if it doesn't already exist. -// Checks existence via `rw vault ssh get --pub`, and if missing, -// runs `rw vault ssh add ` to create it. +// Checks the rootwallet agent for an existing entry, creates one if not found. func EnsureVaultEntry(vaultTarget string) error { - rw, err := rwBinary() - if err != nil { - return err + client := newClient() + ctx := context.Background() + + host, user := parseVaultTarget(vaultTarget) + + // Check if entry already exists + _, err := client.GetSSHKey(ctx, host, user, "pub") + if err == nil { + return nil // entry exists } - // Check if entry exists by trying to get the public key - cmd := exec.Command(rw, "vault", "ssh", "get", vaultTarget, "--pub") - if err := cmd.Run(); err == nil { - return nil // entry already exists - } - - // Entry doesn't exist — try to create it - addCmd := exec.Command(rw, "vault", "ssh", "add", vaultTarget) - addCmd.Stdin = os.Stdin - addCmd.Stdout = os.Stderr - addCmd.Stderr = os.Stderr - if err := addCmd.Run(); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - stderr := strings.TrimSpace(string(exitErr.Stderr)) - if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") { - return fmt.Errorf("wallet is locked — run: rw unlock") - } + // If not found, create it + if rwagent.IsNotFound(err) { + _, createErr := client.CreateSSHEntry(ctx, host, user) + if createErr != nil { + return wrapAgentError(createErr, fmt.Sprintf("create vault entry %s", vaultTarget)) } - return fmt.Errorf("rw vault ssh add %s failed: %w", vaultTarget, err) + return nil } - return nil + + return wrapAgentError(err, fmt.Sprintf("check vault entry %s", vaultTarget)) } // ResolveVaultPublicKey returns the OpenSSH public key string for a vault entry. -// Calls `rw vault ssh get --pub`. func ResolveVaultPublicKey(vaultTarget string) (string, error) { - rw, err := rwBinary() + client := newClient() + ctx := context.Background() + + host, user := parseVaultTarget(vaultTarget) + data, err := client.GetSSHKey(ctx, host, user, "pub") if err != nil { - return "", err + return "", wrapAgentError(err, fmt.Sprintf("get public key for %s", vaultTarget)) } - cmd := exec.Command(rw, "vault", "ssh", "get", vaultTarget, "--pub") - out, err := cmd.Output() - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - stderr := strings.TrimSpace(string(exitErr.Stderr)) - if strings.Contains(stderr, "No SSH entry") { - return "", fmt.Errorf("no vault SSH entry for %s — run: rw vault ssh add %s", vaultTarget, vaultTarget) - } - if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") { - return "", fmt.Errorf("wallet is locked — run: rw unlock") - } - return "", fmt.Errorf("%s", stderr) - } - return "", fmt.Errorf("rw command failed: %w", err) - } - - pubKey := strings.TrimSpace(string(out)) + pubKey := strings.TrimSpace(data.PublicKey) if !strings.HasPrefix(pubKey, "ssh-") { - return "", fmt.Errorf("rw returned invalid public key for %s", vaultTarget) + return "", fmt.Errorf("agent returned invalid public key for %s", vaultTarget) } return pubKey, nil } @@ -188,49 +205,6 @@ func parseVaultTarget(target string) (host, user string) { return target[:idx], target[idx+1:] } -// resolveWalletKey calls `rw vault ssh get / --priv` -// and returns the PEM string. Requires an active rw session. -func resolveWalletKey(rw string, host, user string) (string, error) { - target := host + "/" + user - cmd := exec.Command(rw, "vault", "ssh", "get", target, "--priv") - out, err := cmd.Output() - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - stderr := strings.TrimSpace(string(exitErr.Stderr)) - if strings.Contains(stderr, "No SSH entry") { - return "", fmt.Errorf("no vault SSH entry for %s — run: rw vault ssh add %s", target, target) - } - if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") { - return "", fmt.Errorf("wallet is locked — run: rw unlock") - } - return "", fmt.Errorf("%s", stderr) - } - return "", fmt.Errorf("rw command failed: %w", err) - } - pem := string(out) - if !strings.Contains(pem, "BEGIN OPENSSH PRIVATE KEY") { - return "", fmt.Errorf("rw returned invalid key for %s", target) - } - return pem, nil -} - -// rwBinary returns the path to the `rw` binary. -// Checks RW_PATH env var first, then PATH. -func rwBinary() (string, error) { - if p := os.Getenv("RW_PATH"); p != "" { - if _, err := os.Stat(p); err == nil { - return p, nil - } - return "", fmt.Errorf("RW_PATH=%q not found", p) - } - - p, err := exec.LookPath("rw") - if err != nil { - return "", fmt.Errorf("rw not found in PATH — install rootwallet CLI: https://github.com/DeBrosOfficial/rootwallet") - } - return p, nil -} - // cleanupKeys zero-overwrites and removes all key files, then removes the temp dir. func cleanupKeys(tmpDir string, keyPaths []string) { zeros := make([]byte, 512) diff --git a/pkg/cli/remotessh/wallet_test.go b/pkg/cli/remotessh/wallet_test.go index b3fece6..eca1f16 100644 --- a/pkg/cli/remotessh/wallet_test.go +++ b/pkg/cli/remotessh/wallet_test.go @@ -1,6 +1,39 @@ package remotessh -import "testing" +import ( + "context" + "errors" + "os" + "strings" + "testing" + + "github.com/DeBrosOfficial/network/pkg/inspector" + "github.com/DeBrosOfficial/network/pkg/rwagent" +) + +const testPrivateKey = "-----BEGIN OPENSSH PRIVATE KEY-----\nfake-key-data\n-----END OPENSSH PRIVATE KEY-----" + +// mockClient implements vaultClient for testing. +type mockClient struct { + getSSHKey func(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error) + createSSHEntry func(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error) +} + +func (m *mockClient) GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error) { + return m.getSSHKey(ctx, host, username, format) +} + +func (m *mockClient) CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error) { + return m.createSSHEntry(ctx, host, username) +} + +// withMockClient replaces newClient for the duration of a test. +func withMockClient(t *testing.T, mock *mockClient) { + t.Helper() + orig := newClient + newClient = func() vaultClient { return mock } + t.Cleanup(func() { newClient = orig }) +} func TestParseVaultTarget(t *testing.T) { tests := []struct { @@ -13,6 +46,7 @@ func TestParseVaultTarget(t *testing.T) { {"my-host/my-user", "my-host", "my-user"}, {"noslash", "noslash", ""}, {"a/b/c", "a", "b/c"}, + {"", "", ""}, } for _, tt := range tests { @@ -27,3 +61,316 @@ func TestParseVaultTarget(t *testing.T) { }) } } + +func TestWrapAgentError_notRunning(t *testing.T) { + err := wrapAgentError(rwagent.ErrAgentNotRunning, "test action") + if !strings.Contains(err.Error(), "not running") { + t.Errorf("expected 'not running' message, got: %s", err) + } + if !strings.Contains(err.Error(), "rw agent start") { + t.Errorf("expected actionable hint, got: %s", err) + } +} + +func TestWrapAgentError_locked(t *testing.T) { + agentErr := &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "agent is locked"} + err := wrapAgentError(agentErr, "test action") + if !strings.Contains(err.Error(), "locked") { + t.Errorf("expected 'locked' message, got: %s", err) + } + if !strings.Contains(err.Error(), "rw agent unlock") { + t.Errorf("expected actionable hint, got: %s", err) + } +} + +func TestWrapAgentError_generic(t *testing.T) { + err := wrapAgentError(errors.New("some error"), "test action") + if !strings.Contains(err.Error(), "test action") { + t.Errorf("expected action context, got: %s", err) + } + if !strings.Contains(err.Error(), "some error") { + t.Errorf("expected wrapped error, got: %s", err) + } +} + +func TestPrepareNodeKeys_success(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) { + return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil + }, + } + withMockClient(t, mock) + + nodes := []inspector.Node{ + {Host: "10.0.0.1", User: "root"}, + {Host: "10.0.0.2", User: "root"}, + } + + cleanup, err := PrepareNodeKeys(nodes) + if err != nil { + t.Fatalf("PrepareNodeKeys() error = %v", err) + } + defer cleanup() + + for i, n := range nodes { + if n.SSHKey == "" { + t.Errorf("node[%d].SSHKey is empty", i) + continue + } + data, err := os.ReadFile(n.SSHKey) + if err != nil { + t.Errorf("node[%d] key file unreadable: %v", i, err) + continue + } + if !strings.Contains(string(data), "BEGIN OPENSSH PRIVATE KEY") { + t.Errorf("node[%d] key file has wrong content", i) + } + } +} + +func TestPrepareNodeKeys_deduplication(t *testing.T) { + callCount := 0 + mock := &mockClient{ + getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) { + callCount++ + return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil + }, + } + withMockClient(t, mock) + + nodes := []inspector.Node{ + {Host: "10.0.0.1", User: "root"}, + {Host: "10.0.0.1", User: "root"}, // same host/user + } + + cleanup, err := PrepareNodeKeys(nodes) + if err != nil { + t.Fatalf("PrepareNodeKeys() error = %v", err) + } + defer cleanup() + + if callCount != 1 { + t.Errorf("expected 1 agent call (dedup), got %d", callCount) + } + if nodes[0].SSHKey != nodes[1].SSHKey { + t.Error("expected same key path for deduplicated nodes") + } +} + +func TestPrepareNodeKeys_vaultTarget(t *testing.T) { + var capturedHost, capturedUser string + mock := &mockClient{ + getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) { + capturedHost = host + capturedUser = username + return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil + }, + } + withMockClient(t, mock) + + nodes := []inspector.Node{ + {Host: "10.0.0.1", User: "root", VaultTarget: "sandbox/admin"}, + } + + cleanup, err := PrepareNodeKeys(nodes) + if err != nil { + t.Fatalf("PrepareNodeKeys() error = %v", err) + } + defer cleanup() + + if capturedHost != "sandbox" || capturedUser != "admin" { + t.Errorf("expected host=sandbox user=admin, got host=%s user=%s", capturedHost, capturedUser) + } +} + +func TestPrepareNodeKeys_agentNotRunning(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + return nil, rwagent.ErrAgentNotRunning + }, + } + withMockClient(t, mock) + + nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}} + _, err := PrepareNodeKeys(nodes) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "not running") { + t.Errorf("expected 'not running' error, got: %s", err) + } +} + +func TestPrepareNodeKeys_invalidKey(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + return &rwagent.VaultSSHData{PrivateKey: "garbage"}, nil + }, + } + withMockClient(t, mock) + + nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}} + _, err := PrepareNodeKeys(nodes) + if err == nil { + t.Fatal("expected error for invalid key") + } + if !strings.Contains(err.Error(), "invalid key") { + t.Errorf("expected 'invalid key' error, got: %s", err) + } +} + +func TestPrepareNodeKeys_cleanupOnError(t *testing.T) { + callNum := 0 + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + callNum++ + if callNum == 2 { + return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"} + } + return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil + }, + } + withMockClient(t, mock) + + nodes := []inspector.Node{ + {Host: "10.0.0.1", User: "root"}, + {Host: "10.0.0.2", User: "root"}, + } + + _, err := PrepareNodeKeys(nodes) + if err == nil { + t.Fatal("expected error") + } + + // First node's temp file should have been cleaned up + if nodes[0].SSHKey != "" { + if _, statErr := os.Stat(nodes[0].SSHKey); statErr == nil { + t.Error("expected temp key file to be cleaned up on error") + } + } +} + +func TestPrepareNodeKeys_emptyNodes(t *testing.T) { + mock := &mockClient{} + withMockClient(t, mock) + + cleanup, err := PrepareNodeKeys(nil) + if err != nil { + t.Fatalf("expected no error for empty nodes, got: %v", err) + } + cleanup() // should not panic +} + +func TestEnsureVaultEntry_exists(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil + }, + } + withMockClient(t, mock) + + if err := EnsureVaultEntry("sandbox/root"); err != nil { + t.Fatalf("EnsureVaultEntry() error = %v", err) + } +} + +func TestEnsureVaultEntry_creates(t *testing.T) { + created := false + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"} + }, + createSSHEntry: func(_ context.Context, host, username string) (*rwagent.VaultSSHData, error) { + created = true + if host != "sandbox" || username != "root" { + t.Errorf("unexpected create args: %s/%s", host, username) + } + return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil + }, + } + withMockClient(t, mock) + + if err := EnsureVaultEntry("sandbox/root"); err != nil { + t.Fatalf("EnsureVaultEntry() error = %v", err) + } + if !created { + t.Error("expected CreateSSHEntry to be called") + } +} + +func TestEnsureVaultEntry_locked(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"} + }, + } + withMockClient(t, mock) + + err := EnsureVaultEntry("sandbox/root") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "locked") { + t.Errorf("expected locked error, got: %s", err) + } +} + +func TestResolveVaultPublicKey_success(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, format string) (*rwagent.VaultSSHData, error) { + if format != "pub" { + t.Errorf("expected format=pub, got %s", format) + } + return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA..."}, nil + }, + } + withMockClient(t, mock) + + key, err := ResolveVaultPublicKey("sandbox/root") + if err != nil { + t.Fatalf("ResolveVaultPublicKey() error = %v", err) + } + if !strings.HasPrefix(key, "ssh-") { + t.Errorf("expected ssh- prefix, got: %s", key) + } +} + +func TestResolveVaultPublicKey_invalidFormat(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + return &rwagent.VaultSSHData{PublicKey: "not-a-valid-key"}, nil + }, + } + withMockClient(t, mock) + + _, err := ResolveVaultPublicKey("sandbox/root") + if err == nil { + t.Fatal("expected error for invalid public key") + } + if !strings.Contains(err.Error(), "invalid public key") { + t.Errorf("expected 'invalid public key' error, got: %s", err) + } +} + +func TestResolveVaultPublicKey_notFound(t *testing.T) { + mock := &mockClient{ + getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) { + return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"} + }, + } + withMockClient(t, mock) + + _, err := ResolveVaultPublicKey("sandbox/root") + if err == nil { + t.Fatal("expected error") + } +} + +func TestLoadAgentKeys_emptyNodes(t *testing.T) { + mock := &mockClient{} + withMockClient(t, mock) + + if err := LoadAgentKeys(nil); err != nil { + t.Fatalf("expected no error for empty nodes, got: %v", err) + } +} diff --git a/pkg/cli/sandbox/create.go b/pkg/cli/sandbox/create.go index 2c26dac..b18d647 100644 --- a/pkg/cli/sandbox/create.go +++ b/pkg/cli/sandbox/create.go @@ -75,8 +75,7 @@ func Create(name string) error { // Phase 4: Install genesis node fmt.Println("\nPhase 4: Installing genesis node...") - tokens, err := phase4InstallGenesis(cfg, state, sshKeyPath) - if err != nil { + if err := phase4InstallGenesis(cfg, state, sshKeyPath); err != nil { state.Status = StatusError SaveState(state) return fmt.Errorf("install genesis: %w", err) @@ -84,7 +83,7 @@ func Create(name string) error { // Phase 5: Join remaining nodes fmt.Println("\nPhase 5: Joining remaining nodes...") - if err := phase5JoinNodes(cfg, state, tokens, sshKeyPath); err != nil { + if err := phase5JoinNodes(cfg, state, sshKeyPath); err != nil { state.Status = StatusError SaveState(state) return fmt.Errorf("join nodes: %w", err) @@ -280,60 +279,52 @@ func phase3UploadArchive(state *SandboxState, sshKeyPath string) error { return nil } -// phase4InstallGenesis installs the genesis node and generates invite tokens. -func phase4InstallGenesis(cfg *Config, state *SandboxState, sshKeyPath string) ([]string, error) { +// phase4InstallGenesis installs the genesis node. +func phase4InstallGenesis(cfg *Config, state *SandboxState, sshKeyPath string) error { genesis := state.GenesisServer() node := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath} // Install genesis - installCmd := fmt.Sprintf("/opt/orama/bin/orama node install --vps-ip %s --domain %s --base-domain %s --nameserver --skip-checks", + installCmd := fmt.Sprintf("/opt/orama/bin/orama node install --vps-ip %s --domain %s --base-domain %s --nameserver --anyone-client --skip-checks", genesis.IP, cfg.Domain, cfg.Domain) fmt.Printf(" Installing on %s (%s)...\n", genesis.Name, genesis.IP) if err := remotessh.RunSSHStreaming(node, installCmd, remotessh.WithNoHostKeyCheck()); err != nil { - return nil, fmt.Errorf("install genesis: %w", err) + return fmt.Errorf("install genesis: %w", err) } // Wait for RQLite leader fmt.Print(" Waiting for RQLite leader...") if err := waitForRQLiteHealth(node, 3*time.Minute); err != nil { - return nil, fmt.Errorf("genesis health: %w", err) + return fmt.Errorf("genesis health: %w", err) } fmt.Println(" OK") - // Generate invite tokens (one per remaining node) - fmt.Print(" Generating invite tokens...") - remaining := len(state.Servers) - 1 - tokens := make([]string, remaining) - - for i := 0; i < remaining; i++ { - token, err := generateInviteToken(node) - if err != nil { - return nil, fmt.Errorf("generate invite token %d: %w", i+1, err) - } - tokens[i] = token - fmt.Print(".") - } - fmt.Println(" OK") - - return tokens, nil + return nil } // phase5JoinNodes joins the remaining 4 nodes to the cluster (serial). -func phase5JoinNodes(cfg *Config, state *SandboxState, tokens []string, sshKeyPath string) error { - genesisIP := state.GenesisServer().IP +// Generates invite tokens just-in-time to avoid expiry during long installs. +func phase5JoinNodes(cfg *Config, state *SandboxState, sshKeyPath string) error { + genesis := state.GenesisServer() + genesisNode := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath} for i := 1; i < len(state.Servers); i++ { srv := state.Servers[i] node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath} - token := tokens[i-1] + + // Generate token just before use to avoid expiry + token, err := generateInviteToken(genesisNode) + if err != nil { + return fmt.Errorf("generate invite token for %s: %w", srv.Name, err) + } var installCmd string if srv.Role == "nameserver" { - installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --domain %s --base-domain %s --nameserver --skip-checks", - genesisIP, token, srv.IP, cfg.Domain, cfg.Domain) + installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --domain %s --base-domain %s --nameserver --anyone-client --skip-checks", + genesis.IP, token, srv.IP, cfg.Domain, cfg.Domain) } else { - installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --base-domain %s --skip-checks", - genesisIP, token, srv.IP, cfg.Domain) + installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --base-domain %s --anyone-client --skip-checks", + genesis.IP, token, srv.IP, cfg.Domain) } fmt.Printf(" [%d/%d] Joining %s (%s, %s)...\n", i, len(state.Servers)-1, srv.Name, srv.IP, srv.Role) diff --git a/pkg/cli/sandbox/hetzner.go b/pkg/cli/sandbox/hetzner.go index 51d62a0..dec4d44 100644 --- a/pkg/cli/sandbox/hetzner.go +++ b/pkg/cli/sandbox/hetzner.go @@ -346,7 +346,11 @@ func (c *HetznerClient) UploadSSHKey(name, publicKey string) (*HetznerSSHKey, er // ListSSHKeysByFingerprint finds SSH keys matching a fingerprint. func (c *HetznerClient) ListSSHKeysByFingerprint(fingerprint string) ([]HetznerSSHKey, error) { - body, err := c.get("/ssh_keys?fingerprint=" + fingerprint) + path := "/ssh_keys" + if fingerprint != "" { + path += "?fingerprint=" + fingerprint + } + body, err := c.get(path) if err != nil { return nil, fmt.Errorf("list SSH keys: %w", err) } diff --git a/pkg/cli/sandbox/setup.go b/pkg/cli/sandbox/setup.go index 9976dbe..16329d1 100644 --- a/pkg/cli/sandbox/setup.go +++ b/pkg/cli/sandbox/setup.go @@ -408,17 +408,38 @@ func setupSSHKey(client *HetznerClient) (SSHKeyConfig, error) { fmt.Print(" Uploading to Hetzner... ") key, err := client.UploadSSHKey("orama-sandbox", pubStr) if err != nil { - // Key may already exist on Hetzner — try to find by fingerprint - existing, listErr := client.ListSSHKeysByFingerprint("") // empty = list all + // Key may already exist on Hetzner — check if it matches the current vault key + existing, listErr := client.ListSSHKeysByFingerprint("") if listErr == nil { for _, k := range existing { - if strings.TrimSpace(k.PublicKey) == pubStr { + if sshKeyDataEqual(k.PublicKey, pubStr) { + // Key data matches — safe to reuse regardless of name fmt.Printf("already exists (ID: %d)\n", k.ID) return SSHKeyConfig{ HetznerID: k.ID, VaultTarget: vaultTarget, }, nil } + if k.Name == "orama-sandbox" { + // Name matches but key data differs — vault key was rotated. + // Delete the stale Hetzner key so we can re-upload the current one. + fmt.Print("stale key detected, replacing... ") + if delErr := client.DeleteSSHKey(k.ID); delErr != nil { + fmt.Println("FAILED") + return SSHKeyConfig{}, fmt.Errorf("delete stale SSH key (ID %d): %w", k.ID, delErr) + } + // Re-upload with current vault key + newKey, uploadErr := client.UploadSSHKey("orama-sandbox", pubStr) + if uploadErr != nil { + fmt.Println("FAILED") + return SSHKeyConfig{}, fmt.Errorf("re-upload SSH key: %w", uploadErr) + } + fmt.Printf("OK (ID: %d)\n", newKey.ID) + return SSHKeyConfig{ + HetznerID: newKey.ID, + VaultTarget: vaultTarget, + }, nil + } } } @@ -433,6 +454,17 @@ func setupSSHKey(client *HetznerClient) (SSHKeyConfig, error) { }, nil } +// sshKeyDataEqual compares two SSH public key strings by their key type and +// data, ignoring the optional comment field. +func sshKeyDataEqual(a, b string) bool { + partsA := strings.Fields(strings.TrimSpace(a)) + partsB := strings.Fields(strings.TrimSpace(b)) + if len(partsA) < 2 || len(partsB) < 2 { + return false + } + return partsA[0] == partsB[0] && partsA[1] == partsB[1] +} + // verifyDNS checks if glue records for the sandbox domain are configured. // // There's a chicken-and-egg problem: NS records can't fully resolve until diff --git a/pkg/cli/sandbox/setup_test.go b/pkg/cli/sandbox/setup_test.go new file mode 100644 index 0000000..3b531b5 --- /dev/null +++ b/pkg/cli/sandbox/setup_test.go @@ -0,0 +1,82 @@ +package sandbox + +import "testing" + +func TestSSHKeyDataEqual(t *testing.T) { + tests := []struct { + name string + a string + b string + expected bool + }{ + { + name: "identical keys", + a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest comment1", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest comment1", + expected: true, + }, + { + name: "same key different comments", + a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest user@host", + expected: true, + }, + { + name: "same key one without comment", + a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault", + expected: true, + }, + { + name: "different key data", + a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBoldkey vault", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBnewkey vault", + expected: false, + }, + { + name: "different key types", + a: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAB vault", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault", + expected: false, + }, + { + name: "empty string a", + a: "", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault", + expected: false, + }, + { + name: "empty string b", + a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault", + b: "", + expected: false, + }, + { + name: "both empty", + a: "", + b: "", + expected: false, + }, + { + name: "single field only", + a: "ssh-ed25519", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest", + expected: false, + }, + { + name: "whitespace trimming", + a: " ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault ", + b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := sshKeyDataEqual(tt.a, tt.b) + if got != tt.expected { + t.Errorf("sshKeyDataEqual(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.expected) + } + }) + } +} diff --git a/pkg/environments/production/installers/anyone_relay.go b/pkg/environments/production/installers/anyone_relay.go index f1e6acb..4809d1b 100644 --- a/pkg/environments/production/installers/anyone_relay.go +++ b/pkg/environments/production/installers/anyone_relay.go @@ -338,6 +338,7 @@ func (ari *AnyoneRelayInstaller) ConfigureClient() error { config := `# Anyone Client Configuration (Managed by Orama Network) # Client-only mode — no relay traffic, no ORPort +AgreeToTerms 1 SocksPort 9050 Log notice file /var/log/anon/notices.log @@ -360,6 +361,8 @@ func (ari *AnyoneRelayInstaller) generateAnonrc() string { sb.WriteString("# Anyone Relay Configuration (Managed by Orama Network)\n") sb.WriteString("# Generated automatically - manual edits may be overwritten\n\n") + sb.WriteString("AgreeToTerms 1\n\n") + // Nickname sb.WriteString(fmt.Sprintf("Nickname %s\n", ari.config.Nickname)) diff --git a/pkg/environments/production/installers/caddy.go b/pkg/environments/production/installers/caddy.go index d8f73e7..5aad389 100644 --- a/pkg/environments/production/installers/caddy.go +++ b/pkg/environments/production/installers/caddy.go @@ -367,15 +367,13 @@ require ( // If baseDomain is provided and different from domain, Caddy also serves // the base domain and its wildcard (e.g., *.dbrs.space alongside *.node1.dbrs.space). func (ci *CaddyInstaller) generateCaddyfile(domain, email, acmeEndpoint, baseDomain string) string { - // Primary: Let's Encrypt via ACME DNS-01 challenge - // Fallback: Caddy's internal CA (self-signed, trust root on clients) + // Let's Encrypt via ACME DNS-01 challenge (no fallback to self-signed) tlsBlock := fmt.Sprintf(` tls { issuer acme { dns orama { endpoint %s } } - issuer internal }`, acmeEndpoint) var sb strings.Builder diff --git a/pkg/environments/production/installers/coredns.go b/pkg/environments/production/installers/coredns.go index b64378f..dcf6d4f 100644 --- a/pkg/environments/production/installers/coredns.go +++ b/pkg/environments/production/installers/coredns.go @@ -356,8 +356,12 @@ func (ci *CoreDNSInstaller) generateCorefile(domain, rqliteDSN string) string { # CoreDNS cache would cache NXDOMAIN and break ACME DNS-01 challenges. } -# Forward all other queries to upstream DNS +# Forward non-authoritative queries to upstream DNS (localhost only). +# The bind directive restricts this block to 127.0.0.1 so the node itself +# can resolve external domains (apt, github, etc.) but external clients +# cannot use this server as an open recursive resolver (BSI/CERT-Bund). . { + bind 127.0.0.1 forward . 8.8.8.8 8.8.4.4 1.1.1.1 cache 300 errors diff --git a/pkg/environments/production/installers/coredns_test.go b/pkg/environments/production/installers/coredns_test.go new file mode 100644 index 0000000..d5ae2e7 --- /dev/null +++ b/pkg/environments/production/installers/coredns_test.go @@ -0,0 +1,151 @@ +package installers + +import ( + "io" + "strings" + "testing" +) + +// newTestCoreDNSInstaller creates a CoreDNSInstaller suitable for unit tests. +// It uses a non-existent oramaHome so generateCorefile won't find a password file +// and will produce output without auth credentials. +func newTestCoreDNSInstaller() *CoreDNSInstaller { + return &CoreDNSInstaller{ + BaseInstaller: NewBaseInstaller("amd64", io.Discard), + version: "1.11.1", + oramaHome: "/nonexistent", + } +} + +func TestGenerateCorefile_ContainsBindLocalhost(t *testing.T) { + ci := newTestCoreDNSInstaller() + corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001") + + if !strings.Contains(corefile, "bind 127.0.0.1") { + t.Fatal("Corefile forward block must contain 'bind 127.0.0.1' to prevent open resolver") + } +} + +func TestGenerateCorefile_ForwardBlockIsLocalhostOnly(t *testing.T) { + ci := newTestCoreDNSInstaller() + corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001") + + // The bind directive must appear inside the catch-all (.) block, + // not inside the authoritative domain block. + // Find the ". {" block and verify bind is inside it. + dotBlockIdx := strings.Index(corefile, ". {") + if dotBlockIdx == -1 { + t.Fatal("Corefile must contain a catch-all '. {' server block") + } + + dotBlock := corefile[dotBlockIdx:] + closingIdx := strings.Index(dotBlock, "}") + if closingIdx == -1 { + t.Fatal("Catch-all block has no closing brace") + } + dotBlock = dotBlock[:closingIdx] + + if !strings.Contains(dotBlock, "bind 127.0.0.1") { + t.Error("bind 127.0.0.1 must be inside the catch-all (.) block, not the domain block") + } + + if !strings.Contains(dotBlock, "forward .") { + t.Error("forward directive must be inside the catch-all (.) block") + } +} + +func TestGenerateCorefile_AuthoritativeBlockNoBindRestriction(t *testing.T) { + ci := newTestCoreDNSInstaller() + corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001") + + // The authoritative domain block should NOT have a bind directive + // (it must listen on all interfaces to serve external DNS queries). + domainBlockStart := strings.Index(corefile, "dbrs.space {") + if domainBlockStart == -1 { + t.Fatal("Corefile must contain 'dbrs.space {' server block") + } + + // Extract the domain block (up to the first closing brace) + domainBlock := corefile[domainBlockStart:] + closingIdx := strings.Index(domainBlock, "}") + if closingIdx == -1 { + t.Fatal("Domain block has no closing brace") + } + domainBlock = domainBlock[:closingIdx] + + if strings.Contains(domainBlock, "bind ") { + t.Error("Authoritative domain block must not have a bind directive — it must listen on all interfaces") + } +} + +func TestGenerateCorefile_ContainsDomainZone(t *testing.T) { + ci := newTestCoreDNSInstaller() + + tests := []struct { + domain string + }{ + {"dbrs.space"}, + {"orama.network"}, + {"example.com"}, + } + + for _, tt := range tests { + t.Run(tt.domain, func(t *testing.T) { + corefile := ci.generateCorefile(tt.domain, "http://localhost:5001") + + if !strings.Contains(corefile, tt.domain+" {") { + t.Errorf("Corefile must contain server block for domain %q", tt.domain) + } + + if !strings.Contains(corefile, "rqlite {") { + t.Error("Corefile must contain rqlite plugin block") + } + }) + } +} + +func TestGenerateCorefile_ContainsRQLiteDSN(t *testing.T) { + ci := newTestCoreDNSInstaller() + dsn := "http://10.0.0.1:5001" + corefile := ci.generateCorefile("dbrs.space", dsn) + + if !strings.Contains(corefile, "dsn "+dsn) { + t.Errorf("Corefile must contain RQLite DSN %q", dsn) + } +} + +func TestGenerateCorefile_NoAuthBlockWithoutCredentials(t *testing.T) { + ci := newTestCoreDNSInstaller() + corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001") + + if strings.Contains(corefile, "username") || strings.Contains(corefile, "password") { + t.Error("Corefile must not contain auth credentials when secrets file is absent") + } +} + +func TestGeneratePluginConfig_ContainsBindPlugin(t *testing.T) { + ci := newTestCoreDNSInstaller() + cfg := ci.generatePluginConfig() + + if !strings.Contains(cfg, "bind:bind") { + t.Error("Plugin config must include the bind plugin (required for localhost-only forwarding)") + } +} + +func TestGeneratePluginConfig_ContainsACLPlugin(t *testing.T) { + ci := newTestCoreDNSInstaller() + cfg := ci.generatePluginConfig() + + if !strings.Contains(cfg, "acl:acl") { + t.Error("Plugin config must include the acl plugin") + } +} + +func TestGeneratePluginConfig_ContainsRQLitePlugin(t *testing.T) { + ci := newTestCoreDNSInstaller() + cfg := ci.generatePluginConfig() + + if !strings.Contains(cfg, "rqlite:rqlite") { + t.Error("Plugin config must include the rqlite plugin") + } +} diff --git a/pkg/environments/production/services.go b/pkg/environments/production/services.go index 3eca7e0..4101e0b 100644 --- a/pkg/environments/production/services.go +++ b/pkg/environments/production/services.go @@ -419,7 +419,7 @@ Description=Caddy HTTP/2 Server Documentation=https://caddyserver.com/docs/ After=network-online.target orama-node.service coredns.service Wants=network-online.target -Wants=orama-node.service +Requires=orama-node.service [Service] Type=simple @@ -428,6 +428,9 @@ ReadWritePaths=%[2]s /var/lib/caddy /etc/caddy Environment=XDG_DATA_HOME=/var/lib/caddy AmbientCapabilities=CAP_NET_BIND_SERVICE CapabilityBoundingSet=CAP_NET_BIND_SERVICE +ExecStartPre=/bin/sh -c 'for i in $$(seq 1 30); do curl -so /dev/null http://localhost:6001/health 2>/dev/null && exit 0; sleep 2; done; echo "Gateway not ready after 60s"; exit 1' +ExecStartPre=/bin/sh -c 'DOMAIN=$$(grep -oP "^\*\\.\K[^ {]+" /etc/caddy/Caddyfile | tail -1); [ -z "$$DOMAIN" ] && exit 0; for i in $$(seq 1 30); do dig +short +timeout=2 "$$DOMAIN" SOA 2>/dev/null | grep -q . && exit 0; sleep 2; done; echo "DNS not resolving $$DOMAIN after 60s (ACME may fail)"; exit 0' +TimeoutStartSec=180 ExecStart=/usr/bin/caddy run --environ --config /etc/caddy/Caddyfile ExecReload=/usr/bin/caddy reload --config /etc/caddy/Caddyfile TimeoutStopSec=5s diff --git a/pkg/environments/production/services_test.go b/pkg/environments/production/services_test.go index db38b11..271ad4a 100644 --- a/pkg/environments/production/services_test.go +++ b/pkg/environments/production/services_test.go @@ -78,6 +78,39 @@ func TestGenerateRQLiteService(t *testing.T) { } } +// TestGenerateCaddyService_GatewayReadinessCheck verifies Caddy waits for gateway before starting +func TestGenerateCaddyService_GatewayReadinessCheck(t *testing.T) { + ssg := &SystemdServiceGenerator{ + oramaHome: "/opt/orama", + oramaDir: "/opt/orama/.orama", + } + + unit := ssg.GenerateCaddyService() + + // Must have ExecStartPre that polls gateway health + if !strings.Contains(unit, "ExecStartPre=") { + t.Error("missing ExecStartPre directive for gateway readiness check") + } + if !strings.Contains(unit, "localhost:6001/health") { + t.Error("ExecStartPre should poll localhost:6001/health") + } + + // Must use Requires= (hard dependency), not Wants= (soft dependency) + if !strings.Contains(unit, "Requires=orama-node.service") { + t.Error("missing Requires=orama-node.service (hard dependency)") + } + if strings.Contains(unit, "Wants=orama-node.service") { + t.Error("should use Requires= not Wants= for orama-node.service dependency") + } + + // ExecStartPre must appear before ExecStart + preIdx := strings.Index(unit, "ExecStartPre=") + startIdx := strings.Index(unit, "ExecStart=/usr/bin/caddy") + if preIdx < 0 || startIdx < 0 || preIdx >= startIdx { + t.Error("ExecStartPre must appear before ExecStart") + } +} + // TestGenerateRQLiteServiceArgs verifies the ExecStart command arguments func TestGenerateRQLiteServiceArgs(t *testing.T) { ssg := &SystemdServiceGenerator{ diff --git a/pkg/gateway/status_handlers.go b/pkg/gateway/status_handlers.go index 19d1862..d08058f 100644 --- a/pkg/gateway/status_handlers.go +++ b/pkg/gateway/status_handlers.go @@ -141,11 +141,15 @@ func (g *Gateway) healthHandler(w http.ResponseWriter, r *http.Request) { ch <- nr }() - // Vault Guardian (TCP connect to localhost:7500) + // Vault Guardian (TCP connect on WireGuard IP:7500) go func() { nr := namedResult{name: "vault"} start := time.Now() - conn, err := net.DialTimeout("tcp", "localhost:7500", 2*time.Second) + vaultAddr := "localhost:7500" + if g.localWireGuardIP != "" { + vaultAddr = g.localWireGuardIP + ":7500" + } + conn, err := net.DialTimeout("tcp", vaultAddr, 2*time.Second) if err != nil { nr.result = checkResult{Status: "error", Latency: time.Since(start).String(), Error: fmt.Sprintf("vault-guardian unreachable on port 7500: %v", err)} } else { diff --git a/pkg/ipfs/cluster_peer.go b/pkg/ipfs/cluster_peer.go index 9f28ac1..284a47b 100644 --- a/pkg/ipfs/cluster_peer.go +++ b/pkg/ipfs/cluster_peer.go @@ -109,9 +109,10 @@ func (cm *ClusterConfigManager) DiscoverClusterPeersFromLibP2P(h host.Host) erro info := h.Peerstore().PeerInfo(p) for _, addr := range info.Addrs { - // Extract IP from multiaddr + // Extract IP from multiaddr — only use WireGuard IPs (10.0.0.x) + // for inter-node queries since port 6001 is blocked on public interfaces by UFW ip := extractIPFromMultiaddr(addr) - if ip != "" && !strings.HasPrefix(ip, "127.") && !strings.HasPrefix(ip, "::1") { + if ip != "" && strings.HasPrefix(ip, "10.0.0.") { peerIPs[ip] = true } } diff --git a/pkg/ipfs/cluster_peer_test.go b/pkg/ipfs/cluster_peer_test.go new file mode 100644 index 0000000..b7ba590 --- /dev/null +++ b/pkg/ipfs/cluster_peer_test.go @@ -0,0 +1,95 @@ +package ipfs + +import ( + "testing" + + "github.com/multiformats/go-multiaddr" +) + +func TestExtractIPFromMultiaddr(t *testing.T) { + tests := []struct { + name string + addr string + expected string + }{ + { + name: "ipv4 tcp address", + addr: "/ip4/10.0.0.1/tcp/4001", + expected: "10.0.0.1", + }, + { + name: "ipv4 public address", + addr: "/ip4/203.0.113.5/tcp/4001", + expected: "203.0.113.5", + }, + { + name: "ipv4 loopback", + addr: "/ip4/127.0.0.1/tcp/4001", + expected: "127.0.0.1", + }, + { + name: "ipv6 address", + addr: "/ip6/::1/tcp/4001", + expected: "[::1]", + }, + { + name: "wireguard ip with udp", + addr: "/ip4/10.0.0.3/udp/4001/quic", + expected: "10.0.0.3", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ma, err := multiaddr.NewMultiaddr(tt.addr) + if err != nil { + t.Fatalf("failed to parse multiaddr %q: %v", tt.addr, err) + } + got := extractIPFromMultiaddr(ma) + if got != tt.expected { + t.Errorf("extractIPFromMultiaddr(%q) = %q, want %q", tt.addr, got, tt.expected) + } + }) + } +} + +func TestExtractIPFromMultiaddr_Nil(t *testing.T) { + got := extractIPFromMultiaddr(nil) + if got != "" { + t.Errorf("extractIPFromMultiaddr(nil) = %q, want empty string", got) + } +} + +// TestWireGuardIPFiltering verifies that only 10.0.0.x IPs would be selected +// for peer discovery queries. This tests the filtering logic used in +// DiscoverClusterPeersFromLibP2P. +func TestWireGuardIPFiltering(t *testing.T) { + tests := []struct { + name string + addr string + accepted bool + }{ + {"wireguard ip", "/ip4/10.0.0.1/tcp/4001", true}, + {"wireguard ip high", "/ip4/10.0.0.254/tcp/4001", true}, + {"public ip", "/ip4/203.0.113.5/tcp/4001", false}, + {"private 192.168", "/ip4/192.168.1.1/tcp/4001", false}, + {"private 172.16", "/ip4/172.16.0.1/tcp/4001", false}, + {"loopback", "/ip4/127.0.0.1/tcp/4001", false}, + {"different 10.x subnet", "/ip4/10.1.0.1/tcp/4001", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ma, err := multiaddr.NewMultiaddr(tt.addr) + if err != nil { + t.Fatalf("failed to parse multiaddr: %v", err) + } + ip := extractIPFromMultiaddr(ma) + // Replicate the filtering logic from DiscoverClusterPeersFromLibP2P + accepted := ip != "" && len(ip) >= 7 && ip[:7] == "10.0.0." + if accepted != tt.accepted { + t.Errorf("IP %q: accepted=%v, want %v", ip, accepted, tt.accepted) + } + }) + } +} diff --git a/pkg/rwagent/client.go b/pkg/rwagent/client.go new file mode 100644 index 0000000..64e7c3d --- /dev/null +++ b/pkg/rwagent/client.go @@ -0,0 +1,222 @@ +package rwagent + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +const ( + // DefaultSocketName is the socket file relative to ~/.rootwallet/. + DefaultSocketName = "agent.sock" + + // DefaultTimeout for HTTP requests to the agent. + // Set high enough to allow pending approval flow (2 min approval timeout). + DefaultTimeout = 150 * time.Second +) + +// Client communicates with the rootwallet agent daemon over a Unix socket. +type Client struct { + httpClient *http.Client + socketPath string +} + +// New creates a client that connects to the agent's Unix socket. +// If socketPath is empty, defaults to ~/.rootwallet/agent.sock. +func New(socketPath string) *Client { + if socketPath == "" { + home, _ := os.UserHomeDir() + socketPath = filepath.Join(home, ".rootwallet", DefaultSocketName) + } + + return &Client{ + socketPath: socketPath, + httpClient: &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) + }, + }, + Timeout: DefaultTimeout, + }, + } +} + +// Status returns the agent's current status. +func (c *Client) Status(ctx context.Context) (*StatusResponse, error) { + var resp apiResponse[StatusResponse] + if err := c.doJSON(ctx, "GET", "/v1/status", nil, &resp); err != nil { + return nil, err + } + if !resp.OK { + return nil, c.apiError(resp.Error, resp.Code, 0) + } + return &resp.Data, nil +} + +// IsRunning returns true if the agent is reachable. +func (c *Client) IsRunning(ctx context.Context) bool { + _, err := c.Status(ctx) + return err == nil +} + +// GetSSHKey retrieves an SSH key from the vault. +// format: "priv", "pub", or "both". +func (c *Client) GetSSHKey(ctx context.Context, host, username, format string) (*VaultSSHData, error) { + path := fmt.Sprintf("/v1/vault/ssh/%s/%s?format=%s", + url.PathEscape(host), + url.PathEscape(username), + url.QueryEscape(format), + ) + + var resp apiResponse[VaultSSHData] + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + if !resp.OK { + return nil, c.apiError(resp.Error, resp.Code, 0) + } + return &resp.Data, nil +} + +// CreateSSHEntry creates a new SSH key entry in the vault. +func (c *Client) CreateSSHEntry(ctx context.Context, host, username string) (*VaultSSHData, error) { + body := map[string]string{"host": host, "username": username} + + var resp apiResponse[VaultSSHData] + if err := c.doJSON(ctx, "POST", "/v1/vault/ssh", body, &resp); err != nil { + return nil, err + } + if !resp.OK { + return nil, c.apiError(resp.Error, resp.Code, 0) + } + return &resp.Data, nil +} + +// GetPassword retrieves a stored password from the vault. +func (c *Client) GetPassword(ctx context.Context, domain, username string) (*VaultPasswordData, error) { + path := fmt.Sprintf("/v1/vault/password/%s/%s", + url.PathEscape(domain), + url.PathEscape(username), + ) + + var resp apiResponse[VaultPasswordData] + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + if !resp.OK { + return nil, c.apiError(resp.Error, resp.Code, 0) + } + return &resp.Data, nil +} + +// GetAddress returns the active wallet address. +func (c *Client) GetAddress(ctx context.Context, chain string) (*WalletAddressData, error) { + path := fmt.Sprintf("/v1/wallet/address?chain=%s", url.QueryEscape(chain)) + + var resp apiResponse[WalletAddressData] + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + if !resp.OK { + return nil, c.apiError(resp.Error, resp.Code, 0) + } + return &resp.Data, nil +} + +// Unlock sends the password to unlock the agent. +func (c *Client) Unlock(ctx context.Context, password string, ttlMinutes int) error { + body := map[string]any{"password": password, "ttlMinutes": ttlMinutes} + + var resp apiResponse[any] + if err := c.doJSON(ctx, "POST", "/v1/unlock", body, &resp); err != nil { + return err + } + if !resp.OK { + return c.apiError(resp.Error, resp.Code, 0) + } + return nil +} + +// Lock locks the agent, zeroing all key material. +func (c *Client) Lock(ctx context.Context) error { + var resp apiResponse[any] + if err := c.doJSON(ctx, "POST", "/v1/lock", nil, &resp); err != nil { + return err + } + if !resp.OK { + return c.apiError(resp.Error, resp.Code, 0) + } + return nil +} + +// doJSON performs an HTTP request and decodes the JSON response. +func (c *Client) doJSON(ctx context.Context, method, path string, body any, result any) error { + var bodyReader io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("marshal request body: %w", err) + } + bodyReader = strings.NewReader(string(data)) + } + + // URL host is ignored for Unix sockets, but required by http.NewRequest + req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-RW-PID", strconv.Itoa(os.Getpid())) + + resp, err := c.httpClient.Do(req) + if err != nil { + // Connection refused or socket not found = agent not running + if isConnectionError(err) { + return ErrAgentNotRunning + } + return fmt.Errorf("agent request failed: %w", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response: %w", err) + } + + if err := json.Unmarshal(data, result); err != nil { + return fmt.Errorf("decode response: %w", err) + } + + return nil +} + +func (c *Client) apiError(message, code string, statusCode int) *AgentError { + return &AgentError{ + Code: code, + Message: message, + StatusCode: statusCode, + } +} + +// isConnectionError checks if the error is a connection-level failure. +func isConnectionError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "connection refused") || + strings.Contains(msg, "no such file or directory") || + strings.Contains(msg, "connect: no such file") +} + diff --git a/pkg/rwagent/client_test.go b/pkg/rwagent/client_test.go new file mode 100644 index 0000000..a97be54 --- /dev/null +++ b/pkg/rwagent/client_test.go @@ -0,0 +1,257 @@ +package rwagent + +import ( + "context" + "encoding/json" + "net" + "net/http" + "os" + "path/filepath" + "testing" +) + +// startMockAgent creates a mock agent server on a Unix socket for testing. +func startMockAgent(t *testing.T, handler http.Handler) (socketPath string, cleanup func()) { + t.Helper() + + tmpDir := t.TempDir() + socketPath = filepath.Join(tmpDir, "test-agent.sock") + + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("listen on unix socket: %v", err) + } + + server := &http.Server{Handler: handler} + go func() { _ = server.Serve(listener) }() + + cleanup = func() { + _ = server.Close() + _ = os.Remove(socketPath) + } + return socketPath, cleanup +} + +// jsonHandler returns an http.HandlerFunc that responds with the given JSON. +func jsonHandler(statusCode int, body any) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + data, _ := json.Marshal(body) + _, _ = w.Write(data) + } +} + +func TestStatus(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/status", jsonHandler(200, apiResponse[StatusResponse]{ + OK: true, + Data: StatusResponse{ + Version: "1.0.0", + Locked: false, + Uptime: 120, + PID: 12345, + }, + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + status, err := client.Status(context.Background()) + if err != nil { + t.Fatalf("Status() error: %v", err) + } + + if status.Version != "1.0.0" { + t.Errorf("Version = %q, want %q", status.Version, "1.0.0") + } + if status.Locked { + t.Error("Locked = true, want false") + } + if status.Uptime != 120 { + t.Errorf("Uptime = %d, want 120", status.Uptime) + } +} + +func TestIsRunning_true(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/status", jsonHandler(200, apiResponse[StatusResponse]{ + OK: true, + Data: StatusResponse{Version: "1.0.0"}, + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + if !client.IsRunning(context.Background()) { + t.Error("IsRunning() = false, want true") + } +} + +func TestIsRunning_false(t *testing.T) { + client := New("/tmp/nonexistent-socket-test.sock") + if client.IsRunning(context.Background()) { + t.Error("IsRunning() = true, want false") + } +} + +func TestGetSSHKey(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/vault/ssh/myhost/root", jsonHandler(200, apiResponse[VaultSSHData]{ + OK: true, + Data: VaultSSHData{ + PrivateKey: "-----BEGIN OPENSSH PRIVATE KEY-----\nfake\n-----END OPENSSH PRIVATE KEY-----", + PublicKey: "ssh-ed25519 AAAA... myhost/root", + }, + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + data, err := client.GetSSHKey(context.Background(), "myhost", "root", "both") + if err != nil { + t.Fatalf("GetSSHKey() error: %v", err) + } + + if data.PrivateKey == "" { + t.Error("PrivateKey is empty") + } + if data.PublicKey == "" { + t.Error("PublicKey is empty") + } +} + +func TestGetSSHKey_locked(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/vault/ssh/myhost/root", jsonHandler(423, apiResponse[any]{ + OK: false, + Error: "Agent is locked", + Code: "AGENT_LOCKED", + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + _, err := client.GetSSHKey(context.Background(), "myhost", "root", "priv") + if err == nil { + t.Fatal("GetSSHKey() expected error, got nil") + } + if !IsLocked(err) { + t.Errorf("IsLocked() = false for error: %v", err) + } +} + +func TestGetSSHKey_notFound(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/vault/ssh/unknown/user", jsonHandler(404, apiResponse[any]{ + OK: false, + Error: "No SSH key found for unknown/user", + Code: "NOT_FOUND", + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + _, err := client.GetSSHKey(context.Background(), "unknown", "user", "priv") + if err == nil { + t.Fatal("GetSSHKey() expected error, got nil") + } + if !IsNotFound(err) { + t.Errorf("IsNotFound() = false for error: %v", err) + } +} + +func TestGetPassword(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/vault/password/example.com/admin", jsonHandler(200, apiResponse[VaultPasswordData]{ + OK: true, + Data: VaultPasswordData{Password: "secret123"}, + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + data, err := client.GetPassword(context.Background(), "example.com", "admin") + if err != nil { + t.Fatalf("GetPassword() error: %v", err) + } + if data.Password != "secret123" { + t.Errorf("Password = %q, want %q", data.Password, "secret123") + } +} + +func TestCreateSSHEntry(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/vault/ssh", jsonHandler(201, apiResponse[VaultSSHData]{ + OK: true, + Data: VaultSSHData{PublicKey: "ssh-ed25519 AAAA... new/entry"}, + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + data, err := client.CreateSSHEntry(context.Background(), "new", "entry") + if err != nil { + t.Fatalf("CreateSSHEntry() error: %v", err) + } + if data.PublicKey == "" { + t.Error("PublicKey is empty") + } +} + +func TestGetAddress(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/wallet/address", jsonHandler(200, apiResponse[WalletAddressData]{ + OK: true, + Data: WalletAddressData{Address: "0x1234abcd", Chain: "evm"}, + })) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + data, err := client.GetAddress(context.Background(), "evm") + if err != nil { + t.Fatalf("GetAddress() error: %v", err) + } + if data.Address != "0x1234abcd" { + t.Errorf("Address = %q, want %q", data.Address, "0x1234abcd") + } +} + +func TestAgentNotRunning(t *testing.T) { + client := New("/tmp/nonexistent-socket-for-testing.sock") + _, err := client.Status(context.Background()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !IsNotRunning(err) { + t.Errorf("IsNotRunning() = false for error: %v", err) + } +} + +func TestUnlockAndLock(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/v1/unlock", jsonHandler(200, apiResponse[any]{OK: true})) + mux.HandleFunc("/v1/lock", jsonHandler(200, apiResponse[any]{OK: true})) + + sock, cleanup := startMockAgent(t, mux) + defer cleanup() + + client := New(sock) + + if err := client.Unlock(context.Background(), "password", 30); err != nil { + t.Fatalf("Unlock() error: %v", err) + } + + if err := client.Lock(context.Background()); err != nil { + t.Fatalf("Lock() error: %v", err) + } +} diff --git a/pkg/rwagent/errors.go b/pkg/rwagent/errors.go new file mode 100644 index 0000000..aeebb99 --- /dev/null +++ b/pkg/rwagent/errors.go @@ -0,0 +1,57 @@ +package rwagent + +import ( + "errors" + "fmt" +) + +// AgentError represents an error returned by the rootwallet agent API. +type AgentError struct { + Code string // e.g., "AGENT_LOCKED", "NOT_FOUND" + Message string + StatusCode int +} + +func (e *AgentError) Error() string { + return fmt.Sprintf("rootwallet agent: %s (%s)", e.Message, e.Code) +} + +// IsLocked returns true if the error indicates the agent is locked. +func IsLocked(err error) bool { + var ae *AgentError + if errors.As(err, &ae) { + return ae.Code == "AGENT_LOCKED" + } + return false +} + +// IsNotRunning returns true if the error indicates the agent is not reachable. +func IsNotRunning(err error) bool { + var ae *AgentError + if errors.As(err, &ae) { + return ae.Code == "AGENT_NOT_RUNNING" + } + // Also check for connection errors + return errors.Is(err, ErrAgentNotRunning) +} + +// IsNotFound returns true if the vault entry was not found. +func IsNotFound(err error) bool { + var ae *AgentError + if errors.As(err, &ae) { + return ae.Code == "NOT_FOUND" + } + return false +} + +// IsApprovalDenied returns true if the user denied the app's access request. +func IsApprovalDenied(err error) bool { + var ae *AgentError + if errors.As(err, &ae) { + return ae.Code == "APPROVAL_DENIED" || ae.Code == "PERMISSION_DENIED" + } + return false +} + +// ErrAgentNotRunning is returned when the agent socket is not reachable. +var ErrAgentNotRunning = fmt.Errorf("rootwallet agent is not running — start with: rw agent start && rw agent unlock") diff --git a/pkg/rwagent/types.go b/pkg/rwagent/types.go new file mode 100644 index 0000000..4d04f95 --- /dev/null +++ b/pkg/rwagent/types.go @@ -0,0 +1,56 @@ +// Package rwagent provides a Go client for the RootWallet agent daemon. +// +// The agent is a persistent daemon that holds vault keys in memory and serves +// operations to authorized apps over a Unix socket HTTP API. This SDK replaces +// all subprocess `rw` calls with direct HTTP communication. +package rwagent + +// StatusResponse from GET /v1/status. +type StatusResponse struct { + Version string `json:"version"` + Locked bool `json:"locked"` + Uptime int `json:"uptime"` + PID int `json:"pid"` + ConnectedApps int `json:"connectedApps"` +} + +// VaultSSHData from GET /v1/vault/ssh/:host/:user. +type VaultSSHData struct { + PrivateKey string `json:"privateKey,omitempty"` + PublicKey string `json:"publicKey,omitempty"` +} + +// VaultPasswordData from GET /v1/vault/password/:domain/:user. +type VaultPasswordData struct { + Password string `json:"password"` +} + +// WalletAddressData from GET /v1/wallet/address. +type WalletAddressData struct { + Address string `json:"address"` + Chain string `json:"chain"` +} + +// AppPermission represents an approved app in the permission database. +type AppPermission struct { + BinaryHash string `json:"binaryHash"` + BinaryPath string `json:"binaryPath"` + Name string `json:"name"` + FirstSeen string `json:"firstSeen"` + LastUsed string `json:"lastUsed"` + Capabilities []PermittedCapability `json:"capabilities"` +} + +// PermittedCapability is a specific capability granted to an app. +type PermittedCapability struct { + Capability string `json:"capability"` + GrantedAt string `json:"grantedAt"` +} + +// apiResponse is the generic API response envelope. +type apiResponse[T any] struct { + OK bool `json:"ok"` + Data T `json:"data,omitempty"` + Error string `json:"error,omitempty"` + Code string `json:"code,omitempty"` +}