mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-27 12:24:12 +00:00
refactor(remotessh): use rwagent directly instead of rw CLI subprocesses
- replace `rw vault ssh` calls with `rwagent.Client` in PrepareNodeKeys, LoadAgentKeys, EnsureVaultEntry, ResolveVaultPublicKey - add vaultClient interface, newClient func, and wrapAgentError for testability and improved error messages - prefer pre-built systemd dir in installNamespaceTemplates
This commit is contained in:
parent
fa826f0d00
commit
0764ac287e
@ -597,7 +597,11 @@ func promptForBaseDomain() string {
|
|||||||
|
|
||||||
// installNamespaceTemplates installs systemd template unit files for namespace services
|
// installNamespaceTemplates installs systemd template unit files for namespace services
|
||||||
func (o *Orchestrator) installNamespaceTemplates() error {
|
func (o *Orchestrator) installNamespaceTemplates() error {
|
||||||
sourceDir := filepath.Join(o.oramaHome, "src", "systemd")
|
// Check pre-built archive path first, fall back to source path
|
||||||
|
sourceDir := production.OramaSystemdDir
|
||||||
|
if _, err := os.Stat(sourceDir); os.IsNotExist(err) {
|
||||||
|
sourceDir = filepath.Join(o.oramaHome, "src", "systemd")
|
||||||
|
}
|
||||||
systemdDir := "/etc/systemd/system"
|
systemdDir := "/etc/systemd/system"
|
||||||
|
|
||||||
templates := []string{
|
templates := []string{
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package remotessh
|
package remotessh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
@ -8,11 +9,40 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/inspector"
|
"github.com/DeBrosOfficial/network/pkg/inspector"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/rwagent"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// vaultClient is the interface used by wallet functions to talk to the agent.
|
||||||
|
// Defaults to the real rwagent.Client; tests replace it with a mock.
|
||||||
|
type vaultClient interface {
|
||||||
|
GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error)
|
||||||
|
CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newClient creates the default vaultClient. Package-level var for test injection.
|
||||||
|
var newClient func() vaultClient = func() vaultClient {
|
||||||
|
return rwagent.New(os.Getenv("RW_AGENT_SOCK"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapAgentError wraps rwagent errors with user-friendly messages.
|
||||||
|
// When the agent is locked, it also triggers the RootWallet desktop app
|
||||||
|
// to show the unlock dialog via deep link (best-effort, fire-and-forget).
|
||||||
|
func wrapAgentError(err error, action string) error {
|
||||||
|
if rwagent.IsNotRunning(err) {
|
||||||
|
return fmt.Errorf("%s: rootwallet agent is not running — start with: rw agent start && rw agent unlock", action)
|
||||||
|
}
|
||||||
|
if rwagent.IsLocked(err) {
|
||||||
|
return fmt.Errorf("%s: rootwallet agent is locked — unlock timed out after waiting. Unlock via the RootWallet app or run: rw agent unlock", action)
|
||||||
|
}
|
||||||
|
if rwagent.IsApprovalDenied(err) {
|
||||||
|
return fmt.Errorf("%s: rootwallet access denied — approve this app in the RootWallet desktop app", action)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%s: %w", action, err)
|
||||||
|
}
|
||||||
|
|
||||||
// PrepareNodeKeys resolves wallet-derived SSH keys for all nodes.
|
// PrepareNodeKeys resolves wallet-derived SSH keys for all nodes.
|
||||||
// Calls `rw vault ssh get <host>/<user> --priv` for each unique host/user,
|
// Retrieves private keys from the rootwallet agent daemon, writes PEMs to
|
||||||
// writes PEMs to temp files, and sets node.SSHKey for each node.
|
// temp files, and sets node.SSHKey for each node.
|
||||||
//
|
//
|
||||||
// The nodes slice is modified in place — each node.SSHKey is set to
|
// The nodes slice is modified in place — each node.SSHKey is set to
|
||||||
// the path of the temporary key file.
|
// the path of the temporary key file.
|
||||||
@ -20,10 +50,8 @@ import (
|
|||||||
// Returns a cleanup function that zero-overwrites and removes all temp files.
|
// Returns a cleanup function that zero-overwrites and removes all temp files.
|
||||||
// Caller must defer cleanup().
|
// Caller must defer cleanup().
|
||||||
func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
|
func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
|
||||||
rw, err := rwBinary()
|
client := newClient()
|
||||||
if err != nil {
|
ctx := context.Background()
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create temp dir for all keys
|
// Create temp dir for all keys
|
||||||
tmpDir, err := os.MkdirTemp("", "orama-ssh-")
|
tmpDir, err := os.MkdirTemp("", "orama-ssh-")
|
||||||
@ -31,12 +59,11 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
|
|||||||
return nil, fmt.Errorf("create temp dir: %w", err)
|
return nil, fmt.Errorf("create temp dir: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track resolved keys by host/user to avoid duplicate rw calls
|
// Track resolved keys by host/user to avoid duplicate agent calls
|
||||||
keyPaths := make(map[string]string) // "host/user" → temp file path
|
keyPaths := make(map[string]string) // "host/user" → temp file path
|
||||||
var allKeyPaths []string
|
var allKeyPaths []string
|
||||||
|
|
||||||
for i := range nodes {
|
for i := range nodes {
|
||||||
// Use VaultTarget if set, otherwise default to Host/User
|
|
||||||
var key string
|
var key string
|
||||||
if nodes[i].VaultTarget != "" {
|
if nodes[i].VaultTarget != "" {
|
||||||
key = nodes[i].VaultTarget
|
key = nodes[i].VaultTarget
|
||||||
@ -48,18 +75,21 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call rw to get the private key PEM
|
|
||||||
host, user := parseVaultTarget(key)
|
host, user := parseVaultTarget(key)
|
||||||
pem, err := resolveWalletKey(rw, host, user)
|
data, err := client.GetSSHKey(ctx, host, user, "priv")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Cleanup any keys already written before returning error
|
|
||||||
cleanupKeys(tmpDir, allKeyPaths)
|
cleanupKeys(tmpDir, allKeyPaths)
|
||||||
return nil, fmt.Errorf("resolve key for %s: %w", nodes[i].Name(), err)
|
return nil, wrapAgentError(err, fmt.Sprintf("resolve key for %s", nodes[i].Name()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(data.PrivateKey, "BEGIN OPENSSH PRIVATE KEY") {
|
||||||
|
cleanupKeys(tmpDir, allKeyPaths)
|
||||||
|
return nil, fmt.Errorf("agent returned invalid key for %s", nodes[i].Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write PEM to temp file with restrictive perms
|
// Write PEM to temp file with restrictive perms
|
||||||
keyFile := filepath.Join(tmpDir, fmt.Sprintf("id_%d", i))
|
keyFile := filepath.Join(tmpDir, fmt.Sprintf("id_%d", i))
|
||||||
if err := os.WriteFile(keyFile, []byte(pem), 0600); err != nil {
|
if err := os.WriteFile(keyFile, []byte(data.PrivateKey), 0600); err != nil {
|
||||||
cleanupKeys(tmpDir, allKeyPaths)
|
cleanupKeys(tmpDir, allKeyPaths)
|
||||||
return nil, fmt.Errorf("write key for %s: %w", nodes[i].Name(), err)
|
return nil, fmt.Errorf("write key for %s: %w", nodes[i].Name(), err)
|
||||||
}
|
}
|
||||||
@ -77,12 +107,10 @@ func PrepareNodeKeys(nodes []inspector.Node) (cleanup func(), err error) {
|
|||||||
|
|
||||||
// LoadAgentKeys loads SSH keys for the given nodes into the system ssh-agent.
|
// LoadAgentKeys loads SSH keys for the given nodes into the system ssh-agent.
|
||||||
// Used by push fanout to enable agent forwarding.
|
// Used by push fanout to enable agent forwarding.
|
||||||
// Calls `rw vault ssh agent-load <host1/user1> <host2/user2> ...`
|
// Retrieves private keys from the rootwallet agent and pipes them to ssh-add.
|
||||||
func LoadAgentKeys(nodes []inspector.Node) error {
|
func LoadAgentKeys(nodes []inspector.Node) error {
|
||||||
rw, err := rwBinary()
|
client := newClient()
|
||||||
if err != nil {
|
ctx := context.Background()
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deduplicate host/user pairs
|
// Deduplicate host/user pairs
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
@ -105,76 +133,65 @@ func LoadAgentKeys(nodes []inspector.Node) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
args := append([]string{"vault", "ssh", "agent-load"}, targets...)
|
for _, target := range targets {
|
||||||
cmd := exec.Command(rw, args...)
|
host, user := parseVaultTarget(target)
|
||||||
cmd.Stderr = os.Stderr
|
data, err := client.GetSSHKey(ctx, host, user, "priv")
|
||||||
cmd.Stdout = os.Stderr // info messages go to stderr
|
if err != nil {
|
||||||
|
return wrapAgentError(err, fmt.Sprintf("get key for %s", target))
|
||||||
|
}
|
||||||
|
|
||||||
if err := cmd.Run(); err != nil {
|
// Pipe private key to ssh-add via stdin
|
||||||
return fmt.Errorf("rw vault ssh agent-load failed: %w", err)
|
cmd := exec.Command("ssh-add", "-")
|
||||||
|
cmd.Stdin = strings.NewReader(data.PrivateKey)
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("ssh-add failed for %s: %w", target, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnsureVaultEntry creates a wallet SSH entry if it doesn't already exist.
|
// EnsureVaultEntry creates a wallet SSH entry if it doesn't already exist.
|
||||||
// Checks existence via `rw vault ssh get <target> --pub`, and if missing,
|
// Checks the rootwallet agent for an existing entry, creates one if not found.
|
||||||
// runs `rw vault ssh add <target>` to create it.
|
|
||||||
func EnsureVaultEntry(vaultTarget string) error {
|
func EnsureVaultEntry(vaultTarget string) error {
|
||||||
rw, err := rwBinary()
|
client := newClient()
|
||||||
if err != nil {
|
ctx := context.Background()
|
||||||
return err
|
|
||||||
|
host, user := parseVaultTarget(vaultTarget)
|
||||||
|
|
||||||
|
// Check if entry already exists
|
||||||
|
_, err := client.GetSSHKey(ctx, host, user, "pub")
|
||||||
|
if err == nil {
|
||||||
|
return nil // entry exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if entry exists by trying to get the public key
|
// If not found, create it
|
||||||
cmd := exec.Command(rw, "vault", "ssh", "get", vaultTarget, "--pub")
|
if rwagent.IsNotFound(err) {
|
||||||
if err := cmd.Run(); err == nil {
|
_, createErr := client.CreateSSHEntry(ctx, host, user)
|
||||||
return nil // entry already exists
|
if createErr != nil {
|
||||||
}
|
return wrapAgentError(createErr, fmt.Sprintf("create vault entry %s", vaultTarget))
|
||||||
|
|
||||||
// Entry doesn't exist — try to create it
|
|
||||||
addCmd := exec.Command(rw, "vault", "ssh", "add", vaultTarget)
|
|
||||||
addCmd.Stdin = os.Stdin
|
|
||||||
addCmd.Stdout = os.Stderr
|
|
||||||
addCmd.Stderr = os.Stderr
|
|
||||||
if err := addCmd.Run(); err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
stderr := strings.TrimSpace(string(exitErr.Stderr))
|
|
||||||
if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") {
|
|
||||||
return fmt.Errorf("wallet is locked — run: rw unlock")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return fmt.Errorf("rw vault ssh add %s failed: %w", vaultTarget, err)
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return wrapAgentError(err, fmt.Sprintf("check vault entry %s", vaultTarget))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveVaultPublicKey returns the OpenSSH public key string for a vault entry.
|
// ResolveVaultPublicKey returns the OpenSSH public key string for a vault entry.
|
||||||
// Calls `rw vault ssh get <target> --pub`.
|
|
||||||
func ResolveVaultPublicKey(vaultTarget string) (string, error) {
|
func ResolveVaultPublicKey(vaultTarget string) (string, error) {
|
||||||
rw, err := rwBinary()
|
client := newClient()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
host, user := parseVaultTarget(vaultTarget)
|
||||||
|
data, err := client.GetSSHKey(ctx, host, user, "pub")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", wrapAgentError(err, fmt.Sprintf("get public key for %s", vaultTarget))
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.Command(rw, "vault", "ssh", "get", vaultTarget, "--pub")
|
pubKey := strings.TrimSpace(data.PublicKey)
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
stderr := strings.TrimSpace(string(exitErr.Stderr))
|
|
||||||
if strings.Contains(stderr, "No SSH entry") {
|
|
||||||
return "", fmt.Errorf("no vault SSH entry for %s — run: rw vault ssh add %s", vaultTarget, vaultTarget)
|
|
||||||
}
|
|
||||||
if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") {
|
|
||||||
return "", fmt.Errorf("wallet is locked — run: rw unlock")
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("%s", stderr)
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("rw command failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey := strings.TrimSpace(string(out))
|
|
||||||
if !strings.HasPrefix(pubKey, "ssh-") {
|
if !strings.HasPrefix(pubKey, "ssh-") {
|
||||||
return "", fmt.Errorf("rw returned invalid public key for %s", vaultTarget)
|
return "", fmt.Errorf("agent returned invalid public key for %s", vaultTarget)
|
||||||
}
|
}
|
||||||
return pubKey, nil
|
return pubKey, nil
|
||||||
}
|
}
|
||||||
@ -188,49 +205,6 @@ func parseVaultTarget(target string) (host, user string) {
|
|||||||
return target[:idx], target[idx+1:]
|
return target[:idx], target[idx+1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveWalletKey calls `rw vault ssh get <host>/<user> --priv`
|
|
||||||
// and returns the PEM string. Requires an active rw session.
|
|
||||||
func resolveWalletKey(rw string, host, user string) (string, error) {
|
|
||||||
target := host + "/" + user
|
|
||||||
cmd := exec.Command(rw, "vault", "ssh", "get", target, "--priv")
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
stderr := strings.TrimSpace(string(exitErr.Stderr))
|
|
||||||
if strings.Contains(stderr, "No SSH entry") {
|
|
||||||
return "", fmt.Errorf("no vault SSH entry for %s — run: rw vault ssh add %s", target, target)
|
|
||||||
}
|
|
||||||
if strings.Contains(stderr, "not unlocked") || strings.Contains(stderr, "session") {
|
|
||||||
return "", fmt.Errorf("wallet is locked — run: rw unlock")
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("%s", stderr)
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("rw command failed: %w", err)
|
|
||||||
}
|
|
||||||
pem := string(out)
|
|
||||||
if !strings.Contains(pem, "BEGIN OPENSSH PRIVATE KEY") {
|
|
||||||
return "", fmt.Errorf("rw returned invalid key for %s", target)
|
|
||||||
}
|
|
||||||
return pem, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rwBinary returns the path to the `rw` binary.
|
|
||||||
// Checks RW_PATH env var first, then PATH.
|
|
||||||
func rwBinary() (string, error) {
|
|
||||||
if p := os.Getenv("RW_PATH"); p != "" {
|
|
||||||
if _, err := os.Stat(p); err == nil {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("RW_PATH=%q not found", p)
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := exec.LookPath("rw")
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("rw not found in PATH — install rootwallet CLI: https://github.com/DeBrosOfficial/rootwallet")
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupKeys zero-overwrites and removes all key files, then removes the temp dir.
|
// cleanupKeys zero-overwrites and removes all key files, then removes the temp dir.
|
||||||
func cleanupKeys(tmpDir string, keyPaths []string) {
|
func cleanupKeys(tmpDir string, keyPaths []string) {
|
||||||
zeros := make([]byte, 512)
|
zeros := make([]byte, 512)
|
||||||
|
|||||||
@ -1,6 +1,39 @@
|
|||||||
package remotessh
|
package remotessh
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/inspector"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/rwagent"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testPrivateKey = "-----BEGIN OPENSSH PRIVATE KEY-----\nfake-key-data\n-----END OPENSSH PRIVATE KEY-----"
|
||||||
|
|
||||||
|
// mockClient implements vaultClient for testing.
|
||||||
|
type mockClient struct {
|
||||||
|
getSSHKey func(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error)
|
||||||
|
createSSHEntry func(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return m.getSSHKey(ctx, host, username, format)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockClient) CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return m.createSSHEntry(ctx, host, username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// withMockClient replaces newClient for the duration of a test.
|
||||||
|
func withMockClient(t *testing.T, mock *mockClient) {
|
||||||
|
t.Helper()
|
||||||
|
orig := newClient
|
||||||
|
newClient = func() vaultClient { return mock }
|
||||||
|
t.Cleanup(func() { newClient = orig })
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseVaultTarget(t *testing.T) {
|
func TestParseVaultTarget(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@ -13,6 +46,7 @@ func TestParseVaultTarget(t *testing.T) {
|
|||||||
{"my-host/my-user", "my-host", "my-user"},
|
{"my-host/my-user", "my-host", "my-user"},
|
||||||
{"noslash", "noslash", ""},
|
{"noslash", "noslash", ""},
|
||||||
{"a/b/c", "a", "b/c"},
|
{"a/b/c", "a", "b/c"},
|
||||||
|
{"", "", ""},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -27,3 +61,316 @@ func TestParseVaultTarget(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWrapAgentError_notRunning(t *testing.T) {
|
||||||
|
err := wrapAgentError(rwagent.ErrAgentNotRunning, "test action")
|
||||||
|
if !strings.Contains(err.Error(), "not running") {
|
||||||
|
t.Errorf("expected 'not running' message, got: %s", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "rw agent start") {
|
||||||
|
t.Errorf("expected actionable hint, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapAgentError_locked(t *testing.T) {
|
||||||
|
agentErr := &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "agent is locked"}
|
||||||
|
err := wrapAgentError(agentErr, "test action")
|
||||||
|
if !strings.Contains(err.Error(), "locked") {
|
||||||
|
t.Errorf("expected 'locked' message, got: %s", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "rw agent unlock") {
|
||||||
|
t.Errorf("expected actionable hint, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapAgentError_generic(t *testing.T) {
|
||||||
|
err := wrapAgentError(errors.New("some error"), "test action")
|
||||||
|
if !strings.Contains(err.Error(), "test action") {
|
||||||
|
t.Errorf("expected action context, got: %s", err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "some error") {
|
||||||
|
t.Errorf("expected wrapped error, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNodeKeys_success(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
nodes := []inspector.Node{
|
||||||
|
{Host: "10.0.0.1", User: "root"},
|
||||||
|
{Host: "10.0.0.2", User: "root"},
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup, err := PrepareNodeKeys(nodes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PrepareNodeKeys() error = %v", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
for i, n := range nodes {
|
||||||
|
if n.SSHKey == "" {
|
||||||
|
t.Errorf("node[%d].SSHKey is empty", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(n.SSHKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("node[%d] key file unreadable: %v", i, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(data), "BEGIN OPENSSH PRIVATE KEY") {
|
||||||
|
t.Errorf("node[%d] key file has wrong content", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNodeKeys_deduplication(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
||||||
|
callCount++
|
||||||
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
nodes := []inspector.Node{
|
||||||
|
{Host: "10.0.0.1", User: "root"},
|
||||||
|
{Host: "10.0.0.1", User: "root"}, // same host/user
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup, err := PrepareNodeKeys(nodes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PrepareNodeKeys() error = %v", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if callCount != 1 {
|
||||||
|
t.Errorf("expected 1 agent call (dedup), got %d", callCount)
|
||||||
|
}
|
||||||
|
if nodes[0].SSHKey != nodes[1].SSHKey {
|
||||||
|
t.Error("expected same key path for deduplicated nodes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNodeKeys_vaultTarget(t *testing.T) {
|
||||||
|
var capturedHost, capturedUser string
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
||||||
|
capturedHost = host
|
||||||
|
capturedUser = username
|
||||||
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
nodes := []inspector.Node{
|
||||||
|
{Host: "10.0.0.1", User: "root", VaultTarget: "sandbox/admin"},
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanup, err := PrepareNodeKeys(nodes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PrepareNodeKeys() error = %v", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if capturedHost != "sandbox" || capturedUser != "admin" {
|
||||||
|
t.Errorf("expected host=sandbox user=admin, got host=%s user=%s", capturedHost, capturedUser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNodeKeys_agentNotRunning(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return nil, rwagent.ErrAgentNotRunning
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
|
||||||
|
_, err := PrepareNodeKeys(nodes)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "not running") {
|
||||||
|
t.Errorf("expected 'not running' error, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNodeKeys_invalidKey(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return &rwagent.VaultSSHData{PrivateKey: "garbage"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
|
||||||
|
_, err := PrepareNodeKeys(nodes)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid key")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "invalid key") {
|
||||||
|
t.Errorf("expected 'invalid key' error, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNodeKeys_cleanupOnError(t *testing.T) {
|
||||||
|
callNum := 0
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
callNum++
|
||||||
|
if callNum == 2 {
|
||||||
|
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
|
||||||
|
}
|
||||||
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
nodes := []inspector.Node{
|
||||||
|
{Host: "10.0.0.1", User: "root"},
|
||||||
|
{Host: "10.0.0.2", User: "root"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := PrepareNodeKeys(nodes)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// First node's temp file should have been cleaned up
|
||||||
|
if nodes[0].SSHKey != "" {
|
||||||
|
if _, statErr := os.Stat(nodes[0].SSHKey); statErr == nil {
|
||||||
|
t.Error("expected temp key file to be cleaned up on error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrepareNodeKeys_emptyNodes(t *testing.T) {
|
||||||
|
mock := &mockClient{}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
cleanup, err := PrepareNodeKeys(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error for empty nodes, got: %v", err)
|
||||||
|
}
|
||||||
|
cleanup() // should not panic
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureVaultEntry_exists(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
if err := EnsureVaultEntry("sandbox/root"); err != nil {
|
||||||
|
t.Fatalf("EnsureVaultEntry() error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureVaultEntry_creates(t *testing.T) {
|
||||||
|
created := false
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
|
||||||
|
},
|
||||||
|
createSSHEntry: func(_ context.Context, host, username string) (*rwagent.VaultSSHData, error) {
|
||||||
|
created = true
|
||||||
|
if host != "sandbox" || username != "root" {
|
||||||
|
t.Errorf("unexpected create args: %s/%s", host, username)
|
||||||
|
}
|
||||||
|
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
if err := EnsureVaultEntry("sandbox/root"); err != nil {
|
||||||
|
t.Fatalf("EnsureVaultEntry() error = %v", err)
|
||||||
|
}
|
||||||
|
if !created {
|
||||||
|
t.Error("expected CreateSSHEntry to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureVaultEntry_locked(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
err := EnsureVaultEntry("sandbox/root")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "locked") {
|
||||||
|
t.Errorf("expected locked error, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveVaultPublicKey_success(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, format string) (*rwagent.VaultSSHData, error) {
|
||||||
|
if format != "pub" {
|
||||||
|
t.Errorf("expected format=pub, got %s", format)
|
||||||
|
}
|
||||||
|
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA..."}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
key, err := ResolveVaultPublicKey("sandbox/root")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ResolveVaultPublicKey() error = %v", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(key, "ssh-") {
|
||||||
|
t.Errorf("expected ssh- prefix, got: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveVaultPublicKey_invalidFormat(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return &rwagent.VaultSSHData{PublicKey: "not-a-valid-key"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
_, err := ResolveVaultPublicKey("sandbox/root")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid public key")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "invalid public key") {
|
||||||
|
t.Errorf("expected 'invalid public key' error, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveVaultPublicKey_notFound(t *testing.T) {
|
||||||
|
mock := &mockClient{
|
||||||
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
||||||
|
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
_, err := ResolveVaultPublicKey("sandbox/root")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadAgentKeys_emptyNodes(t *testing.T) {
|
||||||
|
mock := &mockClient{}
|
||||||
|
withMockClient(t, mock)
|
||||||
|
|
||||||
|
if err := LoadAgentKeys(nil); err != nil {
|
||||||
|
t.Fatalf("expected no error for empty nodes, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -75,8 +75,7 @@ func Create(name string) error {
|
|||||||
|
|
||||||
// Phase 4: Install genesis node
|
// Phase 4: Install genesis node
|
||||||
fmt.Println("\nPhase 4: Installing genesis node...")
|
fmt.Println("\nPhase 4: Installing genesis node...")
|
||||||
tokens, err := phase4InstallGenesis(cfg, state, sshKeyPath)
|
if err := phase4InstallGenesis(cfg, state, sshKeyPath); err != nil {
|
||||||
if err != nil {
|
|
||||||
state.Status = StatusError
|
state.Status = StatusError
|
||||||
SaveState(state)
|
SaveState(state)
|
||||||
return fmt.Errorf("install genesis: %w", err)
|
return fmt.Errorf("install genesis: %w", err)
|
||||||
@ -84,7 +83,7 @@ func Create(name string) error {
|
|||||||
|
|
||||||
// Phase 5: Join remaining nodes
|
// Phase 5: Join remaining nodes
|
||||||
fmt.Println("\nPhase 5: Joining remaining nodes...")
|
fmt.Println("\nPhase 5: Joining remaining nodes...")
|
||||||
if err := phase5JoinNodes(cfg, state, tokens, sshKeyPath); err != nil {
|
if err := phase5JoinNodes(cfg, state, sshKeyPath); err != nil {
|
||||||
state.Status = StatusError
|
state.Status = StatusError
|
||||||
SaveState(state)
|
SaveState(state)
|
||||||
return fmt.Errorf("join nodes: %w", err)
|
return fmt.Errorf("join nodes: %w", err)
|
||||||
@ -280,60 +279,52 @@ func phase3UploadArchive(state *SandboxState, sshKeyPath string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// phase4InstallGenesis installs the genesis node and generates invite tokens.
|
// phase4InstallGenesis installs the genesis node.
|
||||||
func phase4InstallGenesis(cfg *Config, state *SandboxState, sshKeyPath string) ([]string, error) {
|
func phase4InstallGenesis(cfg *Config, state *SandboxState, sshKeyPath string) error {
|
||||||
genesis := state.GenesisServer()
|
genesis := state.GenesisServer()
|
||||||
node := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath}
|
node := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath}
|
||||||
|
|
||||||
// Install genesis
|
// Install genesis
|
||||||
installCmd := fmt.Sprintf("/opt/orama/bin/orama node install --vps-ip %s --domain %s --base-domain %s --nameserver --skip-checks",
|
installCmd := fmt.Sprintf("/opt/orama/bin/orama node install --vps-ip %s --domain %s --base-domain %s --nameserver --anyone-client --skip-checks",
|
||||||
genesis.IP, cfg.Domain, cfg.Domain)
|
genesis.IP, cfg.Domain, cfg.Domain)
|
||||||
fmt.Printf(" Installing on %s (%s)...\n", genesis.Name, genesis.IP)
|
fmt.Printf(" Installing on %s (%s)...\n", genesis.Name, genesis.IP)
|
||||||
if err := remotessh.RunSSHStreaming(node, installCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
if err := remotessh.RunSSHStreaming(node, installCmd, remotessh.WithNoHostKeyCheck()); err != nil {
|
||||||
return nil, fmt.Errorf("install genesis: %w", err)
|
return fmt.Errorf("install genesis: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for RQLite leader
|
// Wait for RQLite leader
|
||||||
fmt.Print(" Waiting for RQLite leader...")
|
fmt.Print(" Waiting for RQLite leader...")
|
||||||
if err := waitForRQLiteHealth(node, 3*time.Minute); err != nil {
|
if err := waitForRQLiteHealth(node, 3*time.Minute); err != nil {
|
||||||
return nil, fmt.Errorf("genesis health: %w", err)
|
return fmt.Errorf("genesis health: %w", err)
|
||||||
}
|
}
|
||||||
fmt.Println(" OK")
|
fmt.Println(" OK")
|
||||||
|
|
||||||
// Generate invite tokens (one per remaining node)
|
return nil
|
||||||
fmt.Print(" Generating invite tokens...")
|
|
||||||
remaining := len(state.Servers) - 1
|
|
||||||
tokens := make([]string, remaining)
|
|
||||||
|
|
||||||
for i := 0; i < remaining; i++ {
|
|
||||||
token, err := generateInviteToken(node)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("generate invite token %d: %w", i+1, err)
|
|
||||||
}
|
|
||||||
tokens[i] = token
|
|
||||||
fmt.Print(".")
|
|
||||||
}
|
|
||||||
fmt.Println(" OK")
|
|
||||||
|
|
||||||
return tokens, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// phase5JoinNodes joins the remaining 4 nodes to the cluster (serial).
|
// phase5JoinNodes joins the remaining 4 nodes to the cluster (serial).
|
||||||
func phase5JoinNodes(cfg *Config, state *SandboxState, tokens []string, sshKeyPath string) error {
|
// Generates invite tokens just-in-time to avoid expiry during long installs.
|
||||||
genesisIP := state.GenesisServer().IP
|
func phase5JoinNodes(cfg *Config, state *SandboxState, sshKeyPath string) error {
|
||||||
|
genesis := state.GenesisServer()
|
||||||
|
genesisNode := inspector.Node{User: "root", Host: genesis.IP, SSHKey: sshKeyPath}
|
||||||
|
|
||||||
for i := 1; i < len(state.Servers); i++ {
|
for i := 1; i < len(state.Servers); i++ {
|
||||||
srv := state.Servers[i]
|
srv := state.Servers[i]
|
||||||
node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath}
|
node := inspector.Node{User: "root", Host: srv.IP, SSHKey: sshKeyPath}
|
||||||
token := tokens[i-1]
|
|
||||||
|
// Generate token just before use to avoid expiry
|
||||||
|
token, err := generateInviteToken(genesisNode)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate invite token for %s: %w", srv.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
var installCmd string
|
var installCmd string
|
||||||
if srv.Role == "nameserver" {
|
if srv.Role == "nameserver" {
|
||||||
installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --domain %s --base-domain %s --nameserver --skip-checks",
|
installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --domain %s --base-domain %s --nameserver --anyone-client --skip-checks",
|
||||||
genesisIP, token, srv.IP, cfg.Domain, cfg.Domain)
|
genesis.IP, token, srv.IP, cfg.Domain, cfg.Domain)
|
||||||
} else {
|
} else {
|
||||||
installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --base-domain %s --skip-checks",
|
installCmd = fmt.Sprintf("/opt/orama/bin/orama node install --join http://%s --token %s --vps-ip %s --base-domain %s --anyone-client --skip-checks",
|
||||||
genesisIP, token, srv.IP, cfg.Domain)
|
genesis.IP, token, srv.IP, cfg.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf(" [%d/%d] Joining %s (%s, %s)...\n", i, len(state.Servers)-1, srv.Name, srv.IP, srv.Role)
|
fmt.Printf(" [%d/%d] Joining %s (%s, %s)...\n", i, len(state.Servers)-1, srv.Name, srv.IP, srv.Role)
|
||||||
|
|||||||
@ -346,7 +346,11 @@ func (c *HetznerClient) UploadSSHKey(name, publicKey string) (*HetznerSSHKey, er
|
|||||||
|
|
||||||
// ListSSHKeysByFingerprint finds SSH keys matching a fingerprint.
|
// ListSSHKeysByFingerprint finds SSH keys matching a fingerprint.
|
||||||
func (c *HetznerClient) ListSSHKeysByFingerprint(fingerprint string) ([]HetznerSSHKey, error) {
|
func (c *HetznerClient) ListSSHKeysByFingerprint(fingerprint string) ([]HetznerSSHKey, error) {
|
||||||
body, err := c.get("/ssh_keys?fingerprint=" + fingerprint)
|
path := "/ssh_keys"
|
||||||
|
if fingerprint != "" {
|
||||||
|
path += "?fingerprint=" + fingerprint
|
||||||
|
}
|
||||||
|
body, err := c.get(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list SSH keys: %w", err)
|
return nil, fmt.Errorf("list SSH keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -408,17 +408,38 @@ func setupSSHKey(client *HetznerClient) (SSHKeyConfig, error) {
|
|||||||
fmt.Print(" Uploading to Hetzner... ")
|
fmt.Print(" Uploading to Hetzner... ")
|
||||||
key, err := client.UploadSSHKey("orama-sandbox", pubStr)
|
key, err := client.UploadSSHKey("orama-sandbox", pubStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Key may already exist on Hetzner — try to find by fingerprint
|
// Key may already exist on Hetzner — check if it matches the current vault key
|
||||||
existing, listErr := client.ListSSHKeysByFingerprint("") // empty = list all
|
existing, listErr := client.ListSSHKeysByFingerprint("")
|
||||||
if listErr == nil {
|
if listErr == nil {
|
||||||
for _, k := range existing {
|
for _, k := range existing {
|
||||||
if strings.TrimSpace(k.PublicKey) == pubStr {
|
if sshKeyDataEqual(k.PublicKey, pubStr) {
|
||||||
|
// Key data matches — safe to reuse regardless of name
|
||||||
fmt.Printf("already exists (ID: %d)\n", k.ID)
|
fmt.Printf("already exists (ID: %d)\n", k.ID)
|
||||||
return SSHKeyConfig{
|
return SSHKeyConfig{
|
||||||
HetznerID: k.ID,
|
HetznerID: k.ID,
|
||||||
VaultTarget: vaultTarget,
|
VaultTarget: vaultTarget,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
if k.Name == "orama-sandbox" {
|
||||||
|
// Name matches but key data differs — vault key was rotated.
|
||||||
|
// Delete the stale Hetzner key so we can re-upload the current one.
|
||||||
|
fmt.Print("stale key detected, replacing... ")
|
||||||
|
if delErr := client.DeleteSSHKey(k.ID); delErr != nil {
|
||||||
|
fmt.Println("FAILED")
|
||||||
|
return SSHKeyConfig{}, fmt.Errorf("delete stale SSH key (ID %d): %w", k.ID, delErr)
|
||||||
|
}
|
||||||
|
// Re-upload with current vault key
|
||||||
|
newKey, uploadErr := client.UploadSSHKey("orama-sandbox", pubStr)
|
||||||
|
if uploadErr != nil {
|
||||||
|
fmt.Println("FAILED")
|
||||||
|
return SSHKeyConfig{}, fmt.Errorf("re-upload SSH key: %w", uploadErr)
|
||||||
|
}
|
||||||
|
fmt.Printf("OK (ID: %d)\n", newKey.ID)
|
||||||
|
return SSHKeyConfig{
|
||||||
|
HetznerID: newKey.ID,
|
||||||
|
VaultTarget: vaultTarget,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,6 +454,17 @@ func setupSSHKey(client *HetznerClient) (SSHKeyConfig, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sshKeyDataEqual compares two SSH public key strings by their key type and
|
||||||
|
// data, ignoring the optional comment field.
|
||||||
|
func sshKeyDataEqual(a, b string) bool {
|
||||||
|
partsA := strings.Fields(strings.TrimSpace(a))
|
||||||
|
partsB := strings.Fields(strings.TrimSpace(b))
|
||||||
|
if len(partsA) < 2 || len(partsB) < 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return partsA[0] == partsB[0] && partsA[1] == partsB[1]
|
||||||
|
}
|
||||||
|
|
||||||
// verifyDNS checks if glue records for the sandbox domain are configured.
|
// verifyDNS checks if glue records for the sandbox domain are configured.
|
||||||
//
|
//
|
||||||
// There's a chicken-and-egg problem: NS records can't fully resolve until
|
// There's a chicken-and-egg problem: NS records can't fully resolve until
|
||||||
|
|||||||
82
pkg/cli/sandbox/setup_test.go
Normal file
82
pkg/cli/sandbox/setup_test.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package sandbox
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSSHKeyDataEqual(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a string
|
||||||
|
b string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "identical keys",
|
||||||
|
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest comment1",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest comment1",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same key different comments",
|
||||||
|
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest user@host",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same key one without comment",
|
||||||
|
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different key data",
|
||||||
|
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBoldkey vault",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBnewkey vault",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different key types",
|
||||||
|
a: "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAB vault",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string a",
|
||||||
|
a: "",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string b",
|
||||||
|
a: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault",
|
||||||
|
b: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both empty",
|
||||||
|
a: "",
|
||||||
|
b: "",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single field only",
|
||||||
|
a: "ssh-ed25519",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace trimming",
|
||||||
|
a: " ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest vault ",
|
||||||
|
b: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBtest",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := sshKeyDataEqual(tt.a, tt.b)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("sshKeyDataEqual(%q, %q) = %v, want %v", tt.a, tt.b, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -338,6 +338,7 @@ func (ari *AnyoneRelayInstaller) ConfigureClient() error {
|
|||||||
config := `# Anyone Client Configuration (Managed by Orama Network)
|
config := `# Anyone Client Configuration (Managed by Orama Network)
|
||||||
# Client-only mode — no relay traffic, no ORPort
|
# Client-only mode — no relay traffic, no ORPort
|
||||||
|
|
||||||
|
AgreeToTerms 1
|
||||||
SocksPort 9050
|
SocksPort 9050
|
||||||
|
|
||||||
Log notice file /var/log/anon/notices.log
|
Log notice file /var/log/anon/notices.log
|
||||||
@ -360,6 +361,8 @@ func (ari *AnyoneRelayInstaller) generateAnonrc() string {
|
|||||||
sb.WriteString("# Anyone Relay Configuration (Managed by Orama Network)\n")
|
sb.WriteString("# Anyone Relay Configuration (Managed by Orama Network)\n")
|
||||||
sb.WriteString("# Generated automatically - manual edits may be overwritten\n\n")
|
sb.WriteString("# Generated automatically - manual edits may be overwritten\n\n")
|
||||||
|
|
||||||
|
sb.WriteString("AgreeToTerms 1\n\n")
|
||||||
|
|
||||||
// Nickname
|
// Nickname
|
||||||
sb.WriteString(fmt.Sprintf("Nickname %s\n", ari.config.Nickname))
|
sb.WriteString(fmt.Sprintf("Nickname %s\n", ari.config.Nickname))
|
||||||
|
|
||||||
|
|||||||
@ -367,15 +367,13 @@ require (
|
|||||||
// If baseDomain is provided and different from domain, Caddy also serves
|
// If baseDomain is provided and different from domain, Caddy also serves
|
||||||
// the base domain and its wildcard (e.g., *.dbrs.space alongside *.node1.dbrs.space).
|
// the base domain and its wildcard (e.g., *.dbrs.space alongside *.node1.dbrs.space).
|
||||||
func (ci *CaddyInstaller) generateCaddyfile(domain, email, acmeEndpoint, baseDomain string) string {
|
func (ci *CaddyInstaller) generateCaddyfile(domain, email, acmeEndpoint, baseDomain string) string {
|
||||||
// Primary: Let's Encrypt via ACME DNS-01 challenge
|
// Let's Encrypt via ACME DNS-01 challenge (no fallback to self-signed)
|
||||||
// Fallback: Caddy's internal CA (self-signed, trust root on clients)
|
|
||||||
tlsBlock := fmt.Sprintf(` tls {
|
tlsBlock := fmt.Sprintf(` tls {
|
||||||
issuer acme {
|
issuer acme {
|
||||||
dns orama {
|
dns orama {
|
||||||
endpoint %s
|
endpoint %s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
issuer internal
|
|
||||||
}`, acmeEndpoint)
|
}`, acmeEndpoint)
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|||||||
@ -356,8 +356,12 @@ func (ci *CoreDNSInstaller) generateCorefile(domain, rqliteDSN string) string {
|
|||||||
# CoreDNS cache would cache NXDOMAIN and break ACME DNS-01 challenges.
|
# CoreDNS cache would cache NXDOMAIN and break ACME DNS-01 challenges.
|
||||||
}
|
}
|
||||||
|
|
||||||
# Forward all other queries to upstream DNS
|
# Forward non-authoritative queries to upstream DNS (localhost only).
|
||||||
|
# The bind directive restricts this block to 127.0.0.1 so the node itself
|
||||||
|
# can resolve external domains (apt, github, etc.) but external clients
|
||||||
|
# cannot use this server as an open recursive resolver (BSI/CERT-Bund).
|
||||||
. {
|
. {
|
||||||
|
bind 127.0.0.1
|
||||||
forward . 8.8.8.8 8.8.4.4 1.1.1.1
|
forward . 8.8.8.8 8.8.4.4 1.1.1.1
|
||||||
cache 300
|
cache 300
|
||||||
errors
|
errors
|
||||||
|
|||||||
151
pkg/environments/production/installers/coredns_test.go
Normal file
151
pkg/environments/production/installers/coredns_test.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
package installers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestCoreDNSInstaller creates a CoreDNSInstaller suitable for unit tests.
|
||||||
|
// It uses a non-existent oramaHome so generateCorefile won't find a password file
|
||||||
|
// and will produce output without auth credentials.
|
||||||
|
func newTestCoreDNSInstaller() *CoreDNSInstaller {
|
||||||
|
return &CoreDNSInstaller{
|
||||||
|
BaseInstaller: NewBaseInstaller("amd64", io.Discard),
|
||||||
|
version: "1.11.1",
|
||||||
|
oramaHome: "/nonexistent",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCorefile_ContainsBindLocalhost(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
|
||||||
|
|
||||||
|
if !strings.Contains(corefile, "bind 127.0.0.1") {
|
||||||
|
t.Fatal("Corefile forward block must contain 'bind 127.0.0.1' to prevent open resolver")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCorefile_ForwardBlockIsLocalhostOnly(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
|
||||||
|
|
||||||
|
// The bind directive must appear inside the catch-all (.) block,
|
||||||
|
// not inside the authoritative domain block.
|
||||||
|
// Find the ". {" block and verify bind is inside it.
|
||||||
|
dotBlockIdx := strings.Index(corefile, ". {")
|
||||||
|
if dotBlockIdx == -1 {
|
||||||
|
t.Fatal("Corefile must contain a catch-all '. {' server block")
|
||||||
|
}
|
||||||
|
|
||||||
|
dotBlock := corefile[dotBlockIdx:]
|
||||||
|
closingIdx := strings.Index(dotBlock, "}")
|
||||||
|
if closingIdx == -1 {
|
||||||
|
t.Fatal("Catch-all block has no closing brace")
|
||||||
|
}
|
||||||
|
dotBlock = dotBlock[:closingIdx]
|
||||||
|
|
||||||
|
if !strings.Contains(dotBlock, "bind 127.0.0.1") {
|
||||||
|
t.Error("bind 127.0.0.1 must be inside the catch-all (.) block, not the domain block")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(dotBlock, "forward .") {
|
||||||
|
t.Error("forward directive must be inside the catch-all (.) block")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCorefile_AuthoritativeBlockNoBindRestriction(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
|
||||||
|
|
||||||
|
// The authoritative domain block should NOT have a bind directive
|
||||||
|
// (it must listen on all interfaces to serve external DNS queries).
|
||||||
|
domainBlockStart := strings.Index(corefile, "dbrs.space {")
|
||||||
|
if domainBlockStart == -1 {
|
||||||
|
t.Fatal("Corefile must contain 'dbrs.space {' server block")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the domain block (up to the first closing brace)
|
||||||
|
domainBlock := corefile[domainBlockStart:]
|
||||||
|
closingIdx := strings.Index(domainBlock, "}")
|
||||||
|
if closingIdx == -1 {
|
||||||
|
t.Fatal("Domain block has no closing brace")
|
||||||
|
}
|
||||||
|
domainBlock = domainBlock[:closingIdx]
|
||||||
|
|
||||||
|
if strings.Contains(domainBlock, "bind ") {
|
||||||
|
t.Error("Authoritative domain block must not have a bind directive — it must listen on all interfaces")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCorefile_ContainsDomainZone(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
domain string
|
||||||
|
}{
|
||||||
|
{"dbrs.space"},
|
||||||
|
{"orama.network"},
|
||||||
|
{"example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.domain, func(t *testing.T) {
|
||||||
|
corefile := ci.generateCorefile(tt.domain, "http://localhost:5001")
|
||||||
|
|
||||||
|
if !strings.Contains(corefile, tt.domain+" {") {
|
||||||
|
t.Errorf("Corefile must contain server block for domain %q", tt.domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(corefile, "rqlite {") {
|
||||||
|
t.Error("Corefile must contain rqlite plugin block")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCorefile_ContainsRQLiteDSN(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
dsn := "http://10.0.0.1:5001"
|
||||||
|
corefile := ci.generateCorefile("dbrs.space", dsn)
|
||||||
|
|
||||||
|
if !strings.Contains(corefile, "dsn "+dsn) {
|
||||||
|
t.Errorf("Corefile must contain RQLite DSN %q", dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCorefile_NoAuthBlockWithoutCredentials(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
corefile := ci.generateCorefile("dbrs.space", "http://localhost:5001")
|
||||||
|
|
||||||
|
if strings.Contains(corefile, "username") || strings.Contains(corefile, "password") {
|
||||||
|
t.Error("Corefile must not contain auth credentials when secrets file is absent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePluginConfig_ContainsBindPlugin(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
cfg := ci.generatePluginConfig()
|
||||||
|
|
||||||
|
if !strings.Contains(cfg, "bind:bind") {
|
||||||
|
t.Error("Plugin config must include the bind plugin (required for localhost-only forwarding)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePluginConfig_ContainsACLPlugin(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
cfg := ci.generatePluginConfig()
|
||||||
|
|
||||||
|
if !strings.Contains(cfg, "acl:acl") {
|
||||||
|
t.Error("Plugin config must include the acl plugin")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePluginConfig_ContainsRQLitePlugin(t *testing.T) {
|
||||||
|
ci := newTestCoreDNSInstaller()
|
||||||
|
cfg := ci.generatePluginConfig()
|
||||||
|
|
||||||
|
if !strings.Contains(cfg, "rqlite:rqlite") {
|
||||||
|
t.Error("Plugin config must include the rqlite plugin")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -419,7 +419,7 @@ Description=Caddy HTTP/2 Server
|
|||||||
Documentation=https://caddyserver.com/docs/
|
Documentation=https://caddyserver.com/docs/
|
||||||
After=network-online.target orama-node.service coredns.service
|
After=network-online.target orama-node.service coredns.service
|
||||||
Wants=network-online.target
|
Wants=network-online.target
|
||||||
Wants=orama-node.service
|
Requires=orama-node.service
|
||||||
|
|
||||||
[Service]
|
[Service]
|
||||||
Type=simple
|
Type=simple
|
||||||
@ -428,6 +428,9 @@ ReadWritePaths=%[2]s /var/lib/caddy /etc/caddy
|
|||||||
Environment=XDG_DATA_HOME=/var/lib/caddy
|
Environment=XDG_DATA_HOME=/var/lib/caddy
|
||||||
AmbientCapabilities=CAP_NET_BIND_SERVICE
|
AmbientCapabilities=CAP_NET_BIND_SERVICE
|
||||||
CapabilityBoundingSet=CAP_NET_BIND_SERVICE
|
CapabilityBoundingSet=CAP_NET_BIND_SERVICE
|
||||||
|
ExecStartPre=/bin/sh -c 'for i in $$(seq 1 30); do curl -so /dev/null http://localhost:6001/health 2>/dev/null && exit 0; sleep 2; done; echo "Gateway not ready after 60s"; exit 1'
|
||||||
|
ExecStartPre=/bin/sh -c 'DOMAIN=$$(grep -oP "^\*\\.\K[^ {]+" /etc/caddy/Caddyfile | tail -1); [ -z "$$DOMAIN" ] && exit 0; for i in $$(seq 1 30); do dig +short +timeout=2 "$$DOMAIN" SOA 2>/dev/null | grep -q . && exit 0; sleep 2; done; echo "DNS not resolving $$DOMAIN after 60s (ACME may fail)"; exit 0'
|
||||||
|
TimeoutStartSec=180
|
||||||
ExecStart=/usr/bin/caddy run --environ --config /etc/caddy/Caddyfile
|
ExecStart=/usr/bin/caddy run --environ --config /etc/caddy/Caddyfile
|
||||||
ExecReload=/usr/bin/caddy reload --config /etc/caddy/Caddyfile
|
ExecReload=/usr/bin/caddy reload --config /etc/caddy/Caddyfile
|
||||||
TimeoutStopSec=5s
|
TimeoutStopSec=5s
|
||||||
|
|||||||
@ -78,6 +78,39 @@ func TestGenerateRQLiteService(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGenerateCaddyService_GatewayReadinessCheck verifies Caddy waits for gateway before starting
|
||||||
|
func TestGenerateCaddyService_GatewayReadinessCheck(t *testing.T) {
|
||||||
|
ssg := &SystemdServiceGenerator{
|
||||||
|
oramaHome: "/opt/orama",
|
||||||
|
oramaDir: "/opt/orama/.orama",
|
||||||
|
}
|
||||||
|
|
||||||
|
unit := ssg.GenerateCaddyService()
|
||||||
|
|
||||||
|
// Must have ExecStartPre that polls gateway health
|
||||||
|
if !strings.Contains(unit, "ExecStartPre=") {
|
||||||
|
t.Error("missing ExecStartPre directive for gateway readiness check")
|
||||||
|
}
|
||||||
|
if !strings.Contains(unit, "localhost:6001/health") {
|
||||||
|
t.Error("ExecStartPre should poll localhost:6001/health")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must use Requires= (hard dependency), not Wants= (soft dependency)
|
||||||
|
if !strings.Contains(unit, "Requires=orama-node.service") {
|
||||||
|
t.Error("missing Requires=orama-node.service (hard dependency)")
|
||||||
|
}
|
||||||
|
if strings.Contains(unit, "Wants=orama-node.service") {
|
||||||
|
t.Error("should use Requires= not Wants= for orama-node.service dependency")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecStartPre must appear before ExecStart
|
||||||
|
preIdx := strings.Index(unit, "ExecStartPre=")
|
||||||
|
startIdx := strings.Index(unit, "ExecStart=/usr/bin/caddy")
|
||||||
|
if preIdx < 0 || startIdx < 0 || preIdx >= startIdx {
|
||||||
|
t.Error("ExecStartPre must appear before ExecStart")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestGenerateRQLiteServiceArgs verifies the ExecStart command arguments
|
// TestGenerateRQLiteServiceArgs verifies the ExecStart command arguments
|
||||||
func TestGenerateRQLiteServiceArgs(t *testing.T) {
|
func TestGenerateRQLiteServiceArgs(t *testing.T) {
|
||||||
ssg := &SystemdServiceGenerator{
|
ssg := &SystemdServiceGenerator{
|
||||||
|
|||||||
@ -141,11 +141,15 @@ func (g *Gateway) healthHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
ch <- nr
|
ch <- nr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Vault Guardian (TCP connect to localhost:7500)
|
// Vault Guardian (TCP connect on WireGuard IP:7500)
|
||||||
go func() {
|
go func() {
|
||||||
nr := namedResult{name: "vault"}
|
nr := namedResult{name: "vault"}
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
conn, err := net.DialTimeout("tcp", "localhost:7500", 2*time.Second)
|
vaultAddr := "localhost:7500"
|
||||||
|
if g.localWireGuardIP != "" {
|
||||||
|
vaultAddr = g.localWireGuardIP + ":7500"
|
||||||
|
}
|
||||||
|
conn, err := net.DialTimeout("tcp", vaultAddr, 2*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
nr.result = checkResult{Status: "error", Latency: time.Since(start).String(), Error: fmt.Sprintf("vault-guardian unreachable on port 7500: %v", err)}
|
nr.result = checkResult{Status: "error", Latency: time.Since(start).String(), Error: fmt.Sprintf("vault-guardian unreachable on port 7500: %v", err)}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -109,9 +109,10 @@ func (cm *ClusterConfigManager) DiscoverClusterPeersFromLibP2P(h host.Host) erro
|
|||||||
|
|
||||||
info := h.Peerstore().PeerInfo(p)
|
info := h.Peerstore().PeerInfo(p)
|
||||||
for _, addr := range info.Addrs {
|
for _, addr := range info.Addrs {
|
||||||
// Extract IP from multiaddr
|
// Extract IP from multiaddr — only use WireGuard IPs (10.0.0.x)
|
||||||
|
// for inter-node queries since port 6001 is blocked on public interfaces by UFW
|
||||||
ip := extractIPFromMultiaddr(addr)
|
ip := extractIPFromMultiaddr(addr)
|
||||||
if ip != "" && !strings.HasPrefix(ip, "127.") && !strings.HasPrefix(ip, "::1") {
|
if ip != "" && strings.HasPrefix(ip, "10.0.0.") {
|
||||||
peerIPs[ip] = true
|
peerIPs[ip] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
95
pkg/ipfs/cluster_peer_test.go
Normal file
95
pkg/ipfs/cluster_peer_test.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package ipfs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractIPFromMultiaddr(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ipv4 tcp address",
|
||||||
|
addr: "/ip4/10.0.0.1/tcp/4001",
|
||||||
|
expected: "10.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4 public address",
|
||||||
|
addr: "/ip4/203.0.113.5/tcp/4001",
|
||||||
|
expected: "203.0.113.5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4 loopback",
|
||||||
|
addr: "/ip4/127.0.0.1/tcp/4001",
|
||||||
|
expected: "127.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6 address",
|
||||||
|
addr: "/ip6/::1/tcp/4001",
|
||||||
|
expected: "[::1]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wireguard ip with udp",
|
||||||
|
addr: "/ip4/10.0.0.3/udp/4001/quic",
|
||||||
|
expected: "10.0.0.3",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ma, err := multiaddr.NewMultiaddr(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse multiaddr %q: %v", tt.addr, err)
|
||||||
|
}
|
||||||
|
got := extractIPFromMultiaddr(ma)
|
||||||
|
if got != tt.expected {
|
||||||
|
t.Errorf("extractIPFromMultiaddr(%q) = %q, want %q", tt.addr, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractIPFromMultiaddr_Nil(t *testing.T) {
|
||||||
|
got := extractIPFromMultiaddr(nil)
|
||||||
|
if got != "" {
|
||||||
|
t.Errorf("extractIPFromMultiaddr(nil) = %q, want empty string", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWireGuardIPFiltering verifies that only 10.0.0.x IPs would be selected
|
||||||
|
// for peer discovery queries. This tests the filtering logic used in
|
||||||
|
// DiscoverClusterPeersFromLibP2P.
|
||||||
|
func TestWireGuardIPFiltering(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
accepted bool
|
||||||
|
}{
|
||||||
|
{"wireguard ip", "/ip4/10.0.0.1/tcp/4001", true},
|
||||||
|
{"wireguard ip high", "/ip4/10.0.0.254/tcp/4001", true},
|
||||||
|
{"public ip", "/ip4/203.0.113.5/tcp/4001", false},
|
||||||
|
{"private 192.168", "/ip4/192.168.1.1/tcp/4001", false},
|
||||||
|
{"private 172.16", "/ip4/172.16.0.1/tcp/4001", false},
|
||||||
|
{"loopback", "/ip4/127.0.0.1/tcp/4001", false},
|
||||||
|
{"different 10.x subnet", "/ip4/10.1.0.1/tcp/4001", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ma, err := multiaddr.NewMultiaddr(tt.addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse multiaddr: %v", err)
|
||||||
|
}
|
||||||
|
ip := extractIPFromMultiaddr(ma)
|
||||||
|
// Replicate the filtering logic from DiscoverClusterPeersFromLibP2P
|
||||||
|
accepted := ip != "" && len(ip) >= 7 && ip[:7] == "10.0.0."
|
||||||
|
if accepted != tt.accepted {
|
||||||
|
t.Errorf("IP %q: accepted=%v, want %v", ip, accepted, tt.accepted)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
222
pkg/rwagent/client.go
Normal file
222
pkg/rwagent/client.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package rwagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultSocketName is the socket file relative to ~/.rootwallet/.
|
||||||
|
DefaultSocketName = "agent.sock"
|
||||||
|
|
||||||
|
// DefaultTimeout for HTTP requests to the agent.
|
||||||
|
// Set high enough to allow pending approval flow (2 min approval timeout).
|
||||||
|
DefaultTimeout = 150 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client communicates with the rootwallet agent daemon over a Unix socket.
|
||||||
|
type Client struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
socketPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a client that connects to the agent's Unix socket.
|
||||||
|
// If socketPath is empty, defaults to ~/.rootwallet/agent.sock.
|
||||||
|
func New(socketPath string) *Client {
|
||||||
|
if socketPath == "" {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
socketPath = filepath.Join(home, ".rootwallet", DefaultSocketName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
socketPath: socketPath,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
return d.DialContext(ctx, "unix", socketPath)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status returns the agent's current status.
|
||||||
|
func (c *Client) Status(ctx context.Context) (*StatusResponse, error) {
|
||||||
|
var resp apiResponse[StatusResponse]
|
||||||
|
if err := c.doJSON(ctx, "GET", "/v1/status", nil, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !resp.OK {
|
||||||
|
return nil, c.apiError(resp.Error, resp.Code, 0)
|
||||||
|
}
|
||||||
|
return &resp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning returns true if the agent is reachable.
|
||||||
|
func (c *Client) IsRunning(ctx context.Context) bool {
|
||||||
|
_, err := c.Status(ctx)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSSHKey retrieves an SSH key from the vault.
|
||||||
|
// format: "priv", "pub", or "both".
|
||||||
|
func (c *Client) GetSSHKey(ctx context.Context, host, username, format string) (*VaultSSHData, error) {
|
||||||
|
path := fmt.Sprintf("/v1/vault/ssh/%s/%s?format=%s",
|
||||||
|
url.PathEscape(host),
|
||||||
|
url.PathEscape(username),
|
||||||
|
url.QueryEscape(format),
|
||||||
|
)
|
||||||
|
|
||||||
|
var resp apiResponse[VaultSSHData]
|
||||||
|
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !resp.OK {
|
||||||
|
return nil, c.apiError(resp.Error, resp.Code, 0)
|
||||||
|
}
|
||||||
|
return &resp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSSHEntry creates a new SSH key entry in the vault.
|
||||||
|
func (c *Client) CreateSSHEntry(ctx context.Context, host, username string) (*VaultSSHData, error) {
|
||||||
|
body := map[string]string{"host": host, "username": username}
|
||||||
|
|
||||||
|
var resp apiResponse[VaultSSHData]
|
||||||
|
if err := c.doJSON(ctx, "POST", "/v1/vault/ssh", body, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !resp.OK {
|
||||||
|
return nil, c.apiError(resp.Error, resp.Code, 0)
|
||||||
|
}
|
||||||
|
return &resp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPassword retrieves a stored password from the vault.
|
||||||
|
func (c *Client) GetPassword(ctx context.Context, domain, username string) (*VaultPasswordData, error) {
|
||||||
|
path := fmt.Sprintf("/v1/vault/password/%s/%s",
|
||||||
|
url.PathEscape(domain),
|
||||||
|
url.PathEscape(username),
|
||||||
|
)
|
||||||
|
|
||||||
|
var resp apiResponse[VaultPasswordData]
|
||||||
|
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !resp.OK {
|
||||||
|
return nil, c.apiError(resp.Error, resp.Code, 0)
|
||||||
|
}
|
||||||
|
return &resp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAddress returns the active wallet address.
|
||||||
|
func (c *Client) GetAddress(ctx context.Context, chain string) (*WalletAddressData, error) {
|
||||||
|
path := fmt.Sprintf("/v1/wallet/address?chain=%s", url.QueryEscape(chain))
|
||||||
|
|
||||||
|
var resp apiResponse[WalletAddressData]
|
||||||
|
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !resp.OK {
|
||||||
|
return nil, c.apiError(resp.Error, resp.Code, 0)
|
||||||
|
}
|
||||||
|
return &resp.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlock sends the password to unlock the agent.
|
||||||
|
func (c *Client) Unlock(ctx context.Context, password string, ttlMinutes int) error {
|
||||||
|
body := map[string]any{"password": password, "ttlMinutes": ttlMinutes}
|
||||||
|
|
||||||
|
var resp apiResponse[any]
|
||||||
|
if err := c.doJSON(ctx, "POST", "/v1/unlock", body, &resp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !resp.OK {
|
||||||
|
return c.apiError(resp.Error, resp.Code, 0)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock locks the agent, zeroing all key material.
|
||||||
|
func (c *Client) Lock(ctx context.Context) error {
|
||||||
|
var resp apiResponse[any]
|
||||||
|
if err := c.doJSON(ctx, "POST", "/v1/lock", nil, &resp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !resp.OK {
|
||||||
|
return c.apiError(resp.Error, resp.Code, 0)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doJSON performs an HTTP request and decodes the JSON response.
|
||||||
|
func (c *Client) doJSON(ctx context.Context, method, path string, body any, result any) error {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if body != nil {
|
||||||
|
data, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal request body: %w", err)
|
||||||
|
}
|
||||||
|
bodyReader = strings.NewReader(string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// URL host is ignored for Unix sockets, but required by http.NewRequest
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, "http://localhost"+path, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-RW-PID", strconv.Itoa(os.Getpid()))
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
// Connection refused or socket not found = agent not running
|
||||||
|
if isConnectionError(err) {
|
||||||
|
return ErrAgentNotRunning
|
||||||
|
}
|
||||||
|
return fmt.Errorf("agent request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, result); err != nil {
|
||||||
|
return fmt.Errorf("decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) apiError(message, code string, statusCode int) *AgentError {
|
||||||
|
return &AgentError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isConnectionError checks if the error is a connection-level failure.
|
||||||
|
func isConnectionError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := err.Error()
|
||||||
|
return strings.Contains(msg, "connection refused") ||
|
||||||
|
strings.Contains(msg, "no such file or directory") ||
|
||||||
|
strings.Contains(msg, "connect: no such file")
|
||||||
|
}
|
||||||
|
|
||||||
257
pkg/rwagent/client_test.go
Normal file
257
pkg/rwagent/client_test.go
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
package rwagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startMockAgent creates a mock agent server on a Unix socket for testing.
|
||||||
|
func startMockAgent(t *testing.T, handler http.Handler) (socketPath string, cleanup func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
socketPath = filepath.Join(tmpDir, "test-agent.sock")
|
||||||
|
|
||||||
|
listener, err := net.Listen("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen on unix socket: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &http.Server{Handler: handler}
|
||||||
|
go func() { _ = server.Serve(listener) }()
|
||||||
|
|
||||||
|
cleanup = func() {
|
||||||
|
_ = server.Close()
|
||||||
|
_ = os.Remove(socketPath)
|
||||||
|
}
|
||||||
|
return socketPath, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsonHandler returns an http.HandlerFunc that responds with the given JSON.
|
||||||
|
func jsonHandler(statusCode int, body any) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
data, _ := json.Marshal(body)
|
||||||
|
_, _ = w.Write(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatus(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/status", jsonHandler(200, apiResponse[StatusResponse]{
|
||||||
|
OK: true,
|
||||||
|
Data: StatusResponse{
|
||||||
|
Version: "1.0.0",
|
||||||
|
Locked: false,
|
||||||
|
Uptime: 120,
|
||||||
|
PID: 12345,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
status, err := client.Status(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Status() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Version != "1.0.0" {
|
||||||
|
t.Errorf("Version = %q, want %q", status.Version, "1.0.0")
|
||||||
|
}
|
||||||
|
if status.Locked {
|
||||||
|
t.Error("Locked = true, want false")
|
||||||
|
}
|
||||||
|
if status.Uptime != 120 {
|
||||||
|
t.Errorf("Uptime = %d, want 120", status.Uptime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRunning_true(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/status", jsonHandler(200, apiResponse[StatusResponse]{
|
||||||
|
OK: true,
|
||||||
|
Data: StatusResponse{Version: "1.0.0"},
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
if !client.IsRunning(context.Background()) {
|
||||||
|
t.Error("IsRunning() = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRunning_false(t *testing.T) {
|
||||||
|
client := New("/tmp/nonexistent-socket-test.sock")
|
||||||
|
if client.IsRunning(context.Background()) {
|
||||||
|
t.Error("IsRunning() = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSSHKey(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/vault/ssh/myhost/root", jsonHandler(200, apiResponse[VaultSSHData]{
|
||||||
|
OK: true,
|
||||||
|
Data: VaultSSHData{
|
||||||
|
PrivateKey: "-----BEGIN OPENSSH PRIVATE KEY-----\nfake\n-----END OPENSSH PRIVATE KEY-----",
|
||||||
|
PublicKey: "ssh-ed25519 AAAA... myhost/root",
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
data, err := client.GetSSHKey(context.Background(), "myhost", "root", "both")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSSHKey() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.PrivateKey == "" {
|
||||||
|
t.Error("PrivateKey is empty")
|
||||||
|
}
|
||||||
|
if data.PublicKey == "" {
|
||||||
|
t.Error("PublicKey is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSSHKey_locked(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/vault/ssh/myhost/root", jsonHandler(423, apiResponse[any]{
|
||||||
|
OK: false,
|
||||||
|
Error: "Agent is locked",
|
||||||
|
Code: "AGENT_LOCKED",
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
_, err := client.GetSSHKey(context.Background(), "myhost", "root", "priv")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("GetSSHKey() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !IsLocked(err) {
|
||||||
|
t.Errorf("IsLocked() = false for error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSSHKey_notFound(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/vault/ssh/unknown/user", jsonHandler(404, apiResponse[any]{
|
||||||
|
OK: false,
|
||||||
|
Error: "No SSH key found for unknown/user",
|
||||||
|
Code: "NOT_FOUND",
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
_, err := client.GetSSHKey(context.Background(), "unknown", "user", "priv")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("GetSSHKey() expected error, got nil")
|
||||||
|
}
|
||||||
|
if !IsNotFound(err) {
|
||||||
|
t.Errorf("IsNotFound() = false for error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPassword(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/vault/password/example.com/admin", jsonHandler(200, apiResponse[VaultPasswordData]{
|
||||||
|
OK: true,
|
||||||
|
Data: VaultPasswordData{Password: "secret123"},
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
data, err := client.GetPassword(context.Background(), "example.com", "admin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPassword() error: %v", err)
|
||||||
|
}
|
||||||
|
if data.Password != "secret123" {
|
||||||
|
t.Errorf("Password = %q, want %q", data.Password, "secret123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSSHEntry(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/vault/ssh", jsonHandler(201, apiResponse[VaultSSHData]{
|
||||||
|
OK: true,
|
||||||
|
Data: VaultSSHData{PublicKey: "ssh-ed25519 AAAA... new/entry"},
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
data, err := client.CreateSSHEntry(context.Background(), "new", "entry")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSSHEntry() error: %v", err)
|
||||||
|
}
|
||||||
|
if data.PublicKey == "" {
|
||||||
|
t.Error("PublicKey is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAddress(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/wallet/address", jsonHandler(200, apiResponse[WalletAddressData]{
|
||||||
|
OK: true,
|
||||||
|
Data: WalletAddressData{Address: "0x1234abcd", Chain: "evm"},
|
||||||
|
}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
data, err := client.GetAddress(context.Background(), "evm")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAddress() error: %v", err)
|
||||||
|
}
|
||||||
|
if data.Address != "0x1234abcd" {
|
||||||
|
t.Errorf("Address = %q, want %q", data.Address, "0x1234abcd")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentNotRunning(t *testing.T) {
|
||||||
|
client := New("/tmp/nonexistent-socket-for-testing.sock")
|
||||||
|
_, err := client.Status(context.Background())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !IsNotRunning(err) {
|
||||||
|
t.Errorf("IsNotRunning() = false for error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnlockAndLock(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/v1/unlock", jsonHandler(200, apiResponse[any]{OK: true}))
|
||||||
|
mux.HandleFunc("/v1/lock", jsonHandler(200, apiResponse[any]{OK: true}))
|
||||||
|
|
||||||
|
sock, cleanup := startMockAgent(t, mux)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
client := New(sock)
|
||||||
|
|
||||||
|
if err := client.Unlock(context.Background(), "password", 30); err != nil {
|
||||||
|
t.Fatalf("Unlock() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Lock(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Lock() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
57
pkg/rwagent/errors.go
Normal file
57
pkg/rwagent/errors.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package rwagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AgentError represents an error returned by the rootwallet agent API.
|
||||||
|
type AgentError struct {
|
||||||
|
Code string // e.g., "AGENT_LOCKED", "NOT_FOUND"
|
||||||
|
Message string
|
||||||
|
StatusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AgentError) Error() string {
|
||||||
|
return fmt.Sprintf("rootwallet agent: %s (%s)", e.Message, e.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLocked returns true if the error indicates the agent is locked.
|
||||||
|
func IsLocked(err error) bool {
|
||||||
|
var ae *AgentError
|
||||||
|
if errors.As(err, &ae) {
|
||||||
|
return ae.Code == "AGENT_LOCKED"
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotRunning returns true if the error indicates the agent is not reachable.
|
||||||
|
func IsNotRunning(err error) bool {
|
||||||
|
var ae *AgentError
|
||||||
|
if errors.As(err, &ae) {
|
||||||
|
return ae.Code == "AGENT_NOT_RUNNING"
|
||||||
|
}
|
||||||
|
// Also check for connection errors
|
||||||
|
return errors.Is(err, ErrAgentNotRunning)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNotFound returns true if the vault entry was not found.
|
||||||
|
func IsNotFound(err error) bool {
|
||||||
|
var ae *AgentError
|
||||||
|
if errors.As(err, &ae) {
|
||||||
|
return ae.Code == "NOT_FOUND"
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsApprovalDenied returns true if the user denied the app's access request.
|
||||||
|
func IsApprovalDenied(err error) bool {
|
||||||
|
var ae *AgentError
|
||||||
|
if errors.As(err, &ae) {
|
||||||
|
return ae.Code == "APPROVAL_DENIED" || ae.Code == "PERMISSION_DENIED"
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrAgentNotRunning is returned when the agent socket is not reachable.
|
||||||
|
var ErrAgentNotRunning = fmt.Errorf("rootwallet agent is not running — start with: rw agent start && rw agent unlock")
|
||||||
56
pkg/rwagent/types.go
Normal file
56
pkg/rwagent/types.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
// Package rwagent provides a Go client for the RootWallet agent daemon.
|
||||||
|
//
|
||||||
|
// The agent is a persistent daemon that holds vault keys in memory and serves
|
||||||
|
// operations to authorized apps over a Unix socket HTTP API. This SDK replaces
|
||||||
|
// all subprocess `rw` calls with direct HTTP communication.
|
||||||
|
package rwagent
|
||||||
|
|
||||||
|
// StatusResponse from GET /v1/status.
|
||||||
|
type StatusResponse struct {
|
||||||
|
Version string `json:"version"`
|
||||||
|
Locked bool `json:"locked"`
|
||||||
|
Uptime int `json:"uptime"`
|
||||||
|
PID int `json:"pid"`
|
||||||
|
ConnectedApps int `json:"connectedApps"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VaultSSHData from GET /v1/vault/ssh/:host/:user.
|
||||||
|
type VaultSSHData struct {
|
||||||
|
PrivateKey string `json:"privateKey,omitempty"`
|
||||||
|
PublicKey string `json:"publicKey,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VaultPasswordData from GET /v1/vault/password/:domain/:user.
|
||||||
|
type VaultPasswordData struct {
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WalletAddressData from GET /v1/wallet/address.
|
||||||
|
type WalletAddressData struct {
|
||||||
|
Address string `json:"address"`
|
||||||
|
Chain string `json:"chain"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppPermission represents an approved app in the permission database.
|
||||||
|
type AppPermission struct {
|
||||||
|
BinaryHash string `json:"binaryHash"`
|
||||||
|
BinaryPath string `json:"binaryPath"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
FirstSeen string `json:"firstSeen"`
|
||||||
|
LastUsed string `json:"lastUsed"`
|
||||||
|
Capabilities []PermittedCapability `json:"capabilities"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PermittedCapability is a specific capability granted to an app.
|
||||||
|
type PermittedCapability struct {
|
||||||
|
Capability string `json:"capability"`
|
||||||
|
GrantedAt string `json:"grantedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// apiResponse is the generic API response envelope.
|
||||||
|
type apiResponse[T any] struct {
|
||||||
|
OK bool `json:"ok"`
|
||||||
|
Data T `json:"data,omitempty"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user