orama/pkg/cli/remotessh/wallet_test.go
anonpenguin23 0764ac287e refactor(remotessh): use rwagent directly instead of rw CLI subprocesses
- replace `rw vault ssh` calls with `rwagent.Client` in PrepareNodeKeys,
  LoadAgentKeys, EnsureVaultEntry, ResolveVaultPublicKey
- add vaultClient interface, newClient func, and wrapAgentError for
  testability and improved error messages
- prefer pre-built systemd dir in installNamespaceTemplates
2026-03-20 07:23:10 +02:00

377 lines
10 KiB
Go

package remotessh
import (
"context"
"errors"
"os"
"strings"
"testing"
"github.com/DeBrosOfficial/network/pkg/inspector"
"github.com/DeBrosOfficial/network/pkg/rwagent"
)
const testPrivateKey = "-----BEGIN OPENSSH PRIVATE KEY-----\nfake-key-data\n-----END OPENSSH PRIVATE KEY-----"
// mockClient implements vaultClient for testing.
type mockClient struct {
getSSHKey func(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error)
createSSHEntry func(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error)
}
func (m *mockClient) GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
return m.getSSHKey(ctx, host, username, format)
}
func (m *mockClient) CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error) {
return m.createSSHEntry(ctx, host, username)
}
// withMockClient replaces newClient for the duration of a test.
func withMockClient(t *testing.T, mock *mockClient) {
t.Helper()
orig := newClient
newClient = func() vaultClient { return mock }
t.Cleanup(func() { newClient = orig })
}
func TestParseVaultTarget(t *testing.T) {
tests := []struct {
target string
wantHost string
wantUser string
}{
{"sandbox/root", "sandbox", "root"},
{"192.168.1.1/ubuntu", "192.168.1.1", "ubuntu"},
{"my-host/my-user", "my-host", "my-user"},
{"noslash", "noslash", ""},
{"a/b/c", "a", "b/c"},
{"", "", ""},
}
for _, tt := range tests {
t.Run(tt.target, func(t *testing.T) {
host, user := parseVaultTarget(tt.target)
if host != tt.wantHost {
t.Errorf("parseVaultTarget(%q) host = %q, want %q", tt.target, host, tt.wantHost)
}
if user != tt.wantUser {
t.Errorf("parseVaultTarget(%q) user = %q, want %q", tt.target, user, tt.wantUser)
}
})
}
}
func TestWrapAgentError_notRunning(t *testing.T) {
err := wrapAgentError(rwagent.ErrAgentNotRunning, "test action")
if !strings.Contains(err.Error(), "not running") {
t.Errorf("expected 'not running' message, got: %s", err)
}
if !strings.Contains(err.Error(), "rw agent start") {
t.Errorf("expected actionable hint, got: %s", err)
}
}
func TestWrapAgentError_locked(t *testing.T) {
agentErr := &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "agent is locked"}
err := wrapAgentError(agentErr, "test action")
if !strings.Contains(err.Error(), "locked") {
t.Errorf("expected 'locked' message, got: %s", err)
}
if !strings.Contains(err.Error(), "rw agent unlock") {
t.Errorf("expected actionable hint, got: %s", err)
}
}
func TestWrapAgentError_generic(t *testing.T) {
err := wrapAgentError(errors.New("some error"), "test action")
if !strings.Contains(err.Error(), "test action") {
t.Errorf("expected action context, got: %s", err)
}
if !strings.Contains(err.Error(), "some error") {
t.Errorf("expected wrapped error, got: %s", err)
}
}
func TestPrepareNodeKeys_success(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root"},
{Host: "10.0.0.2", User: "root"},
}
cleanup, err := PrepareNodeKeys(nodes)
if err != nil {
t.Fatalf("PrepareNodeKeys() error = %v", err)
}
defer cleanup()
for i, n := range nodes {
if n.SSHKey == "" {
t.Errorf("node[%d].SSHKey is empty", i)
continue
}
data, err := os.ReadFile(n.SSHKey)
if err != nil {
t.Errorf("node[%d] key file unreadable: %v", i, err)
continue
}
if !strings.Contains(string(data), "BEGIN OPENSSH PRIVATE KEY") {
t.Errorf("node[%d] key file has wrong content", i)
}
}
}
func TestPrepareNodeKeys_deduplication(t *testing.T) {
callCount := 0
mock := &mockClient{
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
callCount++
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root"},
{Host: "10.0.0.1", User: "root"}, // same host/user
}
cleanup, err := PrepareNodeKeys(nodes)
if err != nil {
t.Fatalf("PrepareNodeKeys() error = %v", err)
}
defer cleanup()
if callCount != 1 {
t.Errorf("expected 1 agent call (dedup), got %d", callCount)
}
if nodes[0].SSHKey != nodes[1].SSHKey {
t.Error("expected same key path for deduplicated nodes")
}
}
func TestPrepareNodeKeys_vaultTarget(t *testing.T) {
var capturedHost, capturedUser string
mock := &mockClient{
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
capturedHost = host
capturedUser = username
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root", VaultTarget: "sandbox/admin"},
}
cleanup, err := PrepareNodeKeys(nodes)
if err != nil {
t.Fatalf("PrepareNodeKeys() error = %v", err)
}
defer cleanup()
if capturedHost != "sandbox" || capturedUser != "admin" {
t.Errorf("expected host=sandbox user=admin, got host=%s user=%s", capturedHost, capturedUser)
}
}
func TestPrepareNodeKeys_agentNotRunning(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, rwagent.ErrAgentNotRunning
},
}
withMockClient(t, mock)
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
_, err := PrepareNodeKeys(nodes)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "not running") {
t.Errorf("expected 'not running' error, got: %s", err)
}
}
func TestPrepareNodeKeys_invalidKey(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PrivateKey: "garbage"}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
_, err := PrepareNodeKeys(nodes)
if err == nil {
t.Fatal("expected error for invalid key")
}
if !strings.Contains(err.Error(), "invalid key") {
t.Errorf("expected 'invalid key' error, got: %s", err)
}
}
func TestPrepareNodeKeys_cleanupOnError(t *testing.T) {
callNum := 0
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
callNum++
if callNum == 2 {
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
}
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
},
}
withMockClient(t, mock)
nodes := []inspector.Node{
{Host: "10.0.0.1", User: "root"},
{Host: "10.0.0.2", User: "root"},
}
_, err := PrepareNodeKeys(nodes)
if err == nil {
t.Fatal("expected error")
}
// First node's temp file should have been cleaned up
if nodes[0].SSHKey != "" {
if _, statErr := os.Stat(nodes[0].SSHKey); statErr == nil {
t.Error("expected temp key file to be cleaned up on error")
}
}
}
func TestPrepareNodeKeys_emptyNodes(t *testing.T) {
mock := &mockClient{}
withMockClient(t, mock)
cleanup, err := PrepareNodeKeys(nil)
if err != nil {
t.Fatalf("expected no error for empty nodes, got: %v", err)
}
cleanup() // should not panic
}
func TestEnsureVaultEntry_exists(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
},
}
withMockClient(t, mock)
if err := EnsureVaultEntry("sandbox/root"); err != nil {
t.Fatalf("EnsureVaultEntry() error = %v", err)
}
}
func TestEnsureVaultEntry_creates(t *testing.T) {
created := false
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
},
createSSHEntry: func(_ context.Context, host, username string) (*rwagent.VaultSSHData, error) {
created = true
if host != "sandbox" || username != "root" {
t.Errorf("unexpected create args: %s/%s", host, username)
}
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
},
}
withMockClient(t, mock)
if err := EnsureVaultEntry("sandbox/root"); err != nil {
t.Fatalf("EnsureVaultEntry() error = %v", err)
}
if !created {
t.Error("expected CreateSSHEntry to be called")
}
}
func TestEnsureVaultEntry_locked(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
},
}
withMockClient(t, mock)
err := EnsureVaultEntry("sandbox/root")
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "locked") {
t.Errorf("expected locked error, got: %s", err)
}
}
func TestResolveVaultPublicKey_success(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, format string) (*rwagent.VaultSSHData, error) {
if format != "pub" {
t.Errorf("expected format=pub, got %s", format)
}
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA..."}, nil
},
}
withMockClient(t, mock)
key, err := ResolveVaultPublicKey("sandbox/root")
if err != nil {
t.Fatalf("ResolveVaultPublicKey() error = %v", err)
}
if !strings.HasPrefix(key, "ssh-") {
t.Errorf("expected ssh- prefix, got: %s", key)
}
}
func TestResolveVaultPublicKey_invalidFormat(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return &rwagent.VaultSSHData{PublicKey: "not-a-valid-key"}, nil
},
}
withMockClient(t, mock)
_, err := ResolveVaultPublicKey("sandbox/root")
if err == nil {
t.Fatal("expected error for invalid public key")
}
if !strings.Contains(err.Error(), "invalid public key") {
t.Errorf("expected 'invalid public key' error, got: %s", err)
}
}
func TestResolveVaultPublicKey_notFound(t *testing.T) {
mock := &mockClient{
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
},
}
withMockClient(t, mock)
_, err := ResolveVaultPublicKey("sandbox/root")
if err == nil {
t.Fatal("expected error")
}
}
func TestLoadAgentKeys_emptyNodes(t *testing.T) {
mock := &mockClient{}
withMockClient(t, mock)
if err := LoadAgentKeys(nil); err != nil {
t.Fatalf("expected no error for empty nodes, got: %v", err)
}
}