mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-27 22:24:13 +00:00
- replace `rw vault ssh` calls with `rwagent.Client` in PrepareNodeKeys, LoadAgentKeys, EnsureVaultEntry, ResolveVaultPublicKey - add vaultClient interface, newClient func, and wrapAgentError for testability and improved error messages - prefer pre-built systemd dir in installNamespaceTemplates
377 lines
10 KiB
Go
377 lines
10 KiB
Go
package remotessh
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/DeBrosOfficial/network/pkg/inspector"
|
|
"github.com/DeBrosOfficial/network/pkg/rwagent"
|
|
)
|
|
|
|
const testPrivateKey = "-----BEGIN OPENSSH PRIVATE KEY-----\nfake-key-data\n-----END OPENSSH PRIVATE KEY-----"
|
|
|
|
// mockClient implements vaultClient for testing.
|
|
type mockClient struct {
|
|
getSSHKey func(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error)
|
|
createSSHEntry func(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error)
|
|
}
|
|
|
|
func (m *mockClient) GetSSHKey(ctx context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
|
return m.getSSHKey(ctx, host, username, format)
|
|
}
|
|
|
|
func (m *mockClient) CreateSSHEntry(ctx context.Context, host, username string) (*rwagent.VaultSSHData, error) {
|
|
return m.createSSHEntry(ctx, host, username)
|
|
}
|
|
|
|
// withMockClient replaces newClient for the duration of a test.
|
|
func withMockClient(t *testing.T, mock *mockClient) {
|
|
t.Helper()
|
|
orig := newClient
|
|
newClient = func() vaultClient { return mock }
|
|
t.Cleanup(func() { newClient = orig })
|
|
}
|
|
|
|
func TestParseVaultTarget(t *testing.T) {
|
|
tests := []struct {
|
|
target string
|
|
wantHost string
|
|
wantUser string
|
|
}{
|
|
{"sandbox/root", "sandbox", "root"},
|
|
{"192.168.1.1/ubuntu", "192.168.1.1", "ubuntu"},
|
|
{"my-host/my-user", "my-host", "my-user"},
|
|
{"noslash", "noslash", ""},
|
|
{"a/b/c", "a", "b/c"},
|
|
{"", "", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.target, func(t *testing.T) {
|
|
host, user := parseVaultTarget(tt.target)
|
|
if host != tt.wantHost {
|
|
t.Errorf("parseVaultTarget(%q) host = %q, want %q", tt.target, host, tt.wantHost)
|
|
}
|
|
if user != tt.wantUser {
|
|
t.Errorf("parseVaultTarget(%q) user = %q, want %q", tt.target, user, tt.wantUser)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWrapAgentError_notRunning(t *testing.T) {
|
|
err := wrapAgentError(rwagent.ErrAgentNotRunning, "test action")
|
|
if !strings.Contains(err.Error(), "not running") {
|
|
t.Errorf("expected 'not running' message, got: %s", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "rw agent start") {
|
|
t.Errorf("expected actionable hint, got: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestWrapAgentError_locked(t *testing.T) {
|
|
agentErr := &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "agent is locked"}
|
|
err := wrapAgentError(agentErr, "test action")
|
|
if !strings.Contains(err.Error(), "locked") {
|
|
t.Errorf("expected 'locked' message, got: %s", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "rw agent unlock") {
|
|
t.Errorf("expected actionable hint, got: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestWrapAgentError_generic(t *testing.T) {
|
|
err := wrapAgentError(errors.New("some error"), "test action")
|
|
if !strings.Contains(err.Error(), "test action") {
|
|
t.Errorf("expected action context, got: %s", err)
|
|
}
|
|
if !strings.Contains(err.Error(), "some error") {
|
|
t.Errorf("expected wrapped error, got: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestPrepareNodeKeys_success(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
nodes := []inspector.Node{
|
|
{Host: "10.0.0.1", User: "root"},
|
|
{Host: "10.0.0.2", User: "root"},
|
|
}
|
|
|
|
cleanup, err := PrepareNodeKeys(nodes)
|
|
if err != nil {
|
|
t.Fatalf("PrepareNodeKeys() error = %v", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
for i, n := range nodes {
|
|
if n.SSHKey == "" {
|
|
t.Errorf("node[%d].SSHKey is empty", i)
|
|
continue
|
|
}
|
|
data, err := os.ReadFile(n.SSHKey)
|
|
if err != nil {
|
|
t.Errorf("node[%d] key file unreadable: %v", i, err)
|
|
continue
|
|
}
|
|
if !strings.Contains(string(data), "BEGIN OPENSSH PRIVATE KEY") {
|
|
t.Errorf("node[%d] key file has wrong content", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPrepareNodeKeys_deduplication(t *testing.T) {
|
|
callCount := 0
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
|
callCount++
|
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
nodes := []inspector.Node{
|
|
{Host: "10.0.0.1", User: "root"},
|
|
{Host: "10.0.0.1", User: "root"}, // same host/user
|
|
}
|
|
|
|
cleanup, err := PrepareNodeKeys(nodes)
|
|
if err != nil {
|
|
t.Fatalf("PrepareNodeKeys() error = %v", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
if callCount != 1 {
|
|
t.Errorf("expected 1 agent call (dedup), got %d", callCount)
|
|
}
|
|
if nodes[0].SSHKey != nodes[1].SSHKey {
|
|
t.Error("expected same key path for deduplicated nodes")
|
|
}
|
|
}
|
|
|
|
func TestPrepareNodeKeys_vaultTarget(t *testing.T) {
|
|
var capturedHost, capturedUser string
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, host, username, format string) (*rwagent.VaultSSHData, error) {
|
|
capturedHost = host
|
|
capturedUser = username
|
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
nodes := []inspector.Node{
|
|
{Host: "10.0.0.1", User: "root", VaultTarget: "sandbox/admin"},
|
|
}
|
|
|
|
cleanup, err := PrepareNodeKeys(nodes)
|
|
if err != nil {
|
|
t.Fatalf("PrepareNodeKeys() error = %v", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
if capturedHost != "sandbox" || capturedUser != "admin" {
|
|
t.Errorf("expected host=sandbox user=admin, got host=%s user=%s", capturedHost, capturedUser)
|
|
}
|
|
}
|
|
|
|
func TestPrepareNodeKeys_agentNotRunning(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
return nil, rwagent.ErrAgentNotRunning
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
|
|
_, err := PrepareNodeKeys(nodes)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "not running") {
|
|
t.Errorf("expected 'not running' error, got: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestPrepareNodeKeys_invalidKey(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
return &rwagent.VaultSSHData{PrivateKey: "garbage"}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
nodes := []inspector.Node{{Host: "10.0.0.1", User: "root"}}
|
|
_, err := PrepareNodeKeys(nodes)
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid key")
|
|
}
|
|
if !strings.Contains(err.Error(), "invalid key") {
|
|
t.Errorf("expected 'invalid key' error, got: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestPrepareNodeKeys_cleanupOnError(t *testing.T) {
|
|
callNum := 0
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
callNum++
|
|
if callNum == 2 {
|
|
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
|
|
}
|
|
return &rwagent.VaultSSHData{PrivateKey: testPrivateKey}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
nodes := []inspector.Node{
|
|
{Host: "10.0.0.1", User: "root"},
|
|
{Host: "10.0.0.2", User: "root"},
|
|
}
|
|
|
|
_, err := PrepareNodeKeys(nodes)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
|
|
// First node's temp file should have been cleaned up
|
|
if nodes[0].SSHKey != "" {
|
|
if _, statErr := os.Stat(nodes[0].SSHKey); statErr == nil {
|
|
t.Error("expected temp key file to be cleaned up on error")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPrepareNodeKeys_emptyNodes(t *testing.T) {
|
|
mock := &mockClient{}
|
|
withMockClient(t, mock)
|
|
|
|
cleanup, err := PrepareNodeKeys(nil)
|
|
if err != nil {
|
|
t.Fatalf("expected no error for empty nodes, got: %v", err)
|
|
}
|
|
cleanup() // should not panic
|
|
}
|
|
|
|
func TestEnsureVaultEntry_exists(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
if err := EnsureVaultEntry("sandbox/root"); err != nil {
|
|
t.Fatalf("EnsureVaultEntry() error = %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEnsureVaultEntry_creates(t *testing.T) {
|
|
created := false
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
|
|
},
|
|
createSSHEntry: func(_ context.Context, host, username string) (*rwagent.VaultSSHData, error) {
|
|
created = true
|
|
if host != "sandbox" || username != "root" {
|
|
t.Errorf("unexpected create args: %s/%s", host, username)
|
|
}
|
|
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAA..."}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
if err := EnsureVaultEntry("sandbox/root"); err != nil {
|
|
t.Fatalf("EnsureVaultEntry() error = %v", err)
|
|
}
|
|
if !created {
|
|
t.Error("expected CreateSSHEntry to be called")
|
|
}
|
|
}
|
|
|
|
func TestEnsureVaultEntry_locked(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
return nil, &rwagent.AgentError{Code: "AGENT_LOCKED", Message: "locked"}
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
err := EnsureVaultEntry("sandbox/root")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !strings.Contains(err.Error(), "locked") {
|
|
t.Errorf("expected locked error, got: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestResolveVaultPublicKey_success(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, format string) (*rwagent.VaultSSHData, error) {
|
|
if format != "pub" {
|
|
t.Errorf("expected format=pub, got %s", format)
|
|
}
|
|
return &rwagent.VaultSSHData{PublicKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA..."}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
key, err := ResolveVaultPublicKey("sandbox/root")
|
|
if err != nil {
|
|
t.Fatalf("ResolveVaultPublicKey() error = %v", err)
|
|
}
|
|
if !strings.HasPrefix(key, "ssh-") {
|
|
t.Errorf("expected ssh- prefix, got: %s", key)
|
|
}
|
|
}
|
|
|
|
func TestResolveVaultPublicKey_invalidFormat(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
return &rwagent.VaultSSHData{PublicKey: "not-a-valid-key"}, nil
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
_, err := ResolveVaultPublicKey("sandbox/root")
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid public key")
|
|
}
|
|
if !strings.Contains(err.Error(), "invalid public key") {
|
|
t.Errorf("expected 'invalid public key' error, got: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestResolveVaultPublicKey_notFound(t *testing.T) {
|
|
mock := &mockClient{
|
|
getSSHKey: func(_ context.Context, _, _, _ string) (*rwagent.VaultSSHData, error) {
|
|
return nil, &rwagent.AgentError{Code: "NOT_FOUND", Message: "not found"}
|
|
},
|
|
}
|
|
withMockClient(t, mock)
|
|
|
|
_, err := ResolveVaultPublicKey("sandbox/root")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
}
|
|
|
|
func TestLoadAgentKeys_emptyNodes(t *testing.T) {
|
|
mock := &mockClient{}
|
|
withMockClient(t, mock)
|
|
|
|
if err := LoadAgentKeys(nil); err != nil {
|
|
t.Fatalf("expected no error for empty nodes, got: %v", err)
|
|
}
|
|
}
|