Merge pull request #85 from DeBrosDAO/0.115.0

0.115.0
This commit is contained in:
anonpenguin 2026-03-20 07:25:50 +02:00 committed by GitHub
commit 8ea4499052
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1476 additions and 158 deletions

View File

@ -597,7 +597,11 @@ func promptForBaseDomain() string {
// installNamespaceTemplates installs systemd template unit files for namespace services // installNamespaceTemplates installs systemd template unit files for namespace services
func (o *Orchestrator) installNamespaceTemplates() error { func (o *Orchestrator) installNamespaceTemplates() error {
sourceDir := filepath.Join(o.oramaHome, "src", "systemd") // Check pre-built archive path first, fall back to source path
sourceDir := production.OramaSystemdDir
if _, err := os.Stat(sourceDir); os.IsNotExist(err) {
sourceDir = filepath.Join(o.oramaHome, "src", "systemd")
}
systemdDir := "/etc/systemd/system" systemdDir := "/etc/systemd/system"
templates := []string{ templates := []string{

View File

@ -1,6 +1,7 @@
package remotessh package remotessh
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
@ -8,11 +9,40 @@ import (
"strings" "strings"
"github.com/DeBrosOfficial/network/pkg/inspector" "github.com/DeBrosOfficial/network/pkg/inspector"
"github.com/DeBrosOfficial/network/pkg/rwagent"
) )
// vaultClient is the interface used by wallet functions to talk to the agent.
// Defaults to the real rwagent.Client; tests replace it with a mock.
type vaultClient interface {
GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error)
CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error)
}
// newClient creates the default vaultClient. Package-level var for test injection.
var newClient func() vaultClient = func() vaultClient {
return rwagent.New(os.Getenv("RW_AGENT_SOCK"))
}
// wrapAgentError wraps rwagent errors with user-friendly messages.
// When the agent is locked, it also triggers the RootWallet desktop app
// to show the unlock dialog via deep link (best-effort, fire-and-forget).
func wrapAgentError(err error, action string) error {
if rwagent.IsNotRunning(err) {
return fmt.Errorf("%s: rootwallet agent is not running — start with: rw agent start && rw agent unlock", action)
}
if rwagent.IsLocked(err) {
return fmt.Errorf("%s: rootwallet agent is locked — unlock timed out after waiting. Unlock via the RootWallet app or run: rw agent unlock", action)
}
if rwagent.IsApprovalDenied(err) {
return fmt.Errorf("%s: rootwallet access denied — approve this app in the RootWallet desktop app", action)
}
return fmt.Errorf("%s: %w", action, err)
}
// PrepareNodeKeys resolves wallet-derived SSH keys for all nodes. // PrepareNodeKeys resolves wallet-derived SSH keys for all nodes.
// Calls `rw vault ssh get <host>/<user> --priv` for each unique host/user, // Retrieves private keys from the rootwallet agent daemon, writes PEMs to
// writes PEMs to temp files, and sets node.SSHKey for each node. // temp files, and sets node.SSHKey for each node.
// //
// The nodes slice is modified in place — each node.SSHKey is set to // The nodes slice is modified in place — each node.SSHKey is set to
// the path of the temporary key file. // the path of the temporary key file.
@ -20,10 +50,8 @@ import (
// Returns a cleanup function that zero-overwrites and removes all temp files. // Returns a cleanup function that zero-overwrites and removes all temp files.
// Caller must defer cleanup(). // Caller must defer cleanup().
func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) { func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
rw, err := rwBinary() client := newClient()
if err != nil { ctx := context.Background()
return nil, err
}
// Create temp dir for all keys // Create temp dir for all keys
tmpDir, err := os.MkdirTemp("", "orama-ssh-") tmpDir, err := os.MkdirTemp("", "orama-ssh-")
@ -31,12 +59,11 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
return nil, fmt.Errorf("create temp dir: %w", err) return nil, fmt.Errorf("create temp dir: %w", err)
} }
// Track resolved keys by host/user to avoid duplicate rw calls // Track resolved keys by host/user to avoid duplicate agent calls
keyPaths := make(map[string]string) // "host/user" → temp file path keyPaths := make(map[string]string) // "host/user" → temp file path
var allKeyPaths []string var allKeyPaths []string
for i := range nodes { for i := range nodes {
// Use VaultTarget if set, otherwise default to Host/User
var key string var key string
if nodes[i].VaultTarget != "" { if nodes[i].VaultTarget != "" {
key = nodes[i].VaultTarget key = nodes[i].VaultTarget
@ -48,18 +75,21 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
continue continue
} }
// Call rw to get the private key PEM
host, user := parseVaultTarget(key) host, user := parseVaultTarget(key)
pem, err := resolveWalletKey(rw, host, user) data, err := client.GetSSHKey(ctx, host, user, "priv")
if err != nil { if err != nil {
// Cleanup any keys already written before returning error
cleanupKeys(tmpDir, allKeyPaths) cleanupKeys(tmpDir, allKeyPaths)
return nil, fmt.Errorf("resolve key for %s: %w", nodes[i].Name(), err) return nil, wrapAgentError(err, fmt.Sprintf("resolve key for %s", nodes[i].Name()))
}
if !strings.Contains(data.PrivateKey, "BEGIN OPENSSH PRIVATE KEY") {
cleanupKeys(tmpDir, allKeyPaths)
return nil, fmt.Errorf("agent returned invalid key for %s", nodes[i].Name())
} }
// Write PEM to temp file with restrictive perms // Write PEM to temp file with restrictive perms
keyFile := filepath.Join(tmpDir, fmt.Sprintf("id_%d", i)) keyFile := filepath.Join(tmpDir, fmt.Sprintf("id_%d", i))
if err := os.WriteFile(keyFile, []byte(pem), 0600); err != nil { if err := os.WriteFile(keyFile, []byte(data.PrivateKey), 0600); err != nil {
cleanupKeys(tmpDir, allKeyPaths) cleanupKeys(tmpDir, allKeyPaths)
return nil, fmt.Errorf("write key for %s: %w", nodes[i].Name(), err) return nil, fmt.Errorf("write key for %s: %w", nodes[i].Name(), err)
} }
@ -77,12 +107,10 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
// LoadAgentKeys loads SSH keys for the given nodes into the system ssh-agent. // LoadAgentKeys loads SSH keys for the given nodes into the system ssh-agent.
// Used by push fanout to enable agent forwarding. // Used by push fanout to enable agent forwarding.
// Calls `rw vault ssh agent-load <host1/user1> <host2/user2> ...` // Retrieves private keys from the rootwallet agent and pipes them to ssh-add.
func LoadAgentKeys(nodes []inspector.Node) error { func LoadAgentKeys(nodes []inspector.Node) error {
rw, err := rwBinary() client := newClient()
if err != nil { ctx := context.Background()
return err
}
// Deduplicate host/user pairs // Deduplicate host/user pairs
seen := make(map[string]bool) seen := make(map[string]bool)
@ -105,76 +133,65 @@ func LoadAgentKeys(nodes []inspector.Node) error {
return nil return nil
} }
args := append([]string{"vault", "ssh", "agent-load"}, targets...) for _, target := range targets {
cmd := exec.Command(rw, args...) host, user := parseVaultTarget(target)
cmd.Stderr = os.Stderr data, err := client.GetSSHKey(ctx, host, user, "priv")
cmd.Stdout = os.Stderr // info messages go to stderr if err != nil {
return wrapAgentError(err, fmt.Sprintf("get key for %s", target))
}
if err := cmd.Run(); err != nil { // Pipe private key to ssh-add via stdin
return fmt.Errorf("rw vault ssh agent-load failed: %w", err) cmd := exec.Command("ssh-add", "-")
cmd.Stdin = strings.NewReader(data.PrivateKey)
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("ssh-add failed for %s: %w", target, err)
}
} }
return nil return nil
} }
// EnsureVaultEntry creates a wallet SSH entry if it doesn't already exist. // EnsureVaultEntry creates a wallet SSH entry if it doesn't already exist.
// Checks existence via `rw vault ssh get <target> --pub`, and if missing, // Checks the rootwallet agent for an existing entry, creates one if not found.
// runs `rw vault ssh add <target>` to create it.
func EnsureVaultEntry(vaultTarget string) error { func EnsureVaultEntry(vaultTarget string) error {
rw, err := rwBinary() client := newClient()
if err != nil { ctx := context.Background()
return err
host, user := parseVaultTarget(vaultTarget)
// Check if entry already exists
_, err := client.GetSSHKey(ctx, host, user, "pub")
if err == nil {
return nil // entry exists
} }
// Check if entry exists by trying to get the public key // If not found, create it
cmd := exec.Command(rw, "vault", "ssh", "get", vaultTarget, "--pub") if rwagent.IsNotFound(err) {
if err := cmd.Run(); err == nil { _, createErr := client.CreateSSHEntry(ctx, host, user)
return nil // entry already exists if createErr != nil {
} return wrapAgentError(createErr, fmt.Sprintf("create vault entry %s", vaultTarget))
// Entry doesn't exist — try to create it
addCmd := exec.Command(rw, "vault", "ssh", "add", vaultTarget)
addCmd.Stdin = os.Stdin
addCmd.Stdout = os.Stderr
addCmd.Stderr = os.Stderr
if err := addCmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
stderr := strings.TrimSpace(string(exitErr.Stderr))
if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") {
return fmt.Errorf("wallet is locked — run: rw unlock")
}
} }
return fmt.Errorf("rw vault ssh add %s failed: %w", vaultTarget, err) return nil
} }
return nil
return wrapAgentError(err, fmt.Sprintf("check vault entry %s", vaultTarget))
} }
// ResolveVaultPublicKey returns the OpenSSH public key string for a vault entry. // ResolveVaultPublicKey returns the OpenSSH public key string for a vault entry.
// Calls `rw vault ssh get <target> --pub`.
func ResolveVaultPublicKey(vaultTarget string) (string, error) { func ResolveVaultPublicKey(vaultTarget string) (string, error) {
rw, err := rwBinary() client := newClient()
ctx := context.Background()
host, user := parseVaultTarget(vaultTarget)
data, err := client.GetSSHKey(ctx, host, user, "pub")
if err != nil { if err != nil {
return "", err return "", wrapAgentError(err, fmt.Sprintf("get public key for %s", vaultTarget))
} }
cmd := exec.Command(rw, "vault", "ssh", "get", vaultTarget, "--pub") pubKey := strings.TrimSpace(data.PublicKey)
out, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
stderr := strings.TrimSpace(string(exitErr.Stderr))
if strings.Contains(stderr, "No SSH entry") {
return "", fmt.Errorf("no vault SSH entry for %s — run: rw vault ssh add %s", vaultTarget, vaultTarget)
}
if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") {
return "", fmt.Errorf("wallet is locked — run: rw unlock")
}
return "", fmt.Errorf("%s", stderr)
}
return "", fmt.Errorf("rw command failed: %w", err)
}
pubKey := strings.TrimSpace(string(out))
if !strings.HasPrefix(pubKey, "ssh-") { if !strings.HasPrefix(pubKey, "ssh-") {
return "", fmt.Errorf("rw returned invalid public key for %s", vaultTarget) return "", fmt.Errorf("agent returned invalid public key for %s", vaultTarget)
} }
return pubKey, nil return pubKey, nil
} }
@ -188,49 +205,6 @@ func parseVaultTarget(target string) (host, user string) {
return target[:idx], target[idx+1:] return target[:idx], target[idx+1:]
} }
// resolveWalletKey calls `rw vault ssh get <host>/<user> --priv`
// and returns the PEM string. Requires an active rw session.
func resolveWalletKey(rw string, host, user string) (string, error) {
target := host + "/" + user
cmd := exec.Command(rw, "vault", "ssh", "get", target, "--priv")
out, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
stderr := strings.TrimSpace(string(exitErr.Stderr))
if strings.Contains(stderr, "No SSH entry") {
return "", fmt.Errorf("no vault SSH entry for %s — run: rw vault ssh add %s", target, target)
}
if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") {
return "", fmt.Errorf("wallet is locked — run: rw unlock")
}
return "", fmt.Errorf("%s", stderr)
}
return "", fmt.Errorf("rw command failed: %w", err)
}
pem := string(out)
if !strings.Contains(pem, "BEGIN OPENSSH PRIVATE KEY") {
return "", fmt.Errorf("rw returned invalid key for %s", target)
}
return pem, nil
}
// rwBinary returns the path to the `rw` binary.
// Checks RW_PATH env var first, then PATH.
func rwBinary() (string, error) {
if p := os.Getenv("RW_PATH"); p != "" {
if _, err := os.Stat(p); err == nil {
return p, nil
}
return "", fmt.Errorf("RW_PATH=%q not found", p)
}
p, err := exec.LookPath("rw")
if err != nil {
return "", fmt.Errorf("rw not found in PATH — install rootwallet CLI: https://github.com/DeBrosOfficial/rootwallet")
}
return p, nil
}
// cleanupKeys zero-overwrites and removes all key files, then removes the temp dir. // cleanupKeys zero-overwrites and removes all key files, then removes the temp dir.
func cleanupKeys(tmpDir string, keyPaths []string) { func cleanupKeys(tmpDir string, keyPaths []string) {
zeros := make([]byte, 512) zeros := make([]byte, 512)

View File

@ -1,6 +1,39 @@
package remotessh package remotessh
import "testing" import (
"context"
"errors"
"os"
"strings"
"testing"
"github.com/DeBrosOfficial/network/pkg/inspector"
"github.com/DeBrosOfficial/network/pkg/rwagent"
)
const testPrivateKey = "-----BEGIN OPENSSH PRIVATE KEY-----\nfake-key-data\n-----END OPENSSH PRIVATE KEY-----"
// mockClient implements vaultClient for testing.
type mockClient struct {
getSSHKey func(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error)
createSSHEntry func(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error)
}
func (m *mockClient) GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
return m.getSSHKey(ctx, host, username, format)
}
func (m *mockClient) CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error) {
return m.createSSHEntry(ctx, host, username)
}
// withMockClient replaces newClient for the duration of a test.
func withMockClient(t *testing.T, mock *mockClient) {
t.Helper()
orig := newClient
newClient = func() vaultClient { return mock }
t.Cleanup(func() { newClient = orig })
}
func TestParseVaultTarget(t *testing.T) { func TestParseVaultTarget(t *testing.T) {
tests := []struct { tests := []struct {
@ -13,6 +46,7 @@ func TestParseVaultTarget(t *testing.T) {
{"my-host/my-user", "my-host", "my-user"}, {"my-host/my-user", "my-host", "my-user"},
{"noslash", "noslash", ""}, {"noslash", "noslash", ""},
{"a/b/c", "a", "b/c"}, {"a/b/c", "a", "b/c"},
{"", "", ""},
} }
for _, tt := range tests { for _, tt := range tests {
@ -27,3 +61,316 @@ func TestParseVaultTarget(t *testing.T) {
}) })
} }
} }
func TestWrapAgentError_notRunning(t *testing.T) {
err := wrapAgentError(rwagent.ErrAgentNotRunning, "test action")
if !strings.Contains(err.Error(), "not running") {
t.Errorf("expected 'not running' message, got: %s", err)
}
if !strings.Contains(err.Error(), "rw agent start") {
t.Errorf("expected actionable hint, got: %s", err)
}
}
func TestWrapAgentError_locked(t *testing.T) {
agentErr := &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "agent is locked"}
err := wrapAgentError(agentErr, "test action")
if !strings.Contains(err.Error(), "locked") {
t.Errorf("expected 'locked' message, got: %s", err)
}
if !strings.Contains(err.Error(), "rw agent unlock") {
t.Errorf("expected actionable hint, got: %s", err)
}
}
func TestWrapAgentError_generic(t *testing.T) {
err := wrapAgentError(errors.New("some error"), "test action")
if !strings.Contains(err.Error(), "test action") {
t.Errorf("expected action context, got: %s", err)
}
if !strings.Contains(err.Error(), "some error") {
t.Errorf("expected wrapped error, got: %s", err)
}
}
func TestPrepareNodeKeys_success(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root"},
{Host: "10.0.0.2", User: "root"},
}
cleanup, err := PrepareNodeKeys(nodes)
if err != nil {
t.Fatalf("PrepareNodeKeys() error = %v", err)
}
defer cleanup()
for i, n := range nodes {
if n.SSHKey == "" {
t.Errorf("node[%d].SSHKey is empty", i)
continue
}
data, err := os.ReadFile(n.SSHKey)
if err != nil {
t.Errorf("node[%d] key file unreadable: %v", i, err)
continue
}
if !strings.Contains(string(data), "BEGIN OPENSSH PRIVATE KEY") {
t.Errorf("node[%d] key file has wrong content", i)
}
}
}
func TestPrepareNodeKeys_deduplication(t *testing.T) {
callCount := 0
mock := &mockClient{
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
callCount++
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root"},
{Host: "10.0.0.1", User: "root"}, // same host/user
}
cleanup, err := PrepareNodeKeys(nodes)
if err != nil {
t.Fatalf("PrepareNodeKeys() error = %v", err)
}
defer cleanup()
if callCount != 1 {
t.Errorf("expected 1 agent call (dedup), got %d", callCount)
}
if nodes[0].SSHKey != nodes[1].SSHKey {
t.Error("expected same key path for deduplicated nodes")
}
}
func TestPrepareNodeKeys_vaultTarget(t *testing.T) {
var capturedHost, capturedUser string
mock := &mockClient{
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
capturedHost = host
capturedUser = username
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root", VaultTarget: "sandbox/admin"},
}
cleanup, err := PrepareNodeKeys(nodes)
if err != nil {
t.Fatalf("PrepareNodeKeys() error = %v", err)
}
defer cleanup()
if capturedHost != "sandbox" || capturedUser != "admin" {
t.Errorf("expected host=sandbox user=admin, got host=%s user=%s", capturedHost, capturedUser)
}
}
func TestPrepareNodeKeys_agentNotRunning(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, rwagent.ErrAgentNotRunning
},
}
withMockClient(t, mock)
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
_, err := PrepareNodeKeys(nodes)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "not running") {
t.Errorf("expected 'not running' error, got: %s", err)
}
}
func TestPrepareNodeKeys_invalidKey(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PrivateKey: "garbage"}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
_, err := PrepareNodeKeys(nodes)
if err == nil {
t.Fatal("expected error for invalid key")
}
if !strings.Contains(err.Error(), "invalid key") {
t.Errorf("expected 'invalid key' error, got: %s", err)
}
}
func TestPrepareNodeKeys_cleanupOnError(t *testing.T) {
callNum := 0
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
callNum++
if callNum == 2 {
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
}
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root"},
{Host: "10.0.0.2", User: "root"},
}
_, err := PrepareNodeKeys(nodes)
if err == nil {
t.Fatal("expected error")
}
// First node's temp file should have been cleaned up
if nodes[0].SSHKey != "" {
if _, statErr := os.Stat(nodes[0].SSHKey); statErr == nil {
t.Error("expected temp key file to be cleaned up on error")
}
}
}
func TestPrepareNodeKeys_emptyNodes(t *testing.T) {
mock := &mockClient{}
withMockClient(t, mock)
cleanup, err := PrepareNodeKeys(nil)
if err != nil {
t.Fatalf("expected no error for empty nodes, got: %v", err)
}
cleanup() // should not panic
}
func TestEnsureVaultEntry_exists(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
},
}
withMockClient(t, mock)
if err := EnsureVaultEntry("sandbox/root"); err != nil {
t.Fatalf("EnsureVaultEntry() error = %v", err)
}
}
func TestEnsureVaultEntry_creates(t *testing.T) {
created := false
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
},
createSSHEntry: func(_ context.Context, host, username string) (*rwagent.VaultSSHData, error) {
created = true
if host != "sandbox" || username != "root" {
t.Errorf("unexpected create args: %s/%s", host, username)
}
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
},
}
withMockClient(t, mock)
if err := EnsureVaultEntry("sandbox/root"); err != nil {
t.Fatalf("EnsureVaultEntry() error = %v", err)
}
if !created {
t.Error("expected CreateSSHEntry to be called")
}
}
func TestEnsureVaultEntry_locked(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
},
}
withMockClient(t, mock)
err := EnsureVaultEntry("sandbox/root")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "locked") {
t.Errorf("expected locked error, got: %s", err)
}
}
func TestResolveVaultPublicKey_success(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, format string) (*rwagent.VaultSSHData, error) {
if format != "pub" {
t.Errorf("expected format=pub, got %s", format)
}
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA..."}, nil
},
}
withMockClient(t, mock)
key, err := ResolveVaultPublicKey("sandbox/root")
if err != nil {
t.Fatalf("ResolveVaultPublicKey() error = %v", err)
}
if !strings.HasPrefix(key, "ssh-") {
t.Errorf("expected ssh- prefix, got: %s", key)
}
}
func TestResolveVaultPublicKey_invalidFormat(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PublicKey: "not-a-valid-key"}, nil
},
}
withMockClient(t, mock)
_, err := ResolveVaultPublicKey("sandbox/root")
if err == nil {
t.Fatal("expected error for invalid public key")
}
if !strings.Contains(err.Error(), "invalid public key") {
t.Errorf("expected 'invalid public key' error, got: %s", err)
}
}
func TestResolveVaultPublicKey_notFound(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
},
}
withMockClient(t, mock)
_, err := ResolveVaultPublicKey("sandbox/root")
if err == nil {
t.Fatal("expected error")
}
}
func TestLoadAgentKeys_emptyNodes(t *testing.T) {
mock := &mockClient{}
withMockClient(t, mock)
if err := LoadAgentKeys(nil); err != nil {
t.Fatalf("expected no error for empty nodes, got: %v", err)
}
}

View File

@ -75,8 +75,7 @@ func Create(name string) error {
// Phase 4: Install genesis node // Phase 4: Install genesis node
fmt.Println("\nPhase 4: Installing genesis node...") fmt.Println("\nPhase 4: Installing genesis node...")
tokens, err := phase4InstallGenesis(cfg, state, sshKeyPath) if err := phase4InstallGenesis(cfg, state, sshKeyPath); err != nil {
if err != nil {
state.Status = StatusError state.Status = StatusError
SaveState(state) SaveState(state)
return fmt.Errorf("install genesis: %w", err) return fmt.Errorf("install genesis: %w", err)
@ -84,7 +83,7 @@ func Create(name string) error {
// Phase 5: Join remaining nodes // Phase 5: Join remaining nodes
fmt.Println("\nPhase 5: Joining remaining nodes...") fmt.Println("\nPhase 5: Joining remaining nodes...")
if err := phase5JoinNodes(cfg, state, tokens, sshKeyPath); err != nil { if err := phase5JoinNodes(cfg, state, sshKeyPath); err != nil {
state.Status = StatusError state.Status = StatusError
SaveState(state) SaveState(state)
return fmt.Errorf("join nodes: %w", err) return fmt.Errorf("join nodes: %w", err)
@ -280,60 +279,52 @@ func phase3UploadArchive(state *SandboxState, sshKeyPath string) error {
return nil return nil
} }
// phase4InstallGenesis installs the genesis node and generates invite tokens. // phase4InstallGenesis installs the genesis node.
func phase4InstallGenesis(cfg *Config, state *SandboxState, sshKeyPath string) ([]string, error) { func phase4InstallGenesis(cfg *Config, state *SandboxState, sshKeyPath string) error {
genesis := state.GenesisServer() genesis := state.GenesisServer()
node := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath} node := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath}
// Install genesis // Install genesis
installCmd := fmt.Sprintf("/opt/orama/bin/orama node install --vps-ip %s --domain %s --base-domain %s --nameserver --skip-checks", installCmd := fmt.Sprintf("/opt/orama/bin/orama node install --vps-ip %s --domain %s --base-domain %s --nameserver --anyone-client --skip-checks",
genesis.IP, cfg.Domain, cfg.Domain) genesis.IP, cfg.Domain, cfg.Domain)
fmt.Printf(" Installing on %s (%s)...\n", genesis.Name, genesis.IP) fmt.Printf(" Installing on %s (%s)...\n", genesis.Name, genesis.IP)
if err := remotessh.RunSSHStreaming(node, installCmd, remotessh.WithNoHostKeyCheck()); err != nil { if err := remotessh.RunSSHStreaming(node, installCmd, remotessh.WithNoHostKeyCheck()); err != nil {
return nil, fmt.Errorf("install genesis: %w", err) return fmt.Errorf("install genesis: %w", err)
} }
// Wait for RQLite leader // Wait for RQLite leader
fmt.Print(" Waiting for RQLite leader...") fmt.Print(" Waiting for RQLite leader...")
if err := waitForRQLiteHealth(node, 3*time.Minute); err != nil { if err := waitForRQLiteHealth(node, 3*time.Minute); err != nil {
return nil, fmt.Errorf("genesis health: %w", err) return fmt.Errorf("genesis health: %w", err)
} }
fmt.Println(" OK") fmt.Println(" OK")
// Generate invite tokens (one per remaining node) return nil
fmt.Print(" Generating invite tokens...")
remaining := len(state.Servers) - 1
tokens := make([]string, remaining)
for i := 0; i < remaining; i++ {
token, err := generateInviteToken(node)
if err != nil {
return nil, fmt.Errorf("generate invite token %d: %w", i+1, err)
}
tokens[i] = token
fmt.Print(".")
}
fmt.Println(" OK")
return tokens, nil
} }
// phase5JoinNodes joins the remaining 4 nodes to the cluster (serial). // phase5JoinNodes joins the remaining 4 nodes to the cluster (serial).
func phase5JoinNodes(cfg *Config, state *SandboxState, tokens []string, sshKeyPath string) error { // Generates invite tokens just-in-time to avoid expiry during long installs.
genesisIP := state.GenesisServer().IP func phase5JoinNodes(cfg *Config, state *SandboxState, sshKeyPath string) error {
genesis := state.GenesisServer()
genesisNode := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath}
for i := 1; i < len(state.Servers); i++ { for i := 1; i < len(state.Servers); i++ {
srv := state.Servers[i] srv := state.Servers[i]
node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath} node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath}
token := tokens[i-1]
// Generate token just before use to avoid expiry
token, err := generateInviteToken(genesisNode)
if err != nil {
return fmt.Errorf("generate invite token for %s: %w", srv.Name, err)
}
var installCmd string var installCmd string
if srv.Role == "nameserver" { if srv.Role == "nameserver" {
installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --domain %s --base-domain %s --nameserver --skip-checks", installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --domain %s --base-domain %s --nameserver --anyone-client --skip-checks",
genesisIP, token, srv.IP, cfg.Domain, cfg.Domain) genesis.IP, token, srv.IP, cfg.Domain, cfg.Domain)
} else { } else {
installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --base-domain %s --skip-checks", installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --base-domain %s --anyone-client --skip-checks",
genesisIP, token, srv.IP, cfg.Domain) genesis.IP, token, srv.IP, cfg.Domain)
} }
fmt.Printf(" [%d/%d] Joining %s (%s, %s)...\n", i, len(state.Servers)-1, srv.Name, srv.IP, srv.Role) fmt.Printf(" [%d/%d] Joining %s (%s, %s)...\n", i, len(state.Servers)-1, srv.Name, srv.IP, srv.Role)

View File

@ -346,7 +346,11 @@ func (c *HetznerClient) UploadSSHKey(name, publicKey string) (*HetznerSSHKey, er
// ListSSHKeysByFingerprint finds SSH keys matching a fingerprint. // ListSSHKeysByFingerprint finds SSH keys matching a fingerprint.
func (c *HetznerClient) ListSSHKeysByFingerprint(fingerprint string) ([]HetznerSSHKey, error) { func (c *HetznerClient) ListSSHKeysByFingerprint(fingerprint string) ([]HetznerSSHKey, error) {
body, err := c.get("/ssh_keys?fingerprint=" + fingerprint) path := "/ssh_keys"
if fingerprint != "" {
path += "?fingerprint=" + fingerprint
}
body, err := c.get(path)
if err != nil { if err != nil {
return nil, fmt.Errorf("list SSH keys: %w", err) return nil, fmt.Errorf("list SSH keys: %w", err)
} }

View File

@ -408,17 +408,38 @@ func setupSSHKey(client *HetznerClient) (SSHKeyConfig, error) {
fmt.Print(" Uploading to Hetzner... ") fmt.Print(" Uploading to Hetzner... ")
key, err := client.UploadSSHKey("orama-sandbox", pubStr) key, err := client.UploadSSHKey("orama-sandbox", pubStr)
if err != nil { if err != nil {
// Key may already exist on Hetzner — try to find by fingerprint // Key may already exist on Hetzner — check if it matches the current vault key
existing, listErr := client.ListSSHKeysByFingerprint("") // empty = list all existing, listErr := client.ListSSHKeysByFingerprint("")
if listErr == nil { if listErr == nil {
for _, k := range existing { for _, k := range existing {
if strings.TrimSpace(k.PublicKey) == pubStr { if sshKeyDataEqual(k.PublicKey, pubStr) {
// Key data matches — safe to reuse regardless of name
fmt.Printf("already exists (ID: %d)\n", k.ID) fmt.Printf("already exists (ID: %d)\n", k.ID)
return SSHKeyConfig{ return SSHKeyConfig{
HetznerID: k.ID, HetznerID: k.ID,
VaultTarget: vaultTarget, VaultTarget: vaultTarget,
}, nil }, nil
} }
if k.Name == "orama-sandbox" {
// Name matches but key data differs — vault key was rotated.
// Delete the stale Hetzner key so we can re-upload the current one.
fmt.Print("stale key detected, replacing... ")
if delErr := client.DeleteSSHKey(k.ID); delErr != nil {
fmt.Println("FAILED")
return SSHKeyConfig{}, fmt.Errorf("delete stale SSH key (ID %d): %w", k.ID, delErr)
}
// Re-upload with current vault key
newKey, uploadErr := client.UploadSSHKey("orama-sandbox", pubStr)
if uploadErr != nil {
fmt.Println("FAILED")
return SSHKeyConfig{}, fmt.Errorf("re-upload SSH key: %w", uploadErr)
}
fmt.Printf("OK (ID: %d)\n", newKey.ID)
return SSHKeyConfig{
HetznerID: newKey.ID,
VaultTarget: vaultTarget,
}, nil
}
} }
} }
@ -433,6 +454,17 @@ func setupSSHKey(client *HetznerClient) (SSHKeyConfig, error) {
}, nil }, nil
} }
// sshKeyDataEqual compares two SSH public key strings by their key type and
// data, ignoring the optional comment field.
func sshKeyDataEqual(a, b string) bool {
partsA := strings.Fields(strings.TrimSpace(a))
partsB := strings.Fields(strings.TrimSpace(b))
if len(partsA) < 2 || len(partsB) < 2 {
return false
}
return partsA[0] == partsB[0] && partsA[1] == partsB[1]
}
// verifyDNS checks if glue records for the sandbox domain are configured. // verifyDNS checks if glue records for the sandbox domain are configured.
// //
// There's a chicken-and-egg problem: NS records can't fully resolve until // There's a chicken-and-egg problem: NS records can't fully resolve until

View File

@ -0,0 +1,82 @@
package sandbox
import "testing"
func TestSSHKeyDataEqual(t *testing.T) {
tests := []struct {
name string
a string
b string
expected bool
}{
{
name: "identical keys",
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest comment1",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest comment1",
expected: true,
},
{
name: "same key different comments",
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest user@host",
expected: true,
},
{
name: "same key one without comment",
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
expected: true,
},
{
name: "different key data",
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBoldkey vault",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBnewkey vault",
expected: false,
},
{
name: "different key types",
a: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAB vault",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
expected: false,
},
{
name: "empty string a",
a: "",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
expected: false,
},
{
name: "empty string b",
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
b: "",
expected: false,
},
{
name: "both empty",
a: "",
b: "",
expected: false,
},
{
name: "single field only",
a: "ssh-ed25519",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest",
expected: false,
},
{
name: "whitespace trimming",
a: " ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault ",
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sshKeyDataEqual(tt.a, tt.b)
if got != tt.expected {
t.Errorf("sshKeyDataEqual(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.expected)
}
})
}
}

View File

@ -338,6 +338,7 @@ func (ari *AnyoneRelayInstaller) ConfigureClient() error {
config := `# Anyone Client Configuration (Managed by Orama Network) config := `# Anyone Client Configuration (Managed by Orama Network)
# Client-only mode no relay traffic, no ORPort # Client-only mode no relay traffic, no ORPort
AgreeToTerms 1
SocksPort 9050 SocksPort 9050
Log notice file /var/log/anon/notices.log Log notice file /var/log/anon/notices.log
@ -360,6 +361,8 @@ func (ari *AnyoneRelayInstaller) generateAnonrc() string {
sb.WriteString("# Anyone Relay Configuration (Managed by Orama Network)\n") sb.WriteString("# Anyone Relay Configuration (Managed by Orama Network)\n")
sb.WriteString("# Generated automatically - manual edits may be overwritten\n\n") sb.WriteString("# Generated automatically - manual edits may be overwritten\n\n")
sb.WriteString("AgreeToTerms 1\n\n")
// Nickname // Nickname
sb.WriteString(fmt.Sprintf("Nickname %s\n", ari.config.Nickname)) sb.WriteString(fmt.Sprintf("Nickname %s\n", ari.config.Nickname))

View File

@ -367,15 +367,13 @@ require (
// If baseDomain is provided and different from domain, Caddy also serves // If baseDomain is provided and different from domain, Caddy also serves
// the base domain and its wildcard (e.g., *.dbrs.space alongside *.node1.dbrs.space). // the base domain and its wildcard (e.g., *.dbrs.space alongside *.node1.dbrs.space).
func (ci *CaddyInstaller) generateCaddyfile(domain, email, acmeEndpoint, baseDomain string) string { func (ci *CaddyInstaller) generateCaddyfile(domain, email, acmeEndpoint, baseDomain string) string {
// Primary: Let's Encrypt via ACME DNS-01 challenge // Let's Encrypt via ACME DNS-01 challenge (no fallback to self-signed)
// Fallback: Caddy's internal CA (self-signed, trust root on clients)
tlsBlock := fmt.Sprintf(` tls { tlsBlock := fmt.Sprintf(` tls {
issuer acme { issuer acme {
dns orama { dns orama {
endpoint %s endpoint %s
} }
} }
issuer internal
}`, acmeEndpoint) }`, acmeEndpoint)
var sb strings.Builder var sb strings.Builder

View File

@ -356,8 +356,12 @@ func (ci *CoreDNSInstaller) generateCorefile(domain, rqliteDSN string) string {
# CoreDNS cache would cache NXDOMAIN and break ACME DNS-01 challenges. # CoreDNS cache would cache NXDOMAIN and break ACME DNS-01 challenges.
} }
# Forward all other queries to upstream DNS # Forward non-authoritative queries to upstream DNS (localhost only).
# The bind directive restricts this block to 127.0.0.1 so the node itself
# can resolve external domains (apt, github, etc.) but external clients
# cannot use this server as an open recursive resolver (BSI/CERT-Bund).
. { . {
bind 127.0.0.1
forward . 8.8.8.8 8.8.4.4 1.1.1.1 forward . 8.8.8.8 8.8.4.4 1.1.1.1
cache 300 cache 300
errors errors

View File

@ -0,0 +1,151 @@
package installers
import (
"io"
"strings"
"testing"
)
// newTestCoreDNSInstaller creates a CoreDNSInstaller suitable for unit tests.
// It uses a non-existent oramaHome so generateCorefile won't find a password file
// and will produce output without auth credentials.
func newTestCoreDNSInstaller() *CoreDNSInstaller {
return &CoreDNSInstaller{
BaseInstaller: NewBaseInstaller("amd64", io.Discard),
version: "1.11.1",
oramaHome: "/nonexistent",
}
}
func TestGenerateCorefile_ContainsBindLocalhost(t *testing.T) {
ci := newTestCoreDNSInstaller()
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
if !strings.Contains(corefile, "bind 127.0.0.1") {
t.Fatal("Corefile forward block must contain 'bind 127.0.0.1' to prevent open resolver")
}
}
func TestGenerateCorefile_ForwardBlockIsLocalhostOnly(t *testing.T) {
ci := newTestCoreDNSInstaller()
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
// The bind directive must appear inside the catch-all (.) block,
// not inside the authoritative domain block.
// Find the ". {" block and verify bind is inside it.
dotBlockIdx := strings.Index(corefile, ". {")
if dotBlockIdx == -1 {
t.Fatal("Corefile must contain a catch-all '. {' server block")
}
dotBlock := corefile[dotBlockIdx:]
closingIdx := strings.Index(dotBlock, "}")
if closingIdx == -1 {
t.Fatal("Catch-all block has no closing brace")
}
dotBlock = dotBlock[:closingIdx]
if !strings.Contains(dotBlock, "bind 127.0.0.1") {
t.Error("bind 127.0.0.1 must be inside the catch-all (.) block, not the domain block")
}
if !strings.Contains(dotBlock, "forward .") {
t.Error("forward directive must be inside the catch-all (.) block")
}
}
func TestGenerateCorefile_AuthoritativeBlockNoBindRestriction(t *testing.T) {
ci := newTestCoreDNSInstaller()
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
// The authoritative domain block should NOT have a bind directive
// (it must listen on all interfaces to serve external DNS queries).
domainBlockStart := strings.Index(corefile, "dbrs.space {")
if domainBlockStart == -1 {
t.Fatal("Corefile must contain 'dbrs.space {' server block")
}
// Extract the domain block (up to the first closing brace)
domainBlock := corefile[domainBlockStart:]
closingIdx := strings.Index(domainBlock, "}")
if closingIdx == -1 {
t.Fatal("Domain block has no closing brace")
}
domainBlock = domainBlock[:closingIdx]
if strings.Contains(domainBlock, "bind ") {
t.Error("Authoritative domain block must not have a bind directive — it must listen on all interfaces")
}
}
func TestGenerateCorefile_ContainsDomainZone(t *testing.T) {
ci := newTestCoreDNSInstaller()
tests := []struct {
domain string
}{
{"dbrs.space"},
{"orama.network"},
{"example.com"},
}
for _, tt := range tests {
t.Run(tt.domain, func(t *testing.T) {
corefile := ci.generateCorefile(tt.domain, "http://localhost:5001")
if !strings.Contains(corefile, tt.domain+" {") {
t.Errorf("Corefile must contain server block for domain %q", tt.domain)
}
if !strings.Contains(corefile, "rqlite {") {
t.Error("Corefile must contain rqlite plugin block")
}
})
}
}
func TestGenerateCorefile_ContainsRQLiteDSN(t *testing.T) {
ci := newTestCoreDNSInstaller()
dsn := "http://10.0.0.1:5001"
corefile := ci.generateCorefile("dbrs.space", dsn)
if !strings.Contains(corefile, "dsn "+dsn) {
t.Errorf("Corefile must contain RQLite DSN %q", dsn)
}
}
func TestGenerateCorefile_NoAuthBlockWithoutCredentials(t *testing.T) {
ci := newTestCoreDNSInstaller()
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
if strings.Contains(corefile, "username") || strings.Contains(corefile, "password") {
t.Error("Corefile must not contain auth credentials when secrets file is absent")
}
}
func TestGeneratePluginConfig_ContainsBindPlugin(t *testing.T) {
ci := newTestCoreDNSInstaller()
cfg := ci.generatePluginConfig()
if !strings.Contains(cfg, "bind:bind") {
t.Error("Plugin config must include the bind plugin (required for localhost-only forwarding)")
}
}
func TestGeneratePluginConfig_ContainsACLPlugin(t *testing.T) {
ci := newTestCoreDNSInstaller()
cfg := ci.generatePluginConfig()
if !strings.Contains(cfg, "acl:acl") {
t.Error("Plugin config must include the acl plugin")
}
}
func TestGeneratePluginConfig_ContainsRQLitePlugin(t *testing.T) {
ci := newTestCoreDNSInstaller()
cfg := ci.generatePluginConfig()
if !strings.Contains(cfg, "rqlite:rqlite") {
t.Error("Plugin config must include the rqlite plugin")
}
}

View File

@ -419,7 +419,7 @@ Description=Caddy HTTP/2 Server
Documentation=https://caddyserver.com/docs/ Documentation=https://caddyserver.com/docs/
After=network-online.target orama-node.service coredns.service After=network-online.target orama-node.service coredns.service
Wants=network-online.target Wants=network-online.target
Wants=orama-node.service Requires=orama-node.service
[Service] [Service]
Type=simple Type=simple
@ -428,6 +428,9 @@ ReadWritePaths=%[2]s /var/lib/caddy /etc/caddy
Environment=XDG_DATA_HOME=/var/lib/caddy Environment=XDG_DATA_HOME=/var/lib/caddy
AmbientCapabilities=CAP_NET_BIND_SERVICE AmbientCapabilities=CAP_NET_BIND_SERVICE
CapabilityBoundingSet=CAP_NET_BIND_SERVICE CapabilityBoundingSet=CAP_NET_BIND_SERVICE
ExecStartPre=/bin/sh -c 'for i in $$(seq 1 30); do curl -so /dev/null http://localhost:6001/health 2>/dev/null && exit 0; sleep 2; done; echo "Gateway not ready after 60s"; exit 1'
ExecStartPre=/bin/sh -c 'DOMAIN=$$(grep -oP "^\*\\.\K[^ {]+" /etc/caddy/Caddyfile | tail -1); [ -z "$$DOMAIN" ] && exit 0; for i in $$(seq 1 30); do dig +short +timeout=2 "$$DOMAIN" SOA 2>/dev/null | grep -q . && exit 0; sleep 2; done; echo "DNS not resolving $$DOMAIN after 60s (ACME may fail)"; exit 0'
TimeoutStartSec=180
ExecStart=/usr/bin/caddy run --environ --config /etc/caddy/Caddyfile ExecStart=/usr/bin/caddy run --environ --config /etc/caddy/Caddyfile
ExecReload=/usr/bin/caddy reload --config /etc/caddy/Caddyfile ExecReload=/usr/bin/caddy reload --config /etc/caddy/Caddyfile
TimeoutStopSec=5s TimeoutStopSec=5s

View File

@ -78,6 +78,39 @@ func TestGenerateRQLiteService(t *testing.T) {
} }
} }
// TestGenerateCaddyService_GatewayReadinessCheck verifies Caddy waits for gateway before starting
func TestGenerateCaddyService_GatewayReadinessCheck(t *testing.T) {
ssg := &SystemdServiceGenerator{
oramaHome: "/opt/orama",
oramaDir: "/opt/orama/.orama",
}
unit := ssg.GenerateCaddyService()
// Must have ExecStartPre that polls gateway health
if !strings.Contains(unit, "ExecStartPre=") {
t.Error("missing ExecStartPre directive for gateway readiness check")
}
if !strings.Contains(unit, "localhost:6001/health") {
t.Error("ExecStartPre should poll localhost:6001/health")
}
// Must use Requires= (hard dependency), not Wants= (soft dependency)
if !strings.Contains(unit, "Requires=orama-node.service") {
t.Error("missing Requires=orama-node.service (hard dependency)")
}
if strings.Contains(unit, "Wants=orama-node.service") {
t.Error("should use Requires= not Wants= for orama-node.service dependency")
}
// ExecStartPre must appear before ExecStart
preIdx := strings.Index(unit, "ExecStartPre=")
startIdx := strings.Index(unit, "ExecStart=/usr/bin/caddy")
if preIdx < 0 || startIdx < 0 || preIdx >= startIdx {
t.Error("ExecStartPre must appear before ExecStart")
}
}
// TestGenerateRQLiteServiceArgs verifies the ExecStart command arguments // TestGenerateRQLiteServiceArgs verifies the ExecStart command arguments
func TestGenerateRQLiteServiceArgs(t *testing.T) { func TestGenerateRQLiteServiceArgs(t *testing.T) {
ssg := &SystemdServiceGenerator{ ssg := &SystemdServiceGenerator{

View File

@ -141,11 +141,15 @@ func (g *Gateway) healthHandler(w http.ResponseWriter, r *http.Request) {
ch <- nr ch <- nr
}() }()
// Vault Guardian (TCP connect to localhost:7500) // Vault Guardian (TCP connect on WireGuard IP:7500)
go func() { go func() {
nr := namedResult{name: "vault"} nr := namedResult{name: "vault"}
start := time.Now() start := time.Now()
conn, err := net.DialTimeout("tcp", "localhost:7500", 2*time.Second) vaultAddr := "localhost:7500"
if g.localWireGuardIP != "" {
vaultAddr = g.localWireGuardIP + ":7500"
}
conn, err := net.DialTimeout("tcp", vaultAddr, 2*time.Second)
if err != nil { if err != nil {
nr.result = checkResult{Status: "error", Latency: time.Since(start).String(), Error: fmt.Sprintf("vault-guardian unreachable on port 7500: %v", err)} nr.result = checkResult{Status: "error", Latency: time.Since(start).String(), Error: fmt.Sprintf("vault-guardian unreachable on port 7500: %v", err)}
} else { } else {

View File

@ -109,9 +109,10 @@ func (cm *ClusterConfigManager) DiscoverClusterPeersFromLibP2P(h host.Host) erro
info := h.Peerstore().PeerInfo(p) info := h.Peerstore().PeerInfo(p)
for _, addr := range info.Addrs { for _, addr := range info.Addrs {
// Extract IP from multiaddr // Extract IP from multiaddr — only use WireGuard IPs (10.0.0.x)
// for inter-node queries since port 6001 is blocked on public interfaces by UFW
ip := extractIPFromMultiaddr(addr) ip := extractIPFromMultiaddr(addr)
if ip != "" && !strings.HasPrefix(ip, "127.") && !strings.HasPrefix(ip, "::1") { if ip != "" && strings.HasPrefix(ip, "10.0.0.") {
peerIPs[ip] = true peerIPs[ip] = true
} }
} }

View File

@ -0,0 +1,95 @@
package ipfs
import (
"testing"
"github.com/multiformats/go-multiaddr"
)
func TestExtractIPFromMultiaddr(t *testing.T) {
tests := []struct {
name string
addr string
expected string
}{
{
name: "ipv4 tcp address",
addr: "/ip4/10.0.0.1/tcp/4001",
expected: "10.0.0.1",
},
{
name: "ipv4 public address",
addr: "/ip4/203.0.113.5/tcp/4001",
expected: "203.0.113.5",
},
{
name: "ipv4 loopback",
addr: "/ip4/127.0.0.1/tcp/4001",
expected: "127.0.0.1",
},
{
name: "ipv6 address",
addr: "/ip6/::1/tcp/4001",
expected: "[::1]",
},
{
name: "wireguard ip with udp",
addr: "/ip4/10.0.0.3/udp/4001/quic",
expected: "10.0.0.3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ma, err := multiaddr.NewMultiaddr(tt.addr)
if err != nil {
t.Fatalf("failed to parse multiaddr %q: %v", tt.addr, err)
}
got := extractIPFromMultiaddr(ma)
if got != tt.expected {
t.Errorf("extractIPFromMultiaddr(%q) = %q, want %q", tt.addr, got, tt.expected)
}
})
}
}
func TestExtractIPFromMultiaddr_Nil(t *testing.T) {
got := extractIPFromMultiaddr(nil)
if got != "" {
t.Errorf("extractIPFromMultiaddr(nil) = %q, want empty string", got)
}
}
// TestWireGuardIPFiltering verifies that only 10.0.0.x IPs would be selected
// for peer discovery queries. This tests the filtering logic used in
// DiscoverClusterPeersFromLibP2P.
func TestWireGuardIPFiltering(t *testing.T) {
tests := []struct {
name string
addr string
accepted bool
}{
{"wireguard ip", "/ip4/10.0.0.1/tcp/4001", true},
{"wireguard ip high", "/ip4/10.0.0.254/tcp/4001", true},
{"public ip", "/ip4/203.0.113.5/tcp/4001", false},
{"private 192.168", "/ip4/192.168.1.1/tcp/4001", false},
{"private 172.16", "/ip4/172.16.0.1/tcp/4001", false},
{"loopback", "/ip4/127.0.0.1/tcp/4001", false},
{"different 10.x subnet", "/ip4/10.1.0.1/tcp/4001", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ma, err := multiaddr.NewMultiaddr(tt.addr)
if err != nil {
t.Fatalf("failed to parse multiaddr: %v", err)
}
ip := extractIPFromMultiaddr(ma)
// Replicate the filtering logic from DiscoverClusterPeersFromLibP2P
accepted := ip != "" && len(ip) >= 7 && ip[:7] == "10.0.0."
if accepted != tt.accepted {
t.Errorf("IP %q: accepted=%v, want %v", ip, accepted, tt.accepted)
}
})
}
}

222
pkg/rwagent/client.go Normal file
View File

@ -0,0 +1,222 @@
package rwagent
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
const (
// DefaultSocketName is the socket file relative to ~/.rootwallet/.
DefaultSocketName = "agent.sock"
// DefaultTimeout for HTTP requests to the agent.
// Set high enough to allow pending approval flow (2 min approval timeout).
DefaultTimeout = 150 * time.Second
)
// Client communicates with the rootwallet agent daemon over a Unix socket.
type Client struct {
httpClient *http.Client
socketPath string
}
// New creates a client that connects to the agent's Unix socket.
// If socketPath is empty, defaults to ~/.rootwallet/agent.sock.
func New(socketPath string) *Client {
if socketPath == "" {
home, _ := os.UserHomeDir()
socketPath = filepath.Join(home, ".rootwallet", DefaultSocketName)
}
return &Client{
socketPath: socketPath,
httpClient: &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socketPath)
},
},
Timeout: DefaultTimeout,
},
}
}
// Status returns the agent's current status.
func (c *Client) Status(ctx context.Context) (*StatusResponse, error) {
var resp apiResponse[StatusResponse]
if err := c.doJSON(ctx, "GET", "/v1/status", nil, &resp); err != nil {
return nil, err
}
if !resp.OK {
return nil, c.apiError(resp.Error, resp.Code, 0)
}
return &resp.Data, nil
}
// IsRunning returns true if the agent is reachable.
func (c *Client) IsRunning(ctx context.Context) bool {
_, err := c.Status(ctx)
return err == nil
}
// GetSSHKey retrieves an SSH key from the vault.
// format: "priv", "pub", or "both".
func (c *Client) GetSSHKey(ctx context.Context, host, username, format string) (*VaultSSHData, error) {
path := fmt.Sprintf("/v1/vault/ssh/%s/%s?format=%s",
url.PathEscape(host),
url.PathEscape(username),
url.QueryEscape(format),
)
var resp apiResponse[VaultSSHData]
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
if !resp.OK {
return nil, c.apiError(resp.Error, resp.Code, 0)
}
return &resp.Data, nil
}
// CreateSSHEntry creates a new SSH key entry in the vault.
func (c *Client) CreateSSHEntry(ctx context.Context, host, username string) (*VaultSSHData, error) {
body := map[string]string{"host": host, "username": username}
var resp apiResponse[VaultSSHData]
if err := c.doJSON(ctx, "POST", "/v1/vault/ssh", body, &resp); err != nil {
return nil, err
}
if !resp.OK {
return nil, c.apiError(resp.Error, resp.Code, 0)
}
return &resp.Data, nil
}
// GetPassword retrieves a stored password from the vault.
func (c *Client) GetPassword(ctx context.Context, domain, username string) (*VaultPasswordData, error) {
path := fmt.Sprintf("/v1/vault/password/%s/%s",
url.PathEscape(domain),
url.PathEscape(username),
)
var resp apiResponse[VaultPasswordData]
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
if !resp.OK {
return nil, c.apiError(resp.Error, resp.Code, 0)
}
return &resp.Data, nil
}
// GetAddress returns the active wallet address.
func (c *Client) GetAddress(ctx context.Context, chain string) (*WalletAddressData, error) {
path := fmt.Sprintf("/v1/wallet/address?chain=%s", url.QueryEscape(chain))
var resp apiResponse[WalletAddressData]
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
if !resp.OK {
return nil, c.apiError(resp.Error, resp.Code, 0)
}
return &resp.Data, nil
}
// Unlock sends the password to unlock the agent.
func (c *Client) Unlock(ctx context.Context, password string, ttlMinutes int) error {
body := map[string]any{"password": password, "ttlMinutes": ttlMinutes}
var resp apiResponse[any]
if err := c.doJSON(ctx, "POST", "/v1/unlock", body, &resp); err != nil {
return err
}
if !resp.OK {
return c.apiError(resp.Error, resp.Code, 0)
}
return nil
}
// Lock locks the agent, zeroing all key material.
func (c *Client) Lock(ctx context.Context) error {
var resp apiResponse[any]
if err := c.doJSON(ctx, "POST", "/v1/lock", nil, &resp); err != nil {
return err
}
if !resp.OK {
return c.apiError(resp.Error, resp.Code, 0)
}
return nil
}
// doJSON performs an HTTP request and decodes the JSON response.
func (c *Client) doJSON(ctx context.Context, method, path string, body any, result any) error {
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("marshal request body: %w", err)
}
bodyReader = strings.NewReader(string(data))
}
// URL host is ignored for Unix sockets, but required by http.NewRequest
req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-RW-PID", strconv.Itoa(os.Getpid()))
resp, err := c.httpClient.Do(req)
if err != nil {
// Connection refused or socket not found = agent not running
if isConnectionError(err) {
return ErrAgentNotRunning
}
return fmt.Errorf("agent request failed: %w", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response: %w", err)
}
if err := json.Unmarshal(data, result); err != nil {
return fmt.Errorf("decode response: %w", err)
}
return nil
}
func (c *Client) apiError(message, code string, statusCode int) *AgentError {
return &AgentError{
Code: code,
Message: message,
StatusCode: statusCode,
}
}
// isConnectionError checks if the error is a connection-level failure.
func isConnectionError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "connection refused") ||
strings.Contains(msg, "no such file or directory") ||
strings.Contains(msg, "connect: no such file")
}

257
pkg/rwagent/client_test.go Normal file
View File

@ -0,0 +1,257 @@
package rwagent
import (
"context"
"encoding/json"
"net"
"net/http"
"os"
"path/filepath"
"testing"
)
// startMockAgent creates a mock agent server on a Unix socket for testing.
func startMockAgent(t *testing.T, handler http.Handler) (socketPath string, cleanup func()) {
t.Helper()
tmpDir := t.TempDir()
socketPath = filepath.Join(tmpDir, "test-agent.sock")
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("listen on unix socket: %v", err)
}
server := &http.Server{Handler: handler}
go func() { _ = server.Serve(listener) }()
cleanup = func() {
_ = server.Close()
_ = os.Remove(socketPath)
}
return socketPath, cleanup
}
// jsonHandler returns an http.HandlerFunc that responds with the given JSON.
func jsonHandler(statusCode int, body any) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
data, _ := json.Marshal(body)
_, _ = w.Write(data)
}
}
func TestStatus(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/status", jsonHandler(200, apiResponse[StatusResponse]{
OK: true,
Data: StatusResponse{
Version: "1.0.0",
Locked: false,
Uptime: 120,
PID: 12345,
},
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
status, err := client.Status(context.Background())
if err != nil {
t.Fatalf("Status() error: %v", err)
}
if status.Version != "1.0.0" {
t.Errorf("Version = %q, want %q", status.Version, "1.0.0")
}
if status.Locked {
t.Error("Locked = true, want false")
}
if status.Uptime != 120 {
t.Errorf("Uptime = %d, want 120", status.Uptime)
}
}
func TestIsRunning_true(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/status", jsonHandler(200, apiResponse[StatusResponse]{
OK: true,
Data: StatusResponse{Version: "1.0.0"},
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
if !client.IsRunning(context.Background()) {
t.Error("IsRunning() = false, want true")
}
}
func TestIsRunning_false(t *testing.T) {
client := New("/tmp/nonexistent-socket-test.sock")
if client.IsRunning(context.Background()) {
t.Error("IsRunning() = true, want false")
}
}
func TestGetSSHKey(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/vault/ssh/myhost/root", jsonHandler(200, apiResponse[VaultSSHData]{
OK: true,
Data: VaultSSHData{
PrivateKey: "-----BEGIN OPENSSH PRIVATE KEY-----\nfake\n-----END OPENSSH PRIVATE KEY-----",
PublicKey: "ssh-ed25519 AAAA... myhost/root",
},
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
data, err := client.GetSSHKey(context.Background(), "myhost", "root", "both")
if err != nil {
t.Fatalf("GetSSHKey() error: %v", err)
}
if data.PrivateKey == "" {
t.Error("PrivateKey is empty")
}
if data.PublicKey == "" {
t.Error("PublicKey is empty")
}
}
func TestGetSSHKey_locked(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/vault/ssh/myhost/root", jsonHandler(423, apiResponse[any]{
OK: false,
Error: "Agent is locked",
Code: "AGENT_LOCKED",
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
_, err := client.GetSSHKey(context.Background(), "myhost", "root", "priv")
if err == nil {
t.Fatal("GetSSHKey() expected error, got nil")
}
if !IsLocked(err) {
t.Errorf("IsLocked() = false for error: %v", err)
}
}
func TestGetSSHKey_notFound(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/vault/ssh/unknown/user", jsonHandler(404, apiResponse[any]{
OK: false,
Error: "No SSH key found for unknown/user",
Code: "NOT_FOUND",
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
_, err := client.GetSSHKey(context.Background(), "unknown", "user", "priv")
if err == nil {
t.Fatal("GetSSHKey() expected error, got nil")
}
if !IsNotFound(err) {
t.Errorf("IsNotFound() = false for error: %v", err)
}
}
func TestGetPassword(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/vault/password/example.com/admin", jsonHandler(200, apiResponse[VaultPasswordData]{
OK: true,
Data: VaultPasswordData{Password: "secret123"},
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
data, err := client.GetPassword(context.Background(), "example.com", "admin")
if err != nil {
t.Fatalf("GetPassword() error: %v", err)
}
if data.Password != "secret123" {
t.Errorf("Password = %q, want %q", data.Password, "secret123")
}
}
func TestCreateSSHEntry(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/vault/ssh", jsonHandler(201, apiResponse[VaultSSHData]{
OK: true,
Data: VaultSSHData{PublicKey: "ssh-ed25519 AAAA... new/entry"},
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
data, err := client.CreateSSHEntry(context.Background(), "new", "entry")
if err != nil {
t.Fatalf("CreateSSHEntry() error: %v", err)
}
if data.PublicKey == "" {
t.Error("PublicKey is empty")
}
}
func TestGetAddress(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/wallet/address", jsonHandler(200, apiResponse[WalletAddressData]{
OK: true,
Data: WalletAddressData{Address: "0x1234abcd", Chain: "evm"},
}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
data, err := client.GetAddress(context.Background(), "evm")
if err != nil {
t.Fatalf("GetAddress() error: %v", err)
}
if data.Address != "0x1234abcd" {
t.Errorf("Address = %q, want %q", data.Address, "0x1234abcd")
}
}
func TestAgentNotRunning(t *testing.T) {
client := New("/tmp/nonexistent-socket-for-testing.sock")
_, err := client.Status(context.Background())
if err == nil {
t.Fatal("expected error, got nil")
}
if !IsNotRunning(err) {
t.Errorf("IsNotRunning() = false for error: %v", err)
}
}
func TestUnlockAndLock(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/v1/unlock", jsonHandler(200, apiResponse[any]{OK: true}))
mux.HandleFunc("/v1/lock", jsonHandler(200, apiResponse[any]{OK: true}))
sock, cleanup := startMockAgent(t, mux)
defer cleanup()
client := New(sock)
if err := client.Unlock(context.Background(), "password", 30); err != nil {
t.Fatalf("Unlock() error: %v", err)
}
if err := client.Lock(context.Background()); err != nil {
t.Fatalf("Lock() error: %v", err)
}
}

57
pkg/rwagent/errors.go Normal file
View File

@ -0,0 +1,57 @@
package rwagent
import (
"errors"
"fmt"
)
// AgentError represents an error returned by the rootwallet agent API.
type AgentError struct {
Code string // e.g., "AGENT_LOCKED", "NOT_FOUND"
Message string
StatusCode int
}
func (e *AgentError) Error() string {
return fmt.Sprintf("rootwallet agent: %s (%s)", e.Message, e.Code)
}
// IsLocked returns true if the error indicates the agent is locked.
func IsLocked(err error) bool {
var ae *AgentError
if errors.As(err, &ae) {
return ae.Code == "AGENT_LOCKED"
}
return false
}
// IsNotRunning returns true if the error indicates the agent is not reachable.
func IsNotRunning(err error) bool {
var ae *AgentError
if errors.As(err, &ae) {
return ae.Code == "AGENT_NOT_RUNNING"
}
// Also check for connection errors
return errors.Is(err, ErrAgentNotRunning)
}
// IsNotFound returns true if the vault entry was not found.
func IsNotFound(err error) bool {
var ae *AgentError
if errors.As(err, &ae) {
return ae.Code == "NOT_FOUND"
}
return false
}
// IsApprovalDenied returns true if the user denied the app's access request.
func IsApprovalDenied(err error) bool {
var ae *AgentError
if errors.As(err, &ae) {
return ae.Code == "APPROVAL_DENIED" || ae.Code == "PERMISSION_DENIED"
}
return false
}
// ErrAgentNotRunning is returned when the agent socket is not reachable.
var ErrAgentNotRunning = fmt.Errorf("rootwallet agent is not running — start with: rw agent start && rw agent unlock")

56
pkg/rwagent/types.go Normal file
View File

@ -0,0 +1,56 @@
// Package rwagent provides a Go client for the RootWallet agent daemon.
//
// The agent is a persistent daemon that holds vault keys in memory and serves
// operations to authorized apps over a Unix socket HTTP API. This SDK replaces
// all subprocess `rw` calls with direct HTTP communication.
package rwagent
// StatusResponse from GET /v1/status.
type StatusResponse struct {
Version string `json:"version"`
Locked bool `json:"locked"`
Uptime int `json:"uptime"`
PID int `json:"pid"`
ConnectedApps int `json:"connectedApps"`
}
// VaultSSHData from GET /v1/vault/ssh/:host/:user.
type VaultSSHData struct {
PrivateKey string `json:"privateKey,omitempty"`
PublicKey string `json:"publicKey,omitempty"`
}
// VaultPasswordData from GET /v1/vault/password/:domain/:user.
type VaultPasswordData struct {
Password string `json:"password"`
}
// WalletAddressData from GET /v1/wallet/address.
type WalletAddressData struct {
Address string `json:"address"`
Chain string `json:"chain"`
}
// AppPermission represents an approved app in the permission database.
type AppPermission struct {
BinaryHash string `json:"binaryHash"`
BinaryPath string `json:"binaryPath"`
Name string `json:"name"`
FirstSeen string `json:"firstSeen"`
LastUsed string `json:"lastUsed"`
Capabilities []PermittedCapability `json:"capabilities"`
}
// PermittedCapability is a specific capability granted to an app.
type PermittedCapability struct {
Capability string `json:"capability"`
GrantedAt string `json:"grantedAt"`
}
// apiResponse is the generic API response envelope.
type apiResponse[T any] struct {
OK bool `json:"ok"`
Data T `json:"data,omitempty"`
Error string `json:"error,omitempty"`
Code string `json:"code,omitempty"`
}