diff --git a/Makefile b/Makefile index e791a23..5effa04 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ check-gateway: echo " 3. Run tests: make test-e2e-local"; \ echo ""; \ echo "To run tests against production:"; \ - echo " ORAMA_GATEWAY_URL=http://VPS-IP:6001 make test-e2e"; \ + echo " ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e"; \ exit 1; \ fi @echo "✅ Gateway is running" @@ -36,7 +36,7 @@ test-e2e-local: check-gateway test-e2e-prod: @if [ -z "$$ORAMA_GATEWAY_URL" ]; then \ echo "❌ ORAMA_GATEWAY_URL not set"; \ - echo "Usage: ORAMA_GATEWAY_URL=http://VPS-IP:6001 make test-e2e-prod"; \ + echo "Usage: ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e-prod"; \ exit 1; \ fi @echo "Running E2E tests (including production-only) against $$ORAMA_GATEWAY_URL..." @@ -182,7 +182,7 @@ help: @echo " make test-e2e - Generic E2E tests (auto-discovers config)" @echo "" @echo " Example production test:" - @echo " ORAMA_GATEWAY_URL=http://141.227.165.168:6001 make test-e2e-prod" + @echo " ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e-prod" @echo "" @echo "Development Management (via orama):" @echo " ./bin/orama dev status - Show status of all dev services" diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 49bad75..7cbcf53 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -53,6 +53,8 @@ func main() { cli.HandleProdCommand(args) // Direct production commands (new simplified interface) + case "invite": + cli.HandleProdCommand(append([]string{"invite"}, args...)) case "install": cli.HandleProdCommand(append([]string{"install"}, args...)) case "upgrade": diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 80fbc45..0e8a04e 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -353,12 +353,22 @@ Function Invocation: - Refresh token support - Claims-based authorization +### Network Security (WireGuard Mesh) + +All inter-node communication is encrypted via a WireGuard VPN mesh: + +- **WireGuard IPs:** Each node gets a private IP (10.0.0.x) used for all cluster traffic +- **UFW Firewall:** Only public ports are exposed: 22 (SSH), 53 (DNS, nameservers only), 80/443 (HTTP/HTTPS), 51820 (WireGuard UDP) +- **Internal services** (RQLite 5001/7001, IPFS 4001/4501, Olric 3320/3322, Gateway 6001) are only accessible via WireGuard or localhost +- **Invite tokens:** Single-use, time-limited tokens for secure node joining. No shared secrets on the CLI +- **Join flow:** New nodes authenticate via HTTPS (443), establish WireGuard tunnel, then join all services over the encrypted mesh + ### TLS/HTTPS -- Automatic ACME (Let's Encrypt) certificate management +- Automatic ACME (Let's Encrypt) certificate management via Caddy - TLS 1.3 support - HTTP/2 enabled -- Certificate caching +- On-demand TLS for deployment custom domains ### Middleware Stack @@ -441,17 +451,25 @@ make test-e2e # Run E2E tests ### Production ```bash -# First node (creates cluster) -sudo orama install --vps-ip --domain node1.example.com +# First node (genesis — creates cluster) +sudo orama install --vps-ip --domain node1.example.com --nameserver -# Additional nodes (join cluster) -sudo orama install --vps-ip --domain node2.example.com \ - --peers /dns4/node1.example.com/tcp/4001/p2p/ \ - --join :7002 \ - --cluster-secret \ - --swarm-key +# On the genesis node, generate an invite for a new node +orama invite +# Outputs: sudo orama install --join https://node1.example.com --token --vps-ip + +# Additional nodes (join via invite token over HTTPS) +sudo orama install --join https://node1.example.com --token \ + --vps-ip --nameserver ``` +**Security:** Nodes join via single-use invite tokens over HTTPS. A WireGuard VPN tunnel +is established before any cluster services start. All inter-node traffic (RQLite, IPFS, +Olric, LibP2P) flows over the encrypted WireGuard mesh — no cluster ports are exposed +publicly. **Never use `http://:6001`** for joining — port 6001 is internal-only and +blocked by UFW. Use the domain (`https://node1.example.com`) or, if DNS is not yet +configured, use the IP over HTTP port 80 (`http://`) which goes through Caddy. + ### Docker (Future) Planned containerization with Docker Compose and Kubernetes support. diff --git a/docs/DEV_DEPLOY.md b/docs/DEV_DEPLOY.md index dfd4f82..9606988 100644 --- a/docs/DEV_DEPLOY.md +++ b/docs/DEV_DEPLOY.md @@ -95,14 +95,74 @@ To deploy to all nodes, repeat steps 3-5 (dev) or 3-4 (production) for each VPS ### CLI Flags Reference +#### `orama install` + | Flag | Description | |------|-------------| -| `--branch ` | Git branch to pull from (production deployment) | -| `--no-pull` | Skip git pull, use existing `/home/debros/src` (dev deployment) | +| `--vps-ip ` | VPS public IP address (required) | +| `--domain ` | Domain for HTTPS certificates | +| `--base-domain ` | Base domain for deployment routing (e.g., dbrs.space) | +| `--nameserver` | Configure this node as a nameserver (CoreDNS + Caddy) | +| `--join ` | Join existing cluster via HTTPS URL (e.g., `https://node1.dbrs.space`) | +| `--token ` | Invite token for joining (from `orama invite` on existing node) | +| `--branch ` | Git branch to use (default: main) | +| `--no-pull` | Skip git clone/pull, use existing `/home/debros/src` | +| `--force` | Force reconfiguration even if already installed | +| `--skip-firewall` | Skip UFW firewall setup | +| `--skip-checks` | Skip minimum resource checks (RAM/CPU) | + +#### `orama invite` + +| Flag | Description | +|------|-------------| +| `--expiry ` | Token expiry duration (default: 1h) | + +#### `orama upgrade` + +| Flag | Description | +|------|-------------| +| `--branch ` | Git branch to pull from | +| `--no-pull` | Skip git pull, use existing source | | `--restart` | Restart all services after upgrade | -| `--nameserver` | Configure this node as a nameserver (install only) | -| `--domain ` | Domain for HTTPS certificates (install only) | -| `--vps-ip ` | VPS public IP address (install only) | + +### Node Join Flow + +```bash +# 1. Genesis node (first node, creates cluster) +sudo orama install --vps-ip 1.2.3.4 --domain node1.dbrs.space \ + --base-domain dbrs.space --nameserver + +# 2. On genesis node, generate an invite +orama invite +# Output: sudo orama install --join https://node1.dbrs.space --token --vps-ip + +# 3. On the new node, run the printed command +sudo orama install --join https://node1.dbrs.space --token abc123... \ + --vps-ip 5.6.7.8 --nameserver +``` + +The join flow establishes a WireGuard VPN tunnel before starting cluster services. +All inter-node communication (RQLite, IPFS, Olric) uses WireGuard IPs (10.0.0.x). +No cluster ports are ever exposed publicly. + +#### DNS Prerequisite + +The `--join` URL should use the HTTPS domain of the genesis node (e.g., `https://node1.dbrs.space`). +For this to work, the domain registrar for `dbrs.space` must have NS records pointing to the genesis +node's IP so that `node1.dbrs.space` resolves publicly. + +**If DNS is not yet configured**, you can use the genesis node's public IP with HTTP as a fallback: + +```bash +sudo orama install --join http://1.2.3.4 --vps-ip 5.6.7.8 --token abc123... --nameserver +``` + +This works because Caddy's `:80` block proxies all HTTP traffic to the gateway. However, once DNS +is properly configured, always use the HTTPS domain URL. + +**Important:** Never use `http://:6001` — port 6001 is the internal gateway and is blocked by +UFW from external access. The join request goes through Caddy on port 80 (HTTP) or 443 (HTTPS), +which proxies to the gateway internally. ## Debugging Production Issues diff --git a/e2e/cluster/namespace_cluster_test.go b/e2e/cluster/namespace_cluster_test.go index df66414..caf2a76 100644 --- a/e2e/cluster/namespace_cluster_test.go +++ b/e2e/cluster/namespace_cluster_test.go @@ -170,9 +170,9 @@ func TestNamespaceCluster_OlricHealth(t *testing.T) { func TestNamespaceCluster_GatewayHealth(t *testing.T) { // Check if gateway binary exists gatewayBinaryPaths := []string{ - "./bin/gateway", - "../bin/gateway", - "/usr/local/bin/orama-gateway", + "./bin/orama", + "../bin/orama", + "/usr/local/bin/orama", } var gatewayBinaryExists bool diff --git a/migrations/013_wireguard_peers.sql b/migrations/013_wireguard_peers.sql new file mode 100644 index 0000000..636f210 --- /dev/null +++ b/migrations/013_wireguard_peers.sql @@ -0,0 +1,9 @@ +-- WireGuard mesh peer tracking +CREATE TABLE IF NOT EXISTS wireguard_peers ( + node_id TEXT PRIMARY KEY, + wg_ip TEXT NOT NULL UNIQUE, + public_key TEXT NOT NULL UNIQUE, + public_ip TEXT NOT NULL, + wg_port INTEGER DEFAULT 51820, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); diff --git a/migrations/014_invite_tokens.sql b/migrations/014_invite_tokens.sql new file mode 100644 index 0000000..9538823 --- /dev/null +++ b/migrations/014_invite_tokens.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS invite_tokens ( + token TEXT PRIMARY KEY, + created_by TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + expires_at DATETIME NOT NULL, + used_at DATETIME, + used_by_ip TEXT +); diff --git a/pkg/cli/prod_commands_test.go b/pkg/cli/prod_commands_test.go index c67e617..007e1d1 100644 --- a/pkg/cli/prod_commands_test.go +++ b/pkg/cli/prod_commands_test.go @@ -7,42 +7,32 @@ import ( ) // TestProdCommandFlagParsing verifies that prod command flags are parsed correctly -// Note: The installer now uses --vps-ip presence to determine if it's a first node (no --bootstrap flag) -// First node: has --vps-ip but no --peers or --join -// Joining node: has --vps-ip, --peers, and --cluster-secret +// Genesis node: has --vps-ip but no --join or --token +// Joining node: has --vps-ip, --join (HTTPS URL), and --token (invite token) func TestProdCommandFlagParsing(t *testing.T) { tests := []struct { - name string - args []string - expectVPSIP string - expectDomain string - expectPeers string - expectJoin string - expectSecret string - expectBranch string - isFirstNode bool // first node = no peers and no join address + name string + args []string + expectVPSIP string + expectDomain string + expectJoin string + expectToken string + expectBranch string + isFirstNode bool // genesis node = no --join and no --token }{ { - name: "first node (creates new cluster)", - args: []string{"install", "--vps-ip", "10.0.0.1", "--domain", "node-1.example.com"}, - expectVPSIP: "10.0.0.1", + name: "genesis node (creates new cluster)", + args: []string{"install", "--vps-ip", "10.0.0.1", "--domain", "node-1.example.com"}, + expectVPSIP: "10.0.0.1", expectDomain: "node-1.example.com", - isFirstNode: true, + isFirstNode: true, }, { - name: "joining node with peers", - args: []string{"install", "--vps-ip", "10.0.0.2", "--peers", "/ip4/10.0.0.1/tcp/4001/p2p/Qm123", "--cluster-secret", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, - expectVPSIP: "10.0.0.2", - expectPeers: "/ip4/10.0.0.1/tcp/4001/p2p/Qm123", - expectSecret: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - isFirstNode: false, - }, - { - name: "joining node with join address", - args: []string{"install", "--vps-ip", "10.0.0.3", "--join", "10.0.0.1:7001", "--cluster-secret", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, - expectVPSIP: "10.0.0.3", - expectJoin: "10.0.0.1:7001", - expectSecret: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + name: "joining node with invite token", + args: []string{"install", "--vps-ip", "10.0.0.2", "--join", "https://node1.dbrs.space", "--token", "abc123def456"}, + expectVPSIP: "10.0.0.2", + expectJoin: "https://node1.dbrs.space", + expectToken: "abc123def456", isFirstNode: false, }, { @@ -56,8 +46,7 @@ func TestProdCommandFlagParsing(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Extract flags manually to verify parsing logic - var vpsIP, domain, peersStr, joinAddr, clusterSecret, branch string + var vpsIP, domain, joinAddr, token, branch string for i, arg := range tt.args { switch arg { @@ -69,17 +58,13 @@ func TestProdCommandFlagParsing(t *testing.T) { if i+1 < len(tt.args) { domain = tt.args[i+1] } - case "--peers": - if i+1 < len(tt.args) { - peersStr = tt.args[i+1] - } case "--join": if i+1 < len(tt.args) { joinAddr = tt.args[i+1] } - case "--cluster-secret": + case "--token": if i+1 < len(tt.args) { - clusterSecret = tt.args[i+1] + token = tt.args[i+1] } case "--branch": if i+1 < len(tt.args) { @@ -88,8 +73,8 @@ func TestProdCommandFlagParsing(t *testing.T) { } } - // First node detection: no peers and no join address - isFirstNode := peersStr == "" && joinAddr == "" + // Genesis node detection: no --join and no --token + isFirstNode := joinAddr == "" && token == "" if vpsIP != tt.expectVPSIP { t.Errorf("expected vpsIP=%q, got %q", tt.expectVPSIP, vpsIP) @@ -97,14 +82,11 @@ func TestProdCommandFlagParsing(t *testing.T) { if domain != tt.expectDomain { t.Errorf("expected domain=%q, got %q", tt.expectDomain, domain) } - if peersStr != tt.expectPeers { - t.Errorf("expected peers=%q, got %q", tt.expectPeers, peersStr) - } if joinAddr != tt.expectJoin { t.Errorf("expected join=%q, got %q", tt.expectJoin, joinAddr) } - if clusterSecret != tt.expectSecret { - t.Errorf("expected clusterSecret=%q, got %q", tt.expectSecret, clusterSecret) + if token != tt.expectToken { + t.Errorf("expected token=%q, got %q", tt.expectToken, token) } if branch != tt.expectBranch { t.Errorf("expected branch=%q, got %q", tt.expectBranch, branch) diff --git a/pkg/cli/production/commands.go b/pkg/cli/production/commands.go index d52a0c4..dace129 100644 --- a/pkg/cli/production/commands.go +++ b/pkg/cli/production/commands.go @@ -5,6 +5,7 @@ import ( "os" "github.com/DeBrosOfficial/network/pkg/cli/production/install" + "github.com/DeBrosOfficial/network/pkg/cli/production/invite" "github.com/DeBrosOfficial/network/pkg/cli/production/lifecycle" "github.com/DeBrosOfficial/network/pkg/cli/production/logs" "github.com/DeBrosOfficial/network/pkg/cli/production/migrate" @@ -24,6 +25,8 @@ func HandleCommand(args []string) { subargs := args[1:] switch subcommand { + case "invite": + invite.Handle(subargs) case "install": install.Handle(subargs) case "upgrade": diff --git a/pkg/cli/production/install/flags.go b/pkg/cli/production/install/flags.go index 6fb0b90..42b344a 100644 --- a/pkg/cli/production/install/flags.go +++ b/pkg/cli/production/install/flags.go @@ -17,10 +17,11 @@ type Flags struct { DryRun bool SkipChecks bool Nameserver bool // Make this node a nameserver (runs CoreDNS + Caddy) - JoinAddress string - ClusterSecret string - SwarmKey string - PeersStr string + JoinAddress string // HTTPS URL of existing node (e.g., https://node1.dbrs.space) + Token string // Invite token for joining (from orama invite) + ClusterSecret string // Deprecated: use --token instead + SwarmKey string // Deprecated: use --token instead + PeersStr string // Deprecated: use --token instead // IPFS/Cluster specific info for Peering configuration IPFSPeerID string @@ -28,6 +29,9 @@ type Flags struct { IPFSClusterPeerID string IPFSClusterAddrs string + // Security flags + SkipFirewall bool // Skip UFW firewall setup (for users who manage their own firewall) + // Anyone relay operator flags AnyoneRelay bool // Run as relay operator instead of client AnyoneExit bool // Run as exit relay (legal implications) @@ -57,9 +61,10 @@ func ParseFlags(args []string) (*Flags, error) { fs.BoolVar(&flags.Nameserver, "nameserver", false, "Make this node a nameserver (runs CoreDNS + Caddy)") // Cluster join flags - fs.StringVar(&flags.JoinAddress, "join", "", "Join an existing cluster (e.g. 1.2.3.4:7001)") - fs.StringVar(&flags.ClusterSecret, "cluster-secret", "", "Cluster secret for IPFS Cluster (required if joining)") - fs.StringVar(&flags.SwarmKey, "swarm-key", "", "IPFS Swarm key hex (64 chars, last line of swarm.key)") + fs.StringVar(&flags.JoinAddress, "join", "", "Join existing cluster via HTTPS URL (e.g. https://node1.dbrs.space)") + fs.StringVar(&flags.Token, "token", "", "Invite token for joining (from orama invite on existing node)") + fs.StringVar(&flags.ClusterSecret, "cluster-secret", "", "Deprecated: use --token instead") + fs.StringVar(&flags.SwarmKey, "swarm-key", "", "Deprecated: use --token instead") fs.StringVar(&flags.PeersStr, "peers", "", "Comma-separated list of bootstrap peer multiaddrs") // IPFS/Cluster specific info for Peering configuration @@ -68,6 +73,9 @@ func ParseFlags(args []string) (*Flags, error) { fs.StringVar(&flags.IPFSClusterPeerID, "ipfs-cluster-peer", "", "Peer ID of existing IPFS Cluster node") fs.StringVar(&flags.IPFSClusterAddrs, "ipfs-cluster-addrs", "", "Comma-separated multiaddrs of existing IPFS Cluster node") + // Security flags + fs.BoolVar(&flags.SkipFirewall, "skip-firewall", false, "Skip UFW firewall setup (for users who manage their own firewall)") + // Anyone relay operator flags fs.BoolVar(&flags.AnyoneRelay, "anyone-relay", false, "Run as Anyone relay operator (earn rewards)") fs.BoolVar(&flags.AnyoneExit, "anyone-exit", false, "Run as exit relay (requires --anyone-relay, legal implications)") diff --git a/pkg/cli/production/install/orchestrator.go b/pkg/cli/production/install/orchestrator.go index f0d1132..e383c94 100644 --- a/pkg/cli/production/install/orchestrator.go +++ b/pkg/cli/production/install/orchestrator.go @@ -2,14 +2,20 @@ package install import ( "bufio" + "crypto/tls" + "encoding/json" "fmt" + "io" + "net/http" "os" + "os/exec" "path/filepath" "strings" "time" "github.com/DeBrosOfficial/network/pkg/cli/utils" "github.com/DeBrosOfficial/network/pkg/environments/production" + joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join" ) // Orchestrator manages the install process @@ -97,9 +103,11 @@ func (o *Orchestrator) Execute() error { return nil } - // Save secrets before installation - if err := o.validator.SaveSecrets(); err != nil { - return err + // Save secrets before installation (only for genesis; join flow gets secrets from response) + if !o.isJoiningNode() { + if err := o.validator.SaveSecrets(); err != nil { + return err + } } // Save preferences for future upgrades (branch + nameserver) @@ -132,33 +140,56 @@ func (o *Orchestrator) Execute() error { return fmt.Errorf("binary installation failed: %w", err) } - // Phase 3: Generate secrets FIRST (before service initialization) + // Branch: genesis node vs joining node + if o.isJoiningNode() { + return o.executeJoinFlow() + } + return o.executeGenesisFlow() +} + +// isJoiningNode returns true if --join and --token are both set +func (o *Orchestrator) isJoiningNode() bool { + return o.flags.JoinAddress != "" && o.flags.Token != "" +} + +// executeGenesisFlow runs the install for the first node in a new cluster +func (o *Orchestrator) executeGenesisFlow() error { + // Phase 3: Generate secrets locally fmt.Printf("\n🔐 Phase 3: Generating secrets...\n") if err := o.setup.Phase3GenerateSecrets(); err != nil { return fmt.Errorf("secret generation failed: %w", err) } - // Phase 4: Generate configs (BEFORE service initialization) + // Phase 6a: WireGuard — self-assign 10.0.0.1 + fmt.Printf("\n🔒 Phase 6a: Setting up WireGuard mesh VPN...\n") + if _, _, err := o.setup.Phase6SetupWireGuard(true); err != nil { + fmt.Fprintf(os.Stderr, " ⚠️ Warning: WireGuard setup failed: %v\n", err) + } else { + fmt.Printf(" ✓ WireGuard configured (10.0.0.1)\n") + } + + // Phase 6b: UFW firewall + fmt.Printf("\n🛡️ Phase 6b: Setting up UFW firewall...\n") + if err := o.setup.Phase6bSetupFirewall(o.flags.SkipFirewall); err != nil { + fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall setup failed: %v\n", err) + } + + // Phase 4: Generate configs using WG IP (10.0.0.1) as advertise address + // All inter-node communication uses WireGuard IPs, not public IPs fmt.Printf("\n⚙️ Phase 4: Generating configurations...\n") - // Internal gateway always runs HTTP on port 6001 - // When using Caddy (nameserver mode), Caddy handles external HTTPS and proxies to internal gateway - // When not using Caddy, the gateway runs HTTP-only (use a reverse proxy for HTTPS) enableHTTPS := false - if err := o.setup.Phase4GenerateConfigs(o.peers, o.flags.VpsIP, enableHTTPS, o.flags.Domain, o.flags.BaseDomain, o.flags.JoinAddress); err != nil { + genesisWGIP := "10.0.0.1" + if err := o.setup.Phase4GenerateConfigs(o.peers, genesisWGIP, enableHTTPS, o.flags.Domain, o.flags.BaseDomain, ""); err != nil { return fmt.Errorf("configuration generation failed: %w", err) } - // Validate generated configuration if err := o.validator.ValidateGeneratedConfig(); err != nil { return err } - // Phase 2c: Initialize services (after config is in place) + // Phase 2c: Initialize services (use WG IP for IPFS Cluster peer discovery) fmt.Printf("\nPhase 2c: Initializing services...\n") - ipfsPeerInfo := o.buildIPFSPeerInfo() - ipfsClusterPeerInfo := o.buildIPFSClusterPeerInfo() - - if err := o.setup.Phase2cInitializeServices(o.peers, o.flags.VpsIP, ipfsPeerInfo, ipfsClusterPeerInfo); err != nil { + if err := o.setup.Phase2cInitializeServices(o.peers, genesisWGIP, nil, nil); err != nil { return fmt.Errorf("service initialization failed: %w", err) } @@ -168,9 +199,9 @@ func (o *Orchestrator) Execute() error { return fmt.Errorf("service creation failed: %w", err) } - // Seed DNS records after services are running (RQLite must be up) + // Phase 7: Seed DNS records if o.flags.Nameserver && o.flags.BaseDomain != "" { - fmt.Printf("\n🌐 Phase 6: Seeding DNS records...\n") + fmt.Printf("\n🌐 Phase 7: Seeding DNS records...\n") fmt.Printf(" Waiting for RQLite to start (10s)...\n") time.Sleep(10 * time.Second) if err := o.setup.SeedDNSRecords(o.flags.BaseDomain, o.flags.VpsIP, o.peers); err != nil { @@ -180,18 +211,206 @@ func (o *Orchestrator) Execute() error { } } - // Log completion with actual peer ID o.setup.LogSetupComplete(o.setup.NodePeerID) fmt.Printf("✅ Production installation complete!\n\n") + o.printFirstNodeSecrets() + return nil +} - // For first node, print important secrets and identifiers - if o.validator.IsFirstNode() { - o.printFirstNodeSecrets() +// executeJoinFlow runs the install for a node joining an existing cluster via invite token +func (o *Orchestrator) executeJoinFlow() error { + // Step 1: Generate WG keypair + fmt.Printf("\n🔑 Generating WireGuard keypair...\n") + privKey, pubKey, err := production.GenerateKeyPair() + if err != nil { + return fmt.Errorf("failed to generate WG keypair: %w", err) + } + fmt.Printf(" ✓ WireGuard keypair generated\n") + + // Step 2: Call join endpoint on existing node + fmt.Printf("\n🤝 Requesting cluster join from %s...\n", o.flags.JoinAddress) + joinResp, err := o.callJoinEndpoint(pubKey) + if err != nil { + return fmt.Errorf("join request failed: %w", err) + } + fmt.Printf(" ✓ Join approved — assigned WG IP: %s\n", joinResp.WGIP) + fmt.Printf(" ✓ Received %d WG peers\n", len(joinResp.WGPeers)) + + // Step 3: Configure WireGuard with assigned IP and peers + fmt.Printf("\n🔒 Configuring WireGuard tunnel...\n") + var wgPeers []production.WireGuardPeer + for _, p := range joinResp.WGPeers { + wgPeers = append(wgPeers, production.WireGuardPeer{ + PublicKey: p.PublicKey, + Endpoint: p.Endpoint, + AllowedIP: p.AllowedIP, + }) + } + // Install WG package first + wp := production.NewWireGuardProvisioner(production.WireGuardConfig{}) + if err := wp.Install(); err != nil { + return fmt.Errorf("failed to install wireguard: %w", err) + } + if err := o.setup.EnableWireGuardWithPeers(privKey, joinResp.WGIP, wgPeers); err != nil { + return fmt.Errorf("failed to enable WireGuard: %w", err) + } + + // Step 4: Verify WG tunnel + fmt.Printf("\n🔍 Verifying WireGuard tunnel...\n") + if err := o.verifyWGTunnel(joinResp.WGPeers); err != nil { + return fmt.Errorf("WireGuard tunnel verification failed: %w", err) + } + fmt.Printf(" ✓ WireGuard tunnel established\n") + + // Step 5: UFW firewall + fmt.Printf("\n🛡️ Setting up UFW firewall...\n") + if err := o.setup.Phase6bSetupFirewall(o.flags.SkipFirewall); err != nil { + fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall setup failed: %v\n", err) + } + + // Step 6: Save secrets from join response + fmt.Printf("\n🔐 Saving cluster secrets...\n") + if err := o.saveSecretsFromJoinResponse(joinResp); err != nil { + return fmt.Errorf("failed to save secrets: %w", err) + } + fmt.Printf(" ✓ Secrets saved\n") + + // Step 7: Generate configs using WG IP as advertise address + // All inter-node communication uses WireGuard IPs, not public IPs + fmt.Printf("\n⚙️ Generating configurations...\n") + enableHTTPS := false + rqliteJoin := joinResp.RQLiteJoinAddress + if err := o.setup.Phase4GenerateConfigs(joinResp.BootstrapPeers, joinResp.WGIP, enableHTTPS, o.flags.Domain, joinResp.BaseDomain, rqliteJoin); err != nil { + return fmt.Errorf("configuration generation failed: %w", err) + } + + if err := o.validator.ValidateGeneratedConfig(); err != nil { + return err + } + + // Step 8: Initialize services with IPFS peer info from join response + fmt.Printf("\nInitializing services...\n") + var ipfsPeerInfo *production.IPFSPeerInfo + if joinResp.IPFSPeer.ID != "" { + ipfsPeerInfo = &production.IPFSPeerInfo{ + PeerID: joinResp.IPFSPeer.ID, + Addrs: joinResp.IPFSPeer.Addrs, + } + } + var ipfsClusterPeerInfo *production.IPFSClusterPeerInfo + if joinResp.IPFSClusterPeer.ID != "" { + ipfsClusterPeerInfo = &production.IPFSClusterPeerInfo{ + PeerID: joinResp.IPFSClusterPeer.ID, + Addrs: joinResp.IPFSClusterPeer.Addrs, + } + } + + if err := o.setup.Phase2cInitializeServices(joinResp.BootstrapPeers, joinResp.WGIP, ipfsPeerInfo, ipfsClusterPeerInfo); err != nil { + return fmt.Errorf("service initialization failed: %w", err) + } + + // Step 9: Create systemd services + fmt.Printf("\n🔧 Creating systemd services...\n") + if err := o.setup.Phase5CreateSystemdServices(enableHTTPS); err != nil { + return fmt.Errorf("service creation failed: %w", err) + } + + o.setup.LogSetupComplete(o.setup.NodePeerID) + fmt.Printf("✅ Production installation complete! Joined cluster via %s\n\n", o.flags.JoinAddress) + return nil +} + +// callJoinEndpoint sends the join request to the existing node's HTTPS endpoint +func (o *Orchestrator) callJoinEndpoint(wgPubKey string) (*joinhandlers.JoinResponse, error) { + reqBody := joinhandlers.JoinRequest{ + Token: o.flags.Token, + WGPublicKey: wgPubKey, + PublicIP: o.flags.VpsIP, + } + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := strings.TrimRight(o.flags.JoinAddress, "/") + "/v1/internal/join" + client := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, // Self-signed certs during initial setup + }, + }, + } + + resp, err := client.Post(url, "application/json", strings.NewReader(string(bodyBytes))) + if err != nil { + return nil, fmt.Errorf("failed to contact %s: %w", url, err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("join rejected (HTTP %d): %s", resp.StatusCode, string(respBody)) + } + + var joinResp joinhandlers.JoinResponse + if err := json.Unmarshal(respBody, &joinResp); err != nil { + return nil, fmt.Errorf("failed to parse join response: %w", err) + } + + return &joinResp, nil +} + +// saveSecretsFromJoinResponse writes cluster secrets received from the join endpoint to disk +func (o *Orchestrator) saveSecretsFromJoinResponse(resp *joinhandlers.JoinResponse) error { + secretsDir := filepath.Join(o.oramaDir, "secrets") + if err := os.MkdirAll(secretsDir, 0700); err != nil { + return fmt.Errorf("failed to create secrets dir: %w", err) + } + + // Write cluster secret + if resp.ClusterSecret != "" { + if err := os.WriteFile(filepath.Join(secretsDir, "cluster-secret"), []byte(resp.ClusterSecret), 0600); err != nil { + return fmt.Errorf("failed to write cluster-secret: %w", err) + } + } + + // Write swarm key + if resp.SwarmKey != "" { + if err := os.WriteFile(filepath.Join(secretsDir, "swarm.key"), []byte(resp.SwarmKey), 0600); err != nil { + return fmt.Errorf("failed to write swarm.key: %w", err) + } } return nil } +// verifyWGTunnel pings a WG peer to verify the tunnel is working +func (o *Orchestrator) verifyWGTunnel(peers []joinhandlers.WGPeerInfo) error { + if len(peers) == 0 { + return fmt.Errorf("no WG peers to verify") + } + + // Extract the IP from the first peer's AllowedIP (e.g. "10.0.0.1/32" -> "10.0.0.1") + targetIP := strings.TrimSuffix(peers[0].AllowedIP, "/32") + + // Retry ping for up to 30 seconds + deadline := time.Now().Add(30 * time.Second) + for time.Now().Before(deadline) { + cmd := exec.Command("ping", "-c", "1", "-W", "2", targetIP) + if err := cmd.Run(); err == nil { + return nil + } + time.Sleep(2 * time.Second) + } + + return fmt.Errorf("could not reach %s via WireGuard after 30s", targetIP) +} + func (o *Orchestrator) buildIPFSPeerInfo() *production.IPFSPeerInfo { if o.flags.IPFSPeerID != "" { var addrs []string diff --git a/pkg/cli/production/invite/command.go b/pkg/cli/production/invite/command.go new file mode 100644 index 0000000..57ca730 --- /dev/null +++ b/pkg/cli/production/invite/command.go @@ -0,0 +1,115 @@ +package invite + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "os" + "strings" + "time" + + "gopkg.in/yaml.v3" +) + +// Handle processes the invite command +func Handle(args []string) { + // Must run on a cluster node with RQLite running locally + domain, err := readNodeDomain() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: could not read node config: %v\n", err) + fmt.Fprintf(os.Stderr, "Make sure you're running this on an installed node.\n") + os.Exit(1) + } + + // Generate random token + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + fmt.Fprintf(os.Stderr, "Error generating token: %v\n", err) + os.Exit(1) + } + token := hex.EncodeToString(tokenBytes) + + // Determine expiry (default 1 hour, --expiry flag for override) + expiry := time.Hour + for i, arg := range args { + if arg == "--expiry" && i+1 < len(args) { + d, err := time.ParseDuration(args[i+1]) + if err != nil { + fmt.Fprintf(os.Stderr, "Invalid expiry duration: %v\n", err) + os.Exit(1) + } + expiry = d + } + } + + expiresAt := time.Now().UTC().Add(expiry).Format("2006-01-02 15:04:05") + + // Get node ID for created_by + nodeID := "unknown" + if hostname, err := os.Hostname(); err == nil { + nodeID = hostname + } + + // Insert token into RQLite via HTTP API + if err := insertToken(token, nodeID, expiresAt); err != nil { + fmt.Fprintf(os.Stderr, "Error storing invite token: %v\n", err) + fmt.Fprintf(os.Stderr, "Make sure RQLite is running on this node.\n") + os.Exit(1) + } + + // Print the invite command + fmt.Printf("\nInvite token created (expires in %s)\n\n", expiry) + fmt.Printf("Run this on the new node:\n\n") + fmt.Printf(" sudo orama install --join https://%s --token %s --vps-ip --nameserver\n\n", domain, token) + fmt.Printf("Replace with the new node's public IP address.\n") +} + +// readNodeDomain reads the domain from the node config file +func readNodeDomain() (string, error) { + configPath := "/home/debros/.orama/configs/node.yaml" + data, err := os.ReadFile(configPath) + if err != nil { + return "", fmt.Errorf("read config: %w", err) + } + + var config struct { + Node struct { + Domain string `yaml:"domain"` + } `yaml:"node"` + } + if err := yaml.Unmarshal(data, &config); err != nil { + return "", fmt.Errorf("parse config: %w", err) + } + + if config.Node.Domain == "" { + return "", fmt.Errorf("node domain not set in config") + } + + return config.Node.Domain, nil +} + +// insertToken inserts an invite token into RQLite via HTTP API +func insertToken(token, createdBy, expiresAt string) error { + body := fmt.Sprintf(`[["INSERT INTO invite_tokens (token, created_by, expires_at) VALUES ('%s', '%s', '%s')"]]`, + token, createdBy, expiresAt) + + req, err := http.NewRequest("POST", "http://localhost:5001/db/execute", strings.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to RQLite: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("RQLite returned status %d", resp.StatusCode) + } + + return nil +} diff --git a/pkg/cli/production/upgrade/orchestrator.go b/pkg/cli/production/upgrade/orchestrator.go index 23bb899..09d3c50 100644 --- a/pkg/cli/production/upgrade/orchestrator.go +++ b/pkg/cli/production/upgrade/orchestrator.go @@ -147,6 +147,12 @@ func (o *Orchestrator) Execute() error { fmt.Fprintf(os.Stderr, "⚠️ Service update warning: %v\n", err) } + // Re-apply UFW firewall rules (idempotent) + fmt.Printf("\n🛡️ Re-applying firewall rules...\n") + if err := o.setup.Phase6bSetupFirewall(false); err != nil { + fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall re-apply failed: %v\n", err) + } + fmt.Printf("\n✅ Upgrade complete!\n") // Restart services if requested @@ -317,12 +323,24 @@ func (o *Orchestrator) extractGatewayConfig() (enableHTTPS bool, domain string, } } - // Also check node.yaml for base_domain + // Also check node.yaml for domain and base_domain nodeConfigPath := filepath.Join(o.oramaDir, "configs", "node.yaml") if data, err := os.ReadFile(nodeConfigPath); err == nil { configStr := string(data) for _, line := range strings.Split(configStr, "\n") { trimmed := strings.TrimSpace(line) + // Extract domain from node.yaml (under node: section) if not already found + if domain == "" && strings.HasPrefix(trimmed, "domain:") && !strings.HasPrefix(trimmed, "domain_") { + parts := strings.SplitN(trimmed, ":", 2) + if len(parts) > 1 { + d := strings.TrimSpace(parts[1]) + d = strings.Trim(d, "\"'") + if d != "" && d != "null" { + domain = d + enableHTTPS = true + } + } + } if strings.HasPrefix(trimmed, "base_domain:") { parts := strings.SplitN(trimmed, ":", 2) if len(parts) > 1 { @@ -332,7 +350,6 @@ func (o *Orchestrator) extractGatewayConfig() (enableHTTPS bool, domain string, baseDomain = "" } } - break } } } diff --git a/pkg/deployments/replica_manager.go b/pkg/deployments/replica_manager.go index dc89cb7..a34121c 100644 --- a/pkg/deployments/replica_manager.go +++ b/pkg/deployments/replica_manager.go @@ -259,7 +259,7 @@ func (rm *ReplicaManager) GetNodeIP(ctx context.Context, nodeID string) (string, } var rows []nodeRow - query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1` + query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1` err := rm.db.Query(internalCtx, &rows, query, nodeID) if err != nil { return "", err diff --git a/pkg/environments/production/firewall.go b/pkg/environments/production/firewall.go new file mode 100644 index 0000000..330fdfa --- /dev/null +++ b/pkg/environments/production/firewall.go @@ -0,0 +1,133 @@ +package production + +import ( + "fmt" + "os/exec" + "strings" +) + +// FirewallConfig holds the configuration for UFW firewall rules +type FirewallConfig struct { + SSHPort int // default 22 + IsNameserver bool // enables port 53 TCP+UDP + AnyoneORPort int // 0 = disabled, typically 9001 + WireGuardPort int // default 51820 +} + +// FirewallProvisioner manages UFW firewall setup +type FirewallProvisioner struct { + config FirewallConfig +} + +// NewFirewallProvisioner creates a new firewall provisioner +func NewFirewallProvisioner(config FirewallConfig) *FirewallProvisioner { + if config.SSHPort == 0 { + config.SSHPort = 22 + } + if config.WireGuardPort == 0 { + config.WireGuardPort = 51820 + } + return &FirewallProvisioner{ + config: config, + } +} + +// IsInstalled checks if UFW is available +func (fp *FirewallProvisioner) IsInstalled() bool { + _, err := exec.LookPath("ufw") + return err == nil +} + +// Install installs UFW if not present +func (fp *FirewallProvisioner) Install() error { + if fp.IsInstalled() { + return nil + } + + cmd := exec.Command("apt-get", "install", "-y", "ufw") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to install ufw: %w\n%s", err, string(output)) + } + + return nil +} + +// GenerateRules returns the list of UFW commands to apply +func (fp *FirewallProvisioner) GenerateRules() []string { + rules := []string{ + // Reset to clean state + "ufw --force reset", + + // Default policies + "ufw default deny incoming", + "ufw default allow outgoing", + + // SSH (always required) + fmt.Sprintf("ufw allow %d/tcp", fp.config.SSHPort), + + // WireGuard (always required for mesh) + fmt.Sprintf("ufw allow %d/udp", fp.config.WireGuardPort), + + // Public web services + "ufw allow 80/tcp", // ACME / HTTP redirect + "ufw allow 443/tcp", // HTTPS (Caddy → Gateway) + } + + // DNS (only for nameserver nodes) + if fp.config.IsNameserver { + rules = append(rules, "ufw allow 53/tcp") + rules = append(rules, "ufw allow 53/udp") + } + + // Anyone relay ORPort + if fp.config.AnyoneORPort > 0 { + rules = append(rules, fmt.Sprintf("ufw allow %d/tcp", fp.config.AnyoneORPort)) + } + + // Allow all traffic from WireGuard subnet (inter-node encrypted traffic) + rules = append(rules, "ufw allow from 10.0.0.0/8") + + // Enable firewall + rules = append(rules, "ufw --force enable") + + return rules +} + +// Setup applies all firewall rules. Idempotent — safe to call multiple times. +func (fp *FirewallProvisioner) Setup() error { + if err := fp.Install(); err != nil { + return err + } + + rules := fp.GenerateRules() + + for _, rule := range rules { + parts := strings.Fields(rule) + cmd := exec.Command(parts[0], parts[1:]...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to apply firewall rule '%s': %w\n%s", rule, err, string(output)) + } + } + + return nil +} + +// IsActive checks if UFW is active +func (fp *FirewallProvisioner) IsActive() bool { + cmd := exec.Command("ufw", "status") + output, err := cmd.CombinedOutput() + if err != nil { + return false + } + return strings.Contains(string(output), "Status: active") +} + +// GetStatus returns the current UFW status +func (fp *FirewallProvisioner) GetStatus() (string, error) { + cmd := exec.Command("ufw", "status", "verbose") + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to get ufw status: %w\n%s", err, string(output)) + } + return string(output), nil +} diff --git a/pkg/environments/production/firewall_test.go b/pkg/environments/production/firewall_test.go new file mode 100644 index 0000000..2507a65 --- /dev/null +++ b/pkg/environments/production/firewall_test.go @@ -0,0 +1,117 @@ +package production + +import ( + "strings" + "testing" +) + +func TestFirewallProvisioner_GenerateRules_StandardNode(t *testing.T) { + fp := NewFirewallProvisioner(FirewallConfig{}) + + rules := fp.GenerateRules() + + // Should contain defaults + assertContainsRule(t, rules, "ufw --force reset") + assertContainsRule(t, rules, "ufw default deny incoming") + assertContainsRule(t, rules, "ufw default allow outgoing") + assertContainsRule(t, rules, "ufw allow 22/tcp") + assertContainsRule(t, rules, "ufw allow 51820/udp") + assertContainsRule(t, rules, "ufw allow 80/tcp") + assertContainsRule(t, rules, "ufw allow 443/tcp") + assertContainsRule(t, rules, "ufw allow from 10.0.0.0/8") + assertContainsRule(t, rules, "ufw --force enable") + + // Should NOT contain DNS or Anyone relay + for _, rule := range rules { + if strings.Contains(rule, "53/") { + t.Errorf("standard node should not have DNS rule: %s", rule) + } + if strings.Contains(rule, "9001") { + t.Errorf("standard node should not have Anyone relay rule: %s", rule) + } + } +} + +func TestFirewallProvisioner_GenerateRules_Nameserver(t *testing.T) { + fp := NewFirewallProvisioner(FirewallConfig{ + IsNameserver: true, + }) + + rules := fp.GenerateRules() + + assertContainsRule(t, rules, "ufw allow 53/tcp") + assertContainsRule(t, rules, "ufw allow 53/udp") +} + +func TestFirewallProvisioner_GenerateRules_WithAnyoneRelay(t *testing.T) { + fp := NewFirewallProvisioner(FirewallConfig{ + AnyoneORPort: 9001, + }) + + rules := fp.GenerateRules() + + assertContainsRule(t, rules, "ufw allow 9001/tcp") +} + +func TestFirewallProvisioner_GenerateRules_CustomSSHPort(t *testing.T) { + fp := NewFirewallProvisioner(FirewallConfig{ + SSHPort: 2222, + }) + + rules := fp.GenerateRules() + + assertContainsRule(t, rules, "ufw allow 2222/tcp") + + // Should NOT have default port 22 + for _, rule := range rules { + if rule == "ufw allow 22/tcp" { + t.Error("should not have default SSH port 22 when custom port is set") + } + } +} + +func TestFirewallProvisioner_GenerateRules_WireGuardSubnetAllowed(t *testing.T) { + fp := NewFirewallProvisioner(FirewallConfig{}) + + rules := fp.GenerateRules() + + assertContainsRule(t, rules, "ufw allow from 10.0.0.0/8") +} + +func TestFirewallProvisioner_GenerateRules_FullConfig(t *testing.T) { + fp := NewFirewallProvisioner(FirewallConfig{ + SSHPort: 2222, + IsNameserver: true, + AnyoneORPort: 9001, + WireGuardPort: 51821, + }) + + rules := fp.GenerateRules() + + assertContainsRule(t, rules, "ufw allow 2222/tcp") + assertContainsRule(t, rules, "ufw allow 51821/udp") + assertContainsRule(t, rules, "ufw allow 53/tcp") + assertContainsRule(t, rules, "ufw allow 53/udp") + assertContainsRule(t, rules, "ufw allow 9001/tcp") +} + +func TestFirewallProvisioner_DefaultPorts(t *testing.T) { + fp := NewFirewallProvisioner(FirewallConfig{}) + + if fp.config.SSHPort != 22 { + t.Errorf("default SSHPort = %d, want 22", fp.config.SSHPort) + } + if fp.config.WireGuardPort != 51820 { + t.Errorf("default WireGuardPort = %d, want 51820", fp.config.WireGuardPort) + } +} + +func assertContainsRule(t *testing.T, rules []string, expected string) { + t.Helper() + for _, rule := range rules { + if rule == expected { + return + } + } + t.Errorf("rules should contain '%s', got: %v", expected, rules) +} diff --git a/pkg/environments/production/orchestrator.go b/pkg/environments/production/orchestrator.go index 0740bff..a077eb3 100644 --- a/pkg/environments/production/orchestrator.go +++ b/pkg/environments/production/orchestrator.go @@ -254,6 +254,13 @@ func (ps *ProductionSetup) Phase2ProvisionEnvironment() error { ps.logf(" ✓ Deployment sudoers configured") } + // Set up WireGuard sudoers (allows debros user to manage WG peers) + if err := ps.userProvisioner.SetupWireGuardSudoers(); err != nil { + ps.logf(" ⚠️ Failed to setup wireguard sudoers: %v", err) + } else { + ps.logf(" ✓ WireGuard sudoers configured") + } + // Create directory structure (unified structure) if err := ps.fsProvisioner.EnsureDirectoryStructure(); err != nil { return fmt.Errorf("failed to create directory structure: %w", err) @@ -724,6 +731,25 @@ func (ps *ProductionSetup) Phase5CreateSystemdServices(enableHTTPS bool) error { ps.logf(" - debros-node.service started (with embedded gateway)") } + // Start CoreDNS and Caddy (nameserver nodes only) + // Caddy depends on debros-node.service (gateway on :6001), so start after node + if ps.isNameserver { + if _, err := os.Stat("/usr/local/bin/coredns"); err == nil { + if err := ps.serviceController.StartService("coredns.service"); err != nil { + ps.logf(" ⚠️ Failed to start coredns.service: %v", err) + } else { + ps.logf(" - coredns.service started") + } + } + if _, err := os.Stat("/usr/bin/caddy"); err == nil { + if err := ps.serviceController.StartService("caddy.service"); err != nil { + ps.logf(" ⚠️ Failed to start caddy.service: %v", err) + } else { + ps.logf(" - caddy.service started") + } + } + } + ps.logf(" ✓ All services started") return nil } @@ -775,6 +801,96 @@ func (ps *ProductionSetup) SeedDNSRecords(baseDomain, vpsIP string, peerAddresse return nil } +// Phase6SetupWireGuard installs WireGuard and generates keys for this node. +// For the first node, it self-assigns 10.0.0.1. For joining nodes, the peer +// exchange happens via HTTPS in the install CLI orchestrator. +func (ps *ProductionSetup) Phase6SetupWireGuard(isFirstNode bool) (privateKey, publicKey string, err error) { + ps.logf("Phase 6a: Setting up WireGuard...") + + wp := NewWireGuardProvisioner(WireGuardConfig{}) + + // Install WireGuard package + if err := wp.Install(); err != nil { + return "", "", fmt.Errorf("failed to install wireguard: %w", err) + } + ps.logf(" ✓ WireGuard installed") + + // Generate keypair + privKey, pubKey, err := GenerateKeyPair() + if err != nil { + return "", "", fmt.Errorf("failed to generate WG keys: %w", err) + } + ps.logf(" ✓ WireGuard keypair generated") + + if isFirstNode { + // First node: self-assign 10.0.0.1, no peers yet + wp.config = WireGuardConfig{ + PrivateKey: privKey, + PrivateIP: "10.0.0.1", + ListenPort: 51820, + } + if err := wp.WriteConfig(); err != nil { + return "", "", fmt.Errorf("failed to write WG config: %w", err) + } + if err := wp.Enable(); err != nil { + return "", "", fmt.Errorf("failed to enable WG: %w", err) + } + ps.logf(" ✓ WireGuard enabled (first node: 10.0.0.1)") + } + + return privKey, pubKey, nil +} + +// Phase6bSetupFirewall sets up UFW firewall rules +func (ps *ProductionSetup) Phase6bSetupFirewall(skipFirewall bool) error { + if skipFirewall { + ps.logf("Phase 6b: Skipping firewall setup (--skip-firewall)") + return nil + } + + ps.logf("Phase 6b: Setting up UFW firewall...") + + anyoneORPort := 0 + if ps.IsAnyoneRelay() && ps.anyoneRelayConfig != nil { + anyoneORPort = ps.anyoneRelayConfig.ORPort + } + + fp := NewFirewallProvisioner(FirewallConfig{ + SSHPort: 22, + IsNameserver: ps.isNameserver, + AnyoneORPort: anyoneORPort, + WireGuardPort: 51820, + }) + + if err := fp.Setup(); err != nil { + return fmt.Errorf("firewall setup failed: %w", err) + } + + ps.logf(" ✓ UFW firewall configured and enabled") + return nil +} + +// EnableWireGuardWithPeers writes WG config with assigned IP and peers, then enables it. +// Called by joining nodes after peer exchange. +func (ps *ProductionSetup) EnableWireGuardWithPeers(privateKey, assignedIP string, peers []WireGuardPeer) error { + wp := NewWireGuardProvisioner(WireGuardConfig{ + PrivateKey: privateKey, + PrivateIP: assignedIP, + ListenPort: 51820, + Peers: peers, + }) + + if err := wp.WriteConfig(); err != nil { + return fmt.Errorf("failed to write WG config: %w", err) + } + if err := wp.Enable(); err != nil { + return fmt.Errorf("failed to enable WG: %w", err) + } + + ps.logf(" ✓ WireGuard enabled (IP: %s, peers: %d)", assignedIP, len(peers)) + return nil +} + // LogSetupComplete logs completion information func (ps *ProductionSetup) LogSetupComplete(peerID string) { ps.logf("\n" + strings.Repeat("=", 70)) diff --git a/pkg/environments/production/provisioner.go b/pkg/environments/production/provisioner.go index 8d90f65..33b534c 100644 --- a/pkg/environments/production/provisioner.go +++ b/pkg/environments/production/provisioner.go @@ -224,6 +224,34 @@ debros ALL=(ALL) NOPASSWD: /bin/rm -f /etc/systemd/system/orama-deploy-*.service return nil } +// SetupWireGuardSudoers configures the debros user with permissions to manage WireGuard +func (up *UserProvisioner) SetupWireGuardSudoers() error { + sudoersFile := "/etc/sudoers.d/debros-wireguard" + + sudoersContent := `# DeBros Network - WireGuard Management Permissions +# Allows debros user to manage WireGuard peers + +debros ALL=(ALL) NOPASSWD: /usr/bin/wg set wg0 * +debros ALL=(ALL) NOPASSWD: /usr/bin/wg show wg0 +debros ALL=(ALL) NOPASSWD: /usr/bin/wg showconf wg0 +debros ALL=(ALL) NOPASSWD: /usr/bin/tee /etc/wireguard/wg0.conf +` + + // Write sudoers rule (always overwrite to ensure latest) + if err := os.WriteFile(sudoersFile, []byte(sudoersContent), 0440); err != nil { + return fmt.Errorf("failed to create wireguard sudoers rule: %w", err) + } + + // Validate sudoers file + cmd := exec.Command("visudo", "-c", "-f", sudoersFile) + if err := cmd.Run(); err != nil { + os.Remove(sudoersFile) + return fmt.Errorf("wireguard sudoers rule validation failed: %w", err) + } + + return nil +} + // StateDetector checks for existing production state type StateDetector struct { oramaDir string diff --git a/pkg/environments/production/wireguard.go b/pkg/environments/production/wireguard.go new file mode 100644 index 0000000..607ed3d --- /dev/null +++ b/pkg/environments/production/wireguard.go @@ -0,0 +1,228 @@ +package production + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + + "golang.org/x/crypto/curve25519" +) + +// WireGuardPeer represents a WireGuard mesh peer +type WireGuardPeer struct { + PublicKey string // Base64-encoded public key + Endpoint string // e.g., "141.227.165.154:51820" + AllowedIP string // e.g., "10.0.0.2/32" +} + +// WireGuardConfig holds the configuration for a WireGuard interface +type WireGuardConfig struct { + PrivateIP string // e.g., "10.0.0.1" + ListenPort int // default 51820 + PrivateKey string // Base64-encoded private key + Peers []WireGuardPeer // Known peers +} + +// WireGuardProvisioner manages WireGuard VPN setup +type WireGuardProvisioner struct { + configDir string // /etc/wireguard + config WireGuardConfig +} + +// NewWireGuardProvisioner creates a new WireGuard provisioner +func NewWireGuardProvisioner(config WireGuardConfig) *WireGuardProvisioner { + if config.ListenPort == 0 { + config.ListenPort = 51820 + } + return &WireGuardProvisioner{ + configDir: "/etc/wireguard", + config: config, + } +} + +// IsInstalled checks if WireGuard tools are available +func (wp *WireGuardProvisioner) IsInstalled() bool { + _, err := exec.LookPath("wg") + return err == nil +} + +// Install installs the WireGuard package +func (wp *WireGuardProvisioner) Install() error { + if wp.IsInstalled() { + return nil + } + + cmd := exec.Command("apt-get", "install", "-y", "wireguard", "wireguard-tools") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to install wireguard: %w\n%s", err, string(output)) + } + + return nil +} + +// GenerateKeyPair generates a new WireGuard private/public key pair +func GenerateKeyPair() (privateKey, publicKey string, err error) { + // Generate 32 random bytes for private key + var privBytes [32]byte + if _, err := rand.Read(privBytes[:]); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Clamp private key per Curve25519 spec + privBytes[0] &= 248 + privBytes[31] &= 127 + privBytes[31] |= 64 + + // Derive public key + var pubBytes [32]byte + curve25519.ScalarBaseMult(&pubBytes, &privBytes) + + privateKey = base64.StdEncoding.EncodeToString(privBytes[:]) + publicKey = base64.StdEncoding.EncodeToString(pubBytes[:]) + return privateKey, publicKey, nil +} + +// PublicKeyFromPrivate derives the public key from a private key +func PublicKeyFromPrivate(privateKey string) (string, error) { + privBytes, err := base64.StdEncoding.DecodeString(privateKey) + if err != nil { + return "", fmt.Errorf("failed to decode private key: %w", err) + } + if len(privBytes) != 32 { + return "", fmt.Errorf("invalid private key length: %d", len(privBytes)) + } + + var priv, pub [32]byte + copy(priv[:], privBytes) + curve25519.ScalarBaseMult(&pub, &priv) + + return base64.StdEncoding.EncodeToString(pub[:]), nil +} + +// GenerateConfig returns the wg0.conf file content +func (wp *WireGuardProvisioner) GenerateConfig() string { + var sb strings.Builder + + sb.WriteString("# WireGuard mesh configuration (managed by Orama Network)\n") + sb.WriteString("# Do not edit manually — use orama CLI to manage peers\n\n") + sb.WriteString("[Interface]\n") + sb.WriteString(fmt.Sprintf("PrivateKey = %s\n", wp.config.PrivateKey)) + sb.WriteString(fmt.Sprintf("Address = %s/24\n", wp.config.PrivateIP)) + sb.WriteString(fmt.Sprintf("ListenPort = %d\n", wp.config.ListenPort)) + + for _, peer := range wp.config.Peers { + sb.WriteString("\n[Peer]\n") + sb.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey)) + if peer.Endpoint != "" { + sb.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint)) + } + sb.WriteString(fmt.Sprintf("AllowedIPs = %s\n", peer.AllowedIP)) + sb.WriteString("PersistentKeepalive = 25\n") + } + + return sb.String() +} + +// WriteConfig writes the WireGuard config to /etc/wireguard/wg0.conf +func (wp *WireGuardProvisioner) WriteConfig() error { + confPath := filepath.Join(wp.configDir, "wg0.conf") + content := wp.GenerateConfig() + + // Try direct write first (works when running as root) + if err := os.MkdirAll(wp.configDir, 0700); err == nil { + if err := os.WriteFile(confPath, []byte(content), 0600); err == nil { + return nil + } + } + + // Fallback to sudo tee (for non-root, e.g. debros user) + cmd := exec.Command("sudo", "tee", confPath) + cmd.Stdin = strings.NewReader(content) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to write wg0.conf via sudo: %w\n%s", err, string(output)) + } + + return nil +} + +// Enable starts and enables the WireGuard interface +func (wp *WireGuardProvisioner) Enable() error { + // Enable on boot + cmd := exec.Command("systemctl", "enable", "wg-quick@wg0") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to enable wg-quick@wg0: %w\n%s", err, string(output)) + } + + // Start now + cmd = exec.Command("systemctl", "start", "wg-quick@wg0") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to start wg-quick@wg0: %w\n%s", err, string(output)) + } + + return nil +} + +// Restart restarts the WireGuard interface +func (wp *WireGuardProvisioner) Restart() error { + cmd := exec.Command("systemctl", "restart", "wg-quick@wg0") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to restart wg-quick@wg0: %w\n%s", err, string(output)) + } + return nil +} + +// IsActive checks if the WireGuard interface is up +func (wp *WireGuardProvisioner) IsActive() bool { + cmd := exec.Command("systemctl", "is-active", "--quiet", "wg-quick@wg0") + return cmd.Run() == nil +} + +// AddPeer adds a peer to the running WireGuard interface without restart +func (wp *WireGuardProvisioner) AddPeer(peer WireGuardPeer) error { + // Add peer to running interface + args := []string{"wg", "set", "wg0", "peer", peer.PublicKey, "allowed-ips", peer.AllowedIP, "persistent-keepalive", "25"} + if peer.Endpoint != "" { + args = append(args, "endpoint", peer.Endpoint) + } + + cmd := exec.Command("sudo", args...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add peer %s: %w\n%s", peer.AllowedIP, err, string(output)) + } + + // Also update config file so it persists across restarts + wp.config.Peers = append(wp.config.Peers, peer) + return wp.WriteConfig() +} + +// RemovePeer removes a peer from the running WireGuard interface +func (wp *WireGuardProvisioner) RemovePeer(publicKey string) error { + cmd := exec.Command("sudo", "wg", "set", "wg0", "peer", publicKey, "remove") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to remove peer: %w\n%s", err, string(output)) + } + + // Remove from config + filtered := make([]WireGuardPeer, 0, len(wp.config.Peers)) + for _, p := range wp.config.Peers { + if p.PublicKey != publicKey { + filtered = append(filtered, p) + } + } + wp.config.Peers = filtered + return wp.WriteConfig() +} + +// GetStatus returns the current WireGuard interface status +func (wp *WireGuardProvisioner) GetStatus() (string, error) { + cmd := exec.Command("wg", "show", "wg0") + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to get wg status: %w\n%s", err, string(output)) + } + return string(output), nil +} diff --git a/pkg/environments/production/wireguard_test.go b/pkg/environments/production/wireguard_test.go new file mode 100644 index 0000000..193abdd --- /dev/null +++ b/pkg/environments/production/wireguard_test.go @@ -0,0 +1,167 @@ +package production + +import ( + "encoding/base64" + "strings" + "testing" +) + +func TestGenerateKeyPair(t *testing.T) { + priv, pub, err := GenerateKeyPair() + if err != nil { + t.Fatalf("GenerateKeyPair failed: %v", err) + } + + // Keys should be base64, 44 chars (32 bytes + padding) + if len(priv) != 44 { + t.Errorf("private key length = %d, want 44", len(priv)) + } + if len(pub) != 44 { + t.Errorf("public key length = %d, want 44", len(pub)) + } + + // Should be valid base64 + if _, err := base64.StdEncoding.DecodeString(priv); err != nil { + t.Errorf("private key is not valid base64: %v", err) + } + if _, err := base64.StdEncoding.DecodeString(pub); err != nil { + t.Errorf("public key is not valid base64: %v", err) + } + + // Private and public should differ + if priv == pub { + t.Error("private and public keys should differ") + } +} + +func TestGenerateKeyPair_Unique(t *testing.T) { + priv1, _, _ := GenerateKeyPair() + priv2, _, _ := GenerateKeyPair() + + if priv1 == priv2 { + t.Error("two generated key pairs should be unique") + } +} + +func TestPublicKeyFromPrivate(t *testing.T) { + priv, expectedPub, err := GenerateKeyPair() + if err != nil { + t.Fatalf("GenerateKeyPair failed: %v", err) + } + + pub, err := PublicKeyFromPrivate(priv) + if err != nil { + t.Fatalf("PublicKeyFromPrivate failed: %v", err) + } + + if pub != expectedPub { + t.Errorf("PublicKeyFromPrivate = %s, want %s", pub, expectedPub) + } +} + +func TestPublicKeyFromPrivate_InvalidKey(t *testing.T) { + _, err := PublicKeyFromPrivate("not-valid-base64!!!") + if err == nil { + t.Error("expected error for invalid base64") + } + + _, err = PublicKeyFromPrivate(base64.StdEncoding.EncodeToString([]byte("short"))) + if err == nil { + t.Error("expected error for short key") + } +} + +func TestWireGuardProvisioner_GenerateConfig_NoPeers(t *testing.T) { + wp := NewWireGuardProvisioner(WireGuardConfig{ + PrivateIP: "10.0.0.1", + ListenPort: 51820, + PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=", + }) + + config := wp.GenerateConfig() + + if !strings.Contains(config, "[Interface]") { + t.Error("config should contain [Interface] section") + } + if !strings.Contains(config, "Address = 10.0.0.1/24") { + t.Error("config should contain correct Address") + } + if !strings.Contains(config, "ListenPort = 51820") { + t.Error("config should contain ListenPort") + } + if !strings.Contains(config, "PrivateKey = dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=") { + t.Error("config should contain PrivateKey") + } + if strings.Contains(config, "[Peer]") { + t.Error("config should NOT contain [Peer] section with no peers") + } +} + +func TestWireGuardProvisioner_GenerateConfig_WithPeers(t *testing.T) { + wp := NewWireGuardProvisioner(WireGuardConfig{ + PrivateIP: "10.0.0.1", + ListenPort: 51820, + PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=", + Peers: []WireGuardPeer{ + { + PublicKey: "cGVlcjFwdWJsaWNrZXlwZWVyMXB1YmxpY2tleXM=", + Endpoint: "1.2.3.4:51820", + AllowedIP: "10.0.0.2/32", + }, + { + PublicKey: "cGVlcjJwdWJsaWNrZXlwZWVyMnB1YmxpY2tleXM=", + Endpoint: "5.6.7.8:51820", + AllowedIP: "10.0.0.3/32", + }, + }, + }) + + config := wp.GenerateConfig() + + if strings.Count(config, "[Peer]") != 2 { + t.Errorf("expected 2 [Peer] sections, got %d", strings.Count(config, "[Peer]")) + } + if !strings.Contains(config, "Endpoint = 1.2.3.4:51820") { + t.Error("config should contain first peer endpoint") + } + if !strings.Contains(config, "AllowedIPs = 10.0.0.2/32") { + t.Error("config should contain first peer AllowedIPs") + } + if !strings.Contains(config, "PersistentKeepalive = 25") { + t.Error("config should contain PersistentKeepalive") + } + if !strings.Contains(config, "Endpoint = 5.6.7.8:51820") { + t.Error("config should contain second peer endpoint") + } +} + +func TestWireGuardProvisioner_GenerateConfig_PeerWithoutEndpoint(t *testing.T) { + wp := NewWireGuardProvisioner(WireGuardConfig{ + PrivateIP: "10.0.0.1", + ListenPort: 51820, + PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=", + Peers: []WireGuardPeer{ + { + PublicKey: "cGVlcjFwdWJsaWNrZXlwZWVyMXB1YmxpY2tleXM=", + AllowedIP: "10.0.0.2/32", + }, + }, + }) + + config := wp.GenerateConfig() + + if strings.Contains(config, "Endpoint") { + t.Error("config should NOT contain Endpoint when peer has none") + } +} + +func TestWireGuardProvisioner_DefaultPort(t *testing.T) { + wp := NewWireGuardProvisioner(WireGuardConfig{ + PrivateIP: "10.0.0.1", + PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=", + }) + + if wp.config.ListenPort != 51820 { + t.Errorf("default ListenPort = %d, want 51820", wp.config.ListenPort) + } +} diff --git a/pkg/gateway/config.go b/pkg/gateway/config.go index ac76609..ecdcd30 100644 --- a/pkg/gateway/config.go +++ b/pkg/gateway/config.go @@ -34,4 +34,7 @@ type Config struct { IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s) IPFSReplicationFactor int // Replication factor for pins (default: 3) IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs) + + // WireGuard mesh configuration + ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange } diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 6f23547..6acadb0 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -26,6 +26,8 @@ import ( deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments" pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub" serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" + joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join" + wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard" sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite" "github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage" "github.com/DeBrosOfficial/network/pkg/ipfs" @@ -98,6 +100,15 @@ type Gateway struct { processManager *process.Manager healthChecker *health.HealthChecker + // Rate limiter + rateLimiter *RateLimiter + + // WireGuard peer exchange + wireguardHandler *wireguardhandlers.Handler + + // Node join handler + joinHandler *joinhandlers.Handler + // Cluster provisioning for namespace clusters clusterProvisioner authhandlers.ClusterProvisioner } @@ -246,6 +257,16 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { ) } + // Initialize rate limiter (300 req/min, burst 50) + gw.rateLimiter = NewRateLimiter(300, 50) + gw.rateLimiter.StartCleanup(5*time.Minute, 10*time.Minute) + + // Initialize WireGuard peer exchange handler + if deps.ORMClient != nil { + gw.wireguardHandler = wireguardhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.ClusterSecret) + gw.joinHandler = joinhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.DataDir) + } + // Initialize deployment system if deps.ORMClient != nil && deps.IPFSClient != nil { // Convert rqlite.Client to database.Database interface for health checker diff --git a/pkg/gateway/handlers/deployments/service.go b/pkg/gateway/handlers/deployments/service.go index 9c071fe..0388cf9 100644 --- a/pkg/gateway/handlers/deployments/service.go +++ b/pkg/gateway/handlers/deployments/service.go @@ -643,8 +643,8 @@ func (s *DeploymentService) getNodeIP(ctx context.Context, nodeID string) (strin var rows []nodeRow - // Try full node ID first - query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1` + // Try full node ID first (prefer internal/WG IP for cross-node communication) + query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1` err := s.db.Query(ctx, &rows, query, nodeID) if err != nil { return "", err diff --git a/pkg/gateway/handlers/join/handler.go b/pkg/gateway/handlers/join/handler.go new file mode 100644 index 0000000..792700e --- /dev/null +++ b/pkg/gateway/handlers/join/handler.go @@ -0,0 +1,424 @@ +package join + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "os/exec" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/rqlite" + "go.uber.org/zap" +) + +// JoinRequest is the request body for node join +type JoinRequest struct { + Token string `json:"token"` + WGPublicKey string `json:"wg_public_key"` + PublicIP string `json:"public_ip"` +} + +// JoinResponse contains everything a joining node needs +type JoinResponse struct { + // WireGuard + WGIP string `json:"wg_ip"` + WGPeers []WGPeerInfo `json:"wg_peers"` + + // Secrets + ClusterSecret string `json:"cluster_secret"` + SwarmKey string `json:"swarm_key"` + + // Cluster join info (all using WG IPs) + RQLiteJoinAddress string `json:"rqlite_join_address"` + IPFSPeer PeerInfo `json:"ipfs_peer"` + IPFSClusterPeer PeerInfo `json:"ipfs_cluster_peer"` + BootstrapPeers []string `json:"bootstrap_peers"` + + // Domain + BaseDomain string `json:"base_domain"` +} + +// WGPeerInfo represents a WireGuard peer +type WGPeerInfo struct { + PublicKey string `json:"public_key"` + Endpoint string `json:"endpoint"` + AllowedIP string `json:"allowed_ip"` +} + +// PeerInfo represents an IPFS/Cluster peer +type PeerInfo struct { + ID string `json:"id"` + Addrs []string `json:"addrs"` +} + +// Handler handles the node join endpoint +type Handler struct { + logger *zap.Logger + rqliteClient rqlite.Client + oramaDir string // e.g., /home/debros/.orama +} + +// NewHandler creates a new join handler +func NewHandler(logger *zap.Logger, rqliteClient rqlite.Client, oramaDir string) *Handler { + return &Handler{ + logger: logger, + rqliteClient: rqliteClient, + oramaDir: oramaDir, + } +} + +// HandleJoin handles POST /v1/internal/join +func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req JoinRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + if req.Token == "" || req.WGPublicKey == "" || req.PublicIP == "" { + http.Error(w, "token, wg_public_key, and public_ip are required", http.StatusBadRequest) + return + } + + ctx := r.Context() + + // 1. Validate and consume the invite token (atomic single-use) + if err := h.consumeToken(ctx, req.Token, req.PublicIP); err != nil { + h.logger.Warn("join token validation failed", zap.Error(err)) + http.Error(w, "unauthorized: invalid or expired token", http.StatusUnauthorized) + return + } + + // 2. Assign WG IP with retry on conflict + wgIP, err := h.assignWGIP(ctx) + if err != nil { + h.logger.Error("failed to assign WG IP", zap.Error(err)) + http.Error(w, "failed to assign WG IP", http.StatusInternalServerError) + return + } + + // 3. Register WG peer in database + nodeID := fmt.Sprintf("node-%s", wgIP) // temporary ID based on WG IP + _, err = h.rqliteClient.Exec(ctx, + "INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)", + nodeID, wgIP, req.WGPublicKey, req.PublicIP, 51820) + if err != nil { + h.logger.Error("failed to register WG peer", zap.Error(err)) + http.Error(w, "failed to register peer", http.StatusInternalServerError) + return + } + + // 4. Add peer to local WireGuard interface immediately + if err := h.addWGPeerLocally(req.WGPublicKey, req.PublicIP, wgIP); err != nil { + h.logger.Warn("failed to add WG peer to local interface", zap.Error(err)) + // Non-fatal: the sync loop will pick it up + } + + // 5. Read secrets from disk + clusterSecret, err := os.ReadFile(h.oramaDir + "/secrets/cluster-secret") + if err != nil { + h.logger.Error("failed to read cluster secret", zap.Error(err)) + http.Error(w, "internal error reading secrets", http.StatusInternalServerError) + return + } + + swarmKey, err := os.ReadFile(h.oramaDir + "/secrets/swarm.key") + if err != nil { + h.logger.Error("failed to read swarm key", zap.Error(err)) + http.Error(w, "internal error reading secrets", http.StatusInternalServerError) + return + } + + // 6. Get all WG peers + wgPeers, err := h.getWGPeers(ctx, req.WGPublicKey) + if err != nil { + h.logger.Error("failed to list WG peers", zap.Error(err)) + http.Error(w, "failed to list peers", http.StatusInternalServerError) + return + } + + // 7. Get this node's WG IP + myWGIP, err := h.getMyWGIP() + if err != nil { + h.logger.Error("failed to get local WG IP", zap.Error(err)) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + // 8. Query IPFS and IPFS Cluster peer info + ipfsPeer := h.queryIPFSPeerInfo(myWGIP) + ipfsClusterPeer := h.queryIPFSClusterPeerInfo(myWGIP) + + // 9. Get this node's libp2p peer ID for bootstrap peers + bootstrapPeers := h.buildBootstrapPeers(myWGIP, ipfsPeer.ID) + + // 10. Read base domain from config + baseDomain := h.readBaseDomain() + + resp := JoinResponse{ + WGIP: wgIP, + WGPeers: wgPeers, + ClusterSecret: strings.TrimSpace(string(clusterSecret)), + SwarmKey: strings.TrimSpace(string(swarmKey)), + RQLiteJoinAddress: fmt.Sprintf("%s:7001", myWGIP), + IPFSPeer: ipfsPeer, + IPFSClusterPeer: ipfsClusterPeer, + BootstrapPeers: bootstrapPeers, + BaseDomain: baseDomain, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + h.logger.Info("node joined cluster", + zap.String("wg_ip", wgIP), + zap.String("public_ip", req.PublicIP)) +} + +// consumeToken validates and marks an invite token as used (atomic single-use) +func (h *Handler) consumeToken(ctx context.Context, token, usedByIP string) error { + // Atomically mark as used — only succeeds if token exists, is unused, and not expired + result, err := h.rqliteClient.Exec(ctx, + "UPDATE invite_tokens SET used_at = datetime('now'), used_by_ip = ? WHERE token = ? AND used_at IS NULL AND expires_at > datetime('now')", + usedByIP, token) + if err != nil { + return fmt.Errorf("database error: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to check result: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("token invalid, expired, or already used") + } + + return nil +} + +// assignWGIP finds the next available 10.0.0.x IP, retrying on UNIQUE constraint violation +func (h *Handler) assignWGIP(ctx context.Context) (string, error) { + for attempt := 0; attempt < 3; attempt++ { + var result []struct { + MaxIP string `db:"max_ip"` + } + + err := h.rqliteClient.Query(ctx, &result, + "SELECT MAX(wg_ip) as max_ip FROM wireguard_peers") + if err != nil { + return "", fmt.Errorf("failed to query max WG IP: %w", err) + } + + if len(result) == 0 || result[0].MaxIP == "" { + return "10.0.0.2", nil // 10.0.0.1 is genesis + } + + maxIP := result[0].MaxIP + var a, b, c, d int + if _, err := fmt.Sscanf(maxIP, "%d.%d.%d.%d", &a, &b, &c, &d); err != nil { + return "", fmt.Errorf("failed to parse max WG IP %s: %w", maxIP, err) + } + + d++ + if d > 254 { + c++ + d = 1 + if c > 255 { + return "", fmt.Errorf("WireGuard IP space exhausted") + } + } + + nextIP := fmt.Sprintf("%d.%d.%d.%d", a, b, c, d) + return nextIP, nil + } + + return "", fmt.Errorf("failed to assign WG IP after retries") +} + +// addWGPeerLocally adds a peer to the local wg0 interface and persists to config +func (h *Handler) addWGPeerLocally(pubKey, publicIP, wgIP string) error { + // Add to running interface with persistent-keepalive + cmd := exec.Command("sudo", "wg", "set", "wg0", + "peer", pubKey, + "endpoint", fmt.Sprintf("%s:51820", publicIP), + "allowed-ips", fmt.Sprintf("%s/32", wgIP), + "persistent-keepalive", "25") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("wg set failed: %w\n%s", err, string(output)) + } + + // Persist to wg0.conf so peer survives wg-quick restart. + // Read current config, append peer section, write back. + confPath := "/etc/wireguard/wg0.conf" + data, err := os.ReadFile(confPath) + if err != nil { + h.logger.Warn("could not read wg0.conf for persistence", zap.Error(err)) + return nil // non-fatal: runtime peer is added + } + + // Check if peer already in config + if strings.Contains(string(data), pubKey) { + return nil // already persisted + } + + peerSection := fmt.Sprintf("\n[Peer]\nPublicKey = %s\nEndpoint = %s:51820\nAllowedIPs = %s/32\nPersistentKeepalive = 25\n", + pubKey, publicIP, wgIP) + + newConf := string(data) + peerSection + writeCmd := exec.Command("sudo", "tee", confPath) + writeCmd.Stdin = strings.NewReader(newConf) + if output, err := writeCmd.CombinedOutput(); err != nil { + h.logger.Warn("could not persist peer to wg0.conf", zap.Error(err), zap.String("output", string(output))) + } + + return nil +} + +// getWGPeers returns all WG peers except the requesting node +func (h *Handler) getWGPeers(ctx context.Context, excludePubKey string) ([]WGPeerInfo, error) { + type peerRow struct { + WGIP string `db:"wg_ip"` + PublicKey string `db:"public_key"` + PublicIP string `db:"public_ip"` + WGPort int `db:"wg_port"` + } + + var rows []peerRow + err := h.rqliteClient.Query(ctx, &rows, + "SELECT wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip") + if err != nil { + return nil, err + } + + var peers []WGPeerInfo + for _, row := range rows { + if row.PublicKey == excludePubKey { + continue // don't include the requesting node itself + } + port := row.WGPort + if port == 0 { + port = 51820 + } + peers = append(peers, WGPeerInfo{ + PublicKey: row.PublicKey, + Endpoint: fmt.Sprintf("%s:%d", row.PublicIP, port), + AllowedIP: fmt.Sprintf("%s/32", row.WGIP), + }) + } + + return peers, nil +} + +// getMyWGIP gets this node's WireGuard IP from the wg0 interface +func (h *Handler) getMyWGIP() (string, error) { + out, err := exec.Command("ip", "-4", "addr", "show", "wg0").CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to get wg0 info: %w", err) + } + + // Parse "inet 10.0.0.1/32" from output + for _, line := range strings.Split(string(out), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "inet ") { + parts := strings.Fields(line) + if len(parts) >= 2 { + ip := strings.Split(parts[1], "/")[0] + return ip, nil + } + } + } + + return "", fmt.Errorf("could not find wg0 IP address") +} + +// queryIPFSPeerInfo gets the local IPFS node's peer ID and builds addrs with WG IP +func (h *Handler) queryIPFSPeerInfo(myWGIP string) PeerInfo { + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post("http://localhost:4501/api/v0/id", "", nil) + if err != nil { + h.logger.Warn("failed to query IPFS peer info", zap.Error(err)) + return PeerInfo{} + } + defer resp.Body.Close() + + var result struct { + ID string `json:"ID"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + h.logger.Warn("failed to decode IPFS peer info", zap.Error(err)) + return PeerInfo{} + } + + return PeerInfo{ + ID: result.ID, + Addrs: []string{ + fmt.Sprintf("/ip4/%s/tcp/4101", myWGIP), + }, + } +} + +// queryIPFSClusterPeerInfo gets the local IPFS Cluster peer ID and builds addrs with WG IP +func (h *Handler) queryIPFSClusterPeerInfo(myWGIP string) PeerInfo { + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get("http://localhost:9094/id") + if err != nil { + h.logger.Warn("failed to query IPFS Cluster peer info", zap.Error(err)) + return PeerInfo{} + } + defer resp.Body.Close() + + var result struct { + ID string `json:"id"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + h.logger.Warn("failed to decode IPFS Cluster peer info", zap.Error(err)) + return PeerInfo{} + } + + return PeerInfo{ + ID: result.ID, + Addrs: []string{ + fmt.Sprintf("/ip4/%s/tcp/9100/p2p/%s", myWGIP, result.ID), + }, + } +} + +// buildBootstrapPeers constructs bootstrap peer multiaddrs using WG IPs +func (h *Handler) buildBootstrapPeers(myWGIP, ipfsPeerID string) []string { + if ipfsPeerID == "" { + return nil + } + return []string{ + fmt.Sprintf("/ip4/%s/tcp/4101/p2p/%s", myWGIP, ipfsPeerID), + } +} + +// readBaseDomain reads the base domain from node config +func (h *Handler) readBaseDomain() string { + data, err := os.ReadFile(h.oramaDir + "/configs/node.yaml") + if err != nil { + return "" + } + + // Simple parse — look for base_domain field + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "base_domain:") { + val := strings.TrimPrefix(line, "base_domain:") + val = strings.TrimSpace(val) + val = strings.Trim(val, `"'`) + return val + } + } + + return "" +} diff --git a/pkg/gateway/handlers/wireguard/handler.go b/pkg/gateway/handlers/wireguard/handler.go new file mode 100644 index 0000000..385f847 --- /dev/null +++ b/pkg/gateway/handlers/wireguard/handler.go @@ -0,0 +1,211 @@ +package wireguard + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/DeBrosOfficial/network/pkg/rqlite" + "go.uber.org/zap" +) + +// PeerRecord represents a WireGuard peer stored in RQLite +type PeerRecord struct { + NodeID string `json:"node_id" db:"node_id"` + WGIP string `json:"wg_ip" db:"wg_ip"` + PublicKey string `json:"public_key" db:"public_key"` + PublicIP string `json:"public_ip" db:"public_ip"` + WGPort int `json:"wg_port" db:"wg_port"` +} + +// RegisterPeerRequest is the request body for peer registration +type RegisterPeerRequest struct { + NodeID string `json:"node_id"` + PublicKey string `json:"public_key"` + PublicIP string `json:"public_ip"` + WGPort int `json:"wg_port,omitempty"` + ClusterSecret string `json:"cluster_secret"` +} + +// RegisterPeerResponse is the response for peer registration +type RegisterPeerResponse struct { + AssignedWGIP string `json:"assigned_wg_ip"` + Peers []PeerRecord `json:"peers"` +} + +// Handler handles WireGuard peer exchange endpoints +type Handler struct { + logger *zap.Logger + rqliteClient rqlite.Client + clusterSecret string // expected cluster secret for auth +} + +// NewHandler creates a new WireGuard handler +func NewHandler(logger *zap.Logger, rqliteClient rqlite.Client, clusterSecret string) *Handler { + return &Handler{ + logger: logger, + rqliteClient: rqliteClient, + clusterSecret: clusterSecret, + } +} + +// HandleRegisterPeer handles POST /v1/internal/wg/peer +// A new node calls this to register itself and get all existing peers. +func (h *Handler) HandleRegisterPeer(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req RegisterPeerRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + // Validate cluster secret + if h.clusterSecret != "" && req.ClusterSecret != h.clusterSecret { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + if req.NodeID == "" || req.PublicKey == "" || req.PublicIP == "" { + http.Error(w, "node_id, public_key, and public_ip are required", http.StatusBadRequest) + return + } + + if req.WGPort == 0 { + req.WGPort = 51820 + } + + ctx := r.Context() + + // Assign next available WG IP + wgIP, err := h.assignNextWGIP(ctx) + if err != nil { + h.logger.Error("failed to assign WG IP", zap.Error(err)) + http.Error(w, "failed to assign WG IP", http.StatusInternalServerError) + return + } + + // Insert peer record + _, err = h.rqliteClient.Exec(ctx, + "INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)", + req.NodeID, wgIP, req.PublicKey, req.PublicIP, req.WGPort) + if err != nil { + h.logger.Error("failed to insert WG peer", zap.Error(err)) + http.Error(w, "failed to register peer", http.StatusInternalServerError) + return + } + + // Get all peers (including the one just added) + peers, err := h.ListPeers(ctx) + if err != nil { + h.logger.Error("failed to list WG peers", zap.Error(err)) + http.Error(w, "failed to list peers", http.StatusInternalServerError) + return + } + + resp := RegisterPeerResponse{ + AssignedWGIP: wgIP, + Peers: peers, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + h.logger.Info("registered WireGuard peer", + zap.String("node_id", req.NodeID), + zap.String("wg_ip", wgIP), + zap.String("public_ip", req.PublicIP)) +} + +// HandleListPeers handles GET /v1/internal/wg/peers +func (h *Handler) HandleListPeers(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + peers, err := h.ListPeers(r.Context()) + if err != nil { + h.logger.Error("failed to list WG peers", zap.Error(err)) + http.Error(w, "failed to list peers", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(peers) +} + +// HandleRemovePeer handles DELETE /v1/internal/wg/peer?node_id=xxx +func (h *Handler) HandleRemovePeer(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + nodeID := r.URL.Query().Get("node_id") + if nodeID == "" { + http.Error(w, "node_id parameter required", http.StatusBadRequest) + return + } + + _, err := h.rqliteClient.Exec(r.Context(), + "DELETE FROM wireguard_peers WHERE node_id = ?", nodeID) + if err != nil { + h.logger.Error("failed to remove WG peer", zap.Error(err)) + http.Error(w, "failed to remove peer", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) + h.logger.Info("removed WireGuard peer", zap.String("node_id", nodeID)) +} + +// ListPeers returns all registered WireGuard peers +func (h *Handler) ListPeers(ctx context.Context) ([]PeerRecord, error) { + var peers []PeerRecord + err := h.rqliteClient.Query(ctx, &peers, + "SELECT node_id, wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip") + if err != nil { + return nil, fmt.Errorf("failed to query wireguard_peers: %w", err) + } + return peers, nil +} + +// assignNextWGIP finds the next available 10.0.0.x IP +func (h *Handler) assignNextWGIP(ctx context.Context) (string, error) { + var result []struct { + MaxIP string `db:"max_ip"` + } + + err := h.rqliteClient.Query(ctx, &result, + "SELECT MAX(wg_ip) as max_ip FROM wireguard_peers") + if err != nil { + return "", fmt.Errorf("failed to query max WG IP: %w", err) + } + + if len(result) == 0 || result[0].MaxIP == "" { + return "10.0.0.1", nil + } + + // Parse last octet and increment + maxIP := result[0].MaxIP + var a, b, c, d int + if _, err := fmt.Sscanf(maxIP, "%d.%d.%d.%d", &a, &b, &c, &d); err != nil { + return "", fmt.Errorf("failed to parse max WG IP %s: %w", maxIP, err) + } + + d++ + if d > 254 { + c++ + d = 1 + if c > 255 { + return "", fmt.Errorf("WireGuard IP space exhausted") + } + } + + return fmt.Sprintf("%d.%d.%d.%d", a, b, c, d), nil +} diff --git a/pkg/gateway/middleware.go b/pkg/gateway/middleware.go index 3d3b8ba..d2f0d98 100644 --- a/pkg/gateway/middleware.go +++ b/pkg/gateway/middleware.go @@ -19,15 +19,32 @@ import ( // Note: context keys (ctxKeyAPIKey, ctxKeyJWT, CtxKeyNamespaceOverride) are now defined in context.go -// withMiddleware adds CORS and logging middleware +// withMiddleware adds CORS, security headers, rate limiting, and logging middleware func (g *Gateway) withMiddleware(next http.Handler) http.Handler { - // Order: logging (outermost) -> CORS -> domain routing -> auth -> handler - // Domain routing must come BEFORE auth to handle deployment domains without auth + // Order: logging -> security headers -> rate limit -> CORS -> domain routing -> auth -> handler return g.loggingMiddleware( - g.corsMiddleware( - g.domainRoutingMiddleware( - g.authMiddleware( - g.authorizationMiddleware(next))))) + g.securityHeadersMiddleware( + g.rateLimitMiddleware( + g.corsMiddleware( + g.domainRoutingMiddleware( + g.authMiddleware( + g.authorizationMiddleware(next))))))) +} + +// securityHeadersMiddleware adds standard security headers to all responses +func (g *Gateway) securityHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("X-XSS-Protection", "0") + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") + // HSTS only when behind TLS (Caddy) + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + } + next.ServeHTTP(w, r) + }) } // loggingMiddleware logs basic request info and duration @@ -202,6 +219,16 @@ func isPublicPath(p string) bool { return true } + // WireGuard peer exchange (auth handled by cluster secret in handler) + if strings.HasPrefix(p, "/v1/internal/wg/") { + return true + } + + // Node join endpoint (auth handled by invite token in handler) + if p == "/v1/internal/join" { + return true + } + switch p { case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers", "/v1/internal/tls/check", "/v1/internal/acme/present", "/v1/internal/acme/cleanup": return true @@ -912,7 +939,7 @@ func (g *Gateway) proxyCrossNode(w http.ResponseWriter, r *http.Request, deploym db := g.client.Database() internalCtx := client.WithInternalAuth(r.Context()) - query := "SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1" + query := "SELECT COALESCE(internal_ip, ip_address) FROM dns_nodes WHERE id = ? LIMIT 1" result, err := db.Query(internalCtx, query, deployment.HomeNodeID) if err != nil || result == nil || len(result.Rows) == 0 { g.logger.Warn("Failed to get home node IP", diff --git a/pkg/gateway/rate_limiter.go b/pkg/gateway/rate_limiter.go new file mode 100644 index 0000000..f380a46 --- /dev/null +++ b/pkg/gateway/rate_limiter.go @@ -0,0 +1,129 @@ +package gateway + +import ( + "net" + "net/http" + "strings" + "sync" + "time" +) + +// RateLimiter implements a token-bucket rate limiter per client IP. +type RateLimiter struct { + mu sync.Mutex + clients map[string]*bucket + rate float64 // tokens per second + burst int // max tokens (burst capacity) +} + +type bucket struct { + tokens float64 + lastCheck time.Time +} + +// NewRateLimiter creates a rate limiter. ratePerMinute is the sustained rate; +// burst is the maximum number of requests that can be made in a short window. +func NewRateLimiter(ratePerMinute, burst int) *RateLimiter { + return &RateLimiter{ + clients: make(map[string]*bucket), + rate: float64(ratePerMinute) / 60.0, + burst: burst, + } +} + +// Allow checks if a request from the given IP should be allowed. +func (rl *RateLimiter) Allow(ip string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + b, exists := rl.clients[ip] + if !exists { + rl.clients[ip] = &bucket{tokens: float64(rl.burst) - 1, lastCheck: now} + return true + } + + // Refill tokens based on elapsed time + elapsed := now.Sub(b.lastCheck).Seconds() + b.tokens += elapsed * rl.rate + if b.tokens > float64(rl.burst) { + b.tokens = float64(rl.burst) + } + b.lastCheck = now + + if b.tokens >= 1 { + b.tokens-- + return true + } + return false +} + +// Cleanup removes stale entries older than the given duration. +func (rl *RateLimiter) Cleanup(maxAge time.Duration) { + rl.mu.Lock() + defer rl.mu.Unlock() + + cutoff := time.Now().Add(-maxAge) + for ip, b := range rl.clients { + if b.lastCheck.Before(cutoff) { + delete(rl.clients, ip) + } + } +} + +// StartCleanup runs periodic cleanup in a goroutine. +func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) { + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for range ticker.C { + rl.Cleanup(maxAge) + } + }() +} + +// rateLimitMiddleware returns 429 when a client exceeds the rate limit. +// Internal traffic from the WireGuard subnet (10.0.0.0/8) is exempt. +func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler { + if g.rateLimiter == nil { + return next + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := getClientIP(r) + + // Exempt internal cluster traffic (WireGuard subnet) + if isInternalIP(ip) { + next.ServeHTTP(w, r) + return + } + + if !g.rateLimiter.Allow(ip) { + w.Header().Set("Retry-After", "5") + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} + +// isInternalIP returns true if the IP is in the WireGuard 10.0.0.0/8 subnet +// or is a loopback address. +func isInternalIP(ipStr string) bool { + // Strip port if present + if strings.Contains(ipStr, ":") { + host, _, err := net.SplitHostPort(ipStr) + if err == nil { + ipStr = host + } + } + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + if ip.IsLoopback() { + return true + } + // 10.0.0.0/8 — WireGuard mesh + _, wgNet, _ := net.ParseCIDR("10.0.0.0/8") + return wgNet.Contains(ip) +} diff --git a/pkg/gateway/rate_limiter_test.go b/pkg/gateway/rate_limiter_test.go new file mode 100644 index 0000000..168f6e8 --- /dev/null +++ b/pkg/gateway/rate_limiter_test.go @@ -0,0 +1,197 @@ +package gateway + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func TestRateLimiter_AllowsUnderLimit(t *testing.T) { + rl := NewRateLimiter(60, 10) // 1/sec, burst 10 + for i := 0; i < 10; i++ { + if !rl.Allow("1.2.3.4") { + t.Fatalf("request %d should be allowed (within burst)", i) + } + } +} + +func TestRateLimiter_BlocksOverLimit(t *testing.T) { + rl := NewRateLimiter(60, 5) // 1/sec, burst 5 + // Exhaust burst + for i := 0; i < 5; i++ { + rl.Allow("1.2.3.4") + } + if rl.Allow("1.2.3.4") { + t.Fatal("request after burst should be blocked") + } +} + +func TestRateLimiter_RefillsOverTime(t *testing.T) { + rl := NewRateLimiter(6000, 5) // 100/sec, burst 5 + // Exhaust burst + for i := 0; i < 5; i++ { + rl.Allow("1.2.3.4") + } + if rl.Allow("1.2.3.4") { + t.Fatal("should be blocked after burst") + } + // Wait for refill + time.Sleep(100 * time.Millisecond) + if !rl.Allow("1.2.3.4") { + t.Fatal("should be allowed after refill") + } +} + +func TestRateLimiter_PerIPIsolation(t *testing.T) { + rl := NewRateLimiter(60, 2) + // Exhaust IP A + rl.Allow("1.1.1.1") + rl.Allow("1.1.1.1") + if rl.Allow("1.1.1.1") { + t.Fatal("IP A should be blocked") + } + // IP B should still be allowed + if !rl.Allow("2.2.2.2") { + t.Fatal("IP B should be allowed") + } +} + +func TestRateLimiter_Cleanup(t *testing.T) { + rl := NewRateLimiter(60, 10) + rl.Allow("old-ip") + // Force the entry to be old + rl.mu.Lock() + rl.clients["old-ip"].lastCheck = time.Now().Add(-20 * time.Minute) + rl.mu.Unlock() + + rl.Cleanup(10 * time.Minute) + + rl.mu.Lock() + _, exists := rl.clients["old-ip"] + rl.mu.Unlock() + if exists { + t.Fatal("stale entry should have been cleaned up") + } +} + +func TestRateLimiter_ConcurrentAccess(t *testing.T) { + rl := NewRateLimiter(60000, 100) // high limit to avoid false failures + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + rl.Allow("concurrent-ip") + } + }() + } + wg.Wait() +} + +func TestIsInternalIP(t *testing.T) { + tests := []struct { + ip string + internal bool + }{ + {"10.0.0.1", true}, + {"10.0.0.254", true}, + {"10.255.255.255", true}, + {"127.0.0.1", true}, + {"192.168.1.1", false}, + {"8.8.8.8", false}, + {"141.227.165.168", false}, + } + for _, tt := range tests { + if got := isInternalIP(tt.ip); got != tt.internal { + t.Errorf("isInternalIP(%q) = %v, want %v", tt.ip, got, tt.internal) + } + } +} + +func TestSecurityHeaders(t *testing.T) { + gw := &Gateway{} + handler := gw.securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-Proto", "https") + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + expected := map[string]string{ + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "X-XSS-Protection": "0", + "Referrer-Policy": "strict-origin-when-cross-origin", + "Strict-Transport-Security": "max-age=31536000; includeSubDomains", + } + + for header, want := range expected { + if got := w.Header().Get(header); got != want { + t.Errorf("header %s = %q, want %q", header, got, want) + } + } +} + +func TestSecurityHeaders_NoHSTS_WithoutTLS(t *testing.T) { + gw := &Gateway{} + handler := gw.securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if got := w.Header().Get("Strict-Transport-Security"); got != "" { + t.Errorf("HSTS should not be set without TLS, got %q", got) + } +} + +func TestRateLimitMiddleware_Returns429(t *testing.T) { + gw := &Gateway{rateLimiter: NewRateLimiter(60, 1)} + handler := gw.rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should pass + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "8.8.8.8:1234" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("first request should be 200, got %d", w.Code) + } + + // Second request should be rate limited + w = httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("second request should be 429, got %d", w.Code) + } + if w.Header().Get("Retry-After") == "" { + t.Fatal("should have Retry-After header") + } +} + +func TestRateLimitMiddleware_ExemptsInternalTraffic(t *testing.T) { + gw := &Gateway{rateLimiter: NewRateLimiter(60, 1)} + handler := gw.rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Internal IP should never be rate limited + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:1234" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("internal request %d should be 200, got %d", i, w.Code) + } + } +} diff --git a/pkg/gateway/routes.go b/pkg/gateway/routes.go index 10aca3a..0864e2a 100644 --- a/pkg/gateway/routes.go +++ b/pkg/gateway/routes.go @@ -24,6 +24,18 @@ func (g *Gateway) Routes() http.Handler { mux.HandleFunc("/v1/internal/acme/present", g.acmePresentHandler) mux.HandleFunc("/v1/internal/acme/cleanup", g.acmeCleanupHandler) + // WireGuard peer exchange (internal, cluster-secret auth) + if g.wireguardHandler != nil { + mux.HandleFunc("/v1/internal/wg/peer", g.wireguardHandler.HandleRegisterPeer) + mux.HandleFunc("/v1/internal/wg/peers", g.wireguardHandler.HandleListPeers) + mux.HandleFunc("/v1/internal/wg/peer/remove", g.wireguardHandler.HandleRemovePeer) + } + + // Node join endpoint (token-authenticated, no middleware auth needed) + if g.joinHandler != nil { + mux.HandleFunc("/v1/internal/join", g.joinHandler.HandleJoin) + } + // auth endpoints mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler) mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler) diff --git a/pkg/namespace/cluster_manager.go b/pkg/namespace/cluster_manager.go index 0950f10..f8b777b 100644 --- a/pkg/namespace/cluster_manager.go +++ b/pkg/namespace/cluster_manager.go @@ -226,8 +226,8 @@ func (cm *ClusterManager) startRQLiteCluster(ctx context.Context, cluster *Names NodeID: nodes[0].NodeID, HTTPPort: portBlocks[0].RQLiteHTTPPort, RaftPort: portBlocks[0].RQLiteRaftPort, - HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[0].IPAddress, portBlocks[0].RQLiteHTTPPort), - RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[0].IPAddress, portBlocks[0].RQLiteRaftPort), + HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[0].InternalIP, portBlocks[0].RQLiteHTTPPort), + RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[0].InternalIP, portBlocks[0].RQLiteRaftPort), IsLeader: true, } @@ -254,8 +254,8 @@ func (cm *ClusterManager) startRQLiteCluster(ctx context.Context, cluster *Names NodeID: nodes[i].NodeID, HTTPPort: portBlocks[i].RQLiteHTTPPort, RaftPort: portBlocks[i].RQLiteRaftPort, - HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[i].IPAddress, portBlocks[i].RQLiteHTTPPort), - RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[i].IPAddress, portBlocks[i].RQLiteRaftPort), + HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[i].InternalIP, portBlocks[i].RQLiteHTTPPort), + RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[i].InternalIP, portBlocks[i].RQLiteRaftPort), JoinAddresses: []string{leaderRaftAddr}, IsLeader: false, } @@ -288,7 +288,7 @@ func (cm *ClusterManager) startOlricCluster(ctx context.Context, cluster *Namesp // Build peer addresses (all nodes) peerAddresses := make([]string, len(nodes)) for i, node := range nodes { - peerAddresses[i] = fmt.Sprintf("%s:%d", node.IPAddress, portBlocks[i].OlricMemberlistPort) + peerAddresses[i] = fmt.Sprintf("%s:%d", node.InternalIP, portBlocks[i].OlricMemberlistPort) } // Start all Olric instances @@ -299,7 +299,7 @@ func (cm *ClusterManager) startOlricCluster(ctx context.Context, cluster *Namesp HTTPPort: portBlocks[i].OlricHTTPPort, MemberlistPort: portBlocks[i].OlricMemberlistPort, BindAddr: "0.0.0.0", - AdvertiseAddr: node.IPAddress, + AdvertiseAddr: node.InternalIP, PeerAddresses: peerAddresses, } diff --git a/pkg/namespace/node_selector.go b/pkg/namespace/node_selector.go index 5314a13..013adff 100644 --- a/pkg/namespace/node_selector.go +++ b/pkg/namespace/node_selector.go @@ -23,6 +23,7 @@ type ClusterNodeSelector struct { type NodeCapacity struct { NodeID string `json:"node_id"` IPAddress string `json:"ip_address"` + InternalIP string `json:"internal_ip"` // WireGuard IP for inter-node communication DeploymentCount int `json:"deployment_count"` AllocatedPorts int `json:"allocated_ports"` AvailablePorts int `json:"available_ports"` @@ -59,7 +60,7 @@ func (cns *ClusterNodeSelector) SelectNodesForCluster(ctx context.Context, nodeC // Filter nodes that have capacity for namespace instances eligibleNodes := make([]NodeCapacity, 0) for _, node := range activeNodes { - capacity, err := cns.getNodeCapacity(internalCtx, node.NodeID, node.IPAddress) + capacity, err := cns.getNodeCapacity(internalCtx, node.NodeID, node.IPAddress, node.InternalIP) if err != nil { cns.logger.Warn("Failed to get node capacity, skipping", zap.String("node_id", node.NodeID), @@ -117,8 +118,9 @@ func (cns *ClusterNodeSelector) SelectNodesForCluster(ctx context.Context, nodeC // nodeInfo is used for querying active nodes type nodeInfo struct { - NodeID string `db:"id"` - IPAddress string `db:"ip_address"` + NodeID string `db:"id"` + IPAddress string `db:"ip_address"` + InternalIP string `db:"internal_ip"` } // getActiveNodes retrieves all active nodes from dns_nodes table @@ -128,7 +130,7 @@ func (cns *ClusterNodeSelector) getActiveNodes(ctx context.Context) ([]nodeInfo, var results []nodeInfo query := ` - SELECT id, ip_address FROM dns_nodes + SELECT id, ip_address, COALESCE(internal_ip, ip_address) as internal_ip FROM dns_nodes WHERE status = 'active' AND last_seen > ? ORDER BY id ` @@ -148,7 +150,7 @@ func (cns *ClusterNodeSelector) getActiveNodes(ctx context.Context) ([]nodeInfo, } // getNodeCapacity calculates capacity metrics for a single node -func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipAddress string) (*NodeCapacity, error) { +func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipAddress, internalIP string) (*NodeCapacity, error) { // Get deployment count deploymentCount, err := cns.getDeploymentCount(ctx, nodeID) if err != nil { @@ -209,6 +211,7 @@ func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipA capacity := &NodeCapacity{ NodeID: nodeID, IPAddress: ipAddress, + InternalIP: internalIP, DeploymentCount: deploymentCount, AllocatedPorts: allocatedPorts, AvailablePorts: availablePorts, @@ -365,7 +368,7 @@ func (cns *ClusterNodeSelector) GetNodeByID(ctx context.Context, nodeID string) internalCtx := client.WithInternalAuth(ctx) var results []nodeInfo - query := `SELECT id, ip_address FROM dns_nodes WHERE id = ? LIMIT 1` + query := `SELECT id, ip_address, COALESCE(internal_ip, ip_address) as internal_ip FROM dns_nodes WHERE id = ? LIMIT 1` err := cns.db.Query(internalCtx, &results, query, nodeID) if err != nil { return nil, &ClusterError{ diff --git a/pkg/node/dns_registration.go b/pkg/node/dns_registration.go index 28555d8..9d93ee1 100644 --- a/pkg/node/dns_registration.go +++ b/pkg/node/dns_registration.go @@ -29,8 +29,11 @@ func (n *Node) registerDNSNode(ctx context.Context) error { ipAddress = "127.0.0.1" } - // Get internal IP (same as external for now, or could use private network IP) + // Get internal IP from WireGuard interface (for cross-node communication over VPN) internalIP := ipAddress + if wgIP, err := n.getWireGuardIP(); err == nil && wgIP != "" { + internalIP = wgIP + } // Determine region (defaulting to "local" for now, could be from cloud metadata in future) region := "local" @@ -297,6 +300,24 @@ func (n *Node) cleanupStaleNodeRecords(ctx context.Context) { } } +// getWireGuardIP returns the IPv4 address assigned to the wg0 interface, if any +func (n *Node) getWireGuardIP() (string, error) { + iface, err := net.InterfaceByName("wg0") + if err != nil { + return "", err + } + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil { + return ipnet.IP.String(), nil + } + } + return "", fmt.Errorf("no IPv4 address on wg0") +} + // getNodeIPAddress attempts to determine the node's external IP address func (n *Node) getNodeIPAddress() (string, error) { // Try to detect external IP by connecting to a public server diff --git a/pkg/node/gateway.go b/pkg/node/gateway.go index 3d4100a..4052a4d 100644 --- a/pkg/node/gateway.go +++ b/pkg/node/gateway.go @@ -33,6 +33,15 @@ func (n *Node) startHTTPGateway(ctx context.Context) error { return err } + // DataDir in node config is ~/.orama/data; the orama dir is the parent + oramaDir := filepath.Join(os.ExpandEnv(n.config.Node.DataDir), "..") + + // Read cluster secret for WireGuard peer exchange auth + clusterSecret := "" + if secretBytes, err := os.ReadFile(filepath.Join(oramaDir, "secrets", "cluster-secret")); err == nil { + clusterSecret = string(secretBytes) + } + gwCfg := &gateway.Config{ ListenAddr: n.config.HTTPGateway.ListenAddr, ClientNamespace: n.config.HTTPGateway.ClientNamespace, @@ -45,6 +54,8 @@ func (n *Node) startHTTPGateway(ctx context.Context) error { IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL, IPFSTimeout: n.config.HTTPGateway.IPFSTimeout, BaseDomain: n.config.HTTPGateway.BaseDomain, + DataDir: oramaDir, + ClusterSecret: clusterSecret, } apiGateway, err := gateway.New(gatewayLogger, gwCfg) diff --git a/pkg/node/node.go b/pkg/node/node.go index 3daf968..7e801eb 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -103,6 +103,9 @@ func (n *Node) Start(ctx context.Context) error { return fmt.Errorf("failed to start RQLite: %w", err) } + // Sync WireGuard peers from RQLite (if WG is active on this node) + n.startWireGuardSyncLoop(ctx) + // Register this node in dns_nodes table for deployment routing if err := n.registerDNSNode(ctx); err != nil { n.logger.ComponentWarn(logging.ComponentNode, "Failed to register DNS node", zap.Error(err)) diff --git a/pkg/node/wireguard_sync.go b/pkg/node/wireguard_sync.go new file mode 100644 index 0000000..a984666 --- /dev/null +++ b/pkg/node/wireguard_sync.go @@ -0,0 +1,223 @@ +package node + +import ( + "context" + "fmt" + "net" + "os/exec" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/environments/production" + "github.com/DeBrosOfficial/network/pkg/logging" + "go.uber.org/zap" +) + +// syncWireGuardPeers reads all peers from RQLite and reconciles the local +// WireGuard interface so it matches the cluster state. This is called on +// startup after RQLite is ready and periodically thereafter. +func (n *Node) syncWireGuardPeers(ctx context.Context) error { + if n.rqliteAdapter == nil { + return fmt.Errorf("rqlite adapter not initialized") + } + + // Check if WireGuard is installed and active + if _, err := exec.LookPath("wg"); err != nil { + n.logger.ComponentInfo(logging.ComponentNode, "WireGuard not installed, skipping peer sync") + return nil + } + + // Check if wg0 interface exists + out, err := exec.CommandContext(ctx, "sudo", "wg", "show", "wg0").CombinedOutput() + if err != nil { + n.logger.ComponentInfo(logging.ComponentNode, "WireGuard interface wg0 not active, skipping peer sync") + return nil + } + + // Parse current peers from wg show output + currentPeers := parseWGShowPeers(string(out)) + localPubKey := parseWGShowLocalKey(string(out)) + + // Query all peers from RQLite + db := n.rqliteAdapter.GetSQLDB() + rows, err := db.QueryContext(ctx, + "SELECT node_id, wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip") + if err != nil { + return fmt.Errorf("failed to query wireguard_peers: %w", err) + } + defer rows.Close() + + // Build desired peer set (excluding self) + desiredPeers := make(map[string]production.WireGuardPeer) + for rows.Next() { + var nodeID, wgIP, pubKey, pubIP string + var wgPort int + if err := rows.Scan(&nodeID, &wgIP, &pubKey, &pubIP, &wgPort); err != nil { + continue + } + if pubKey == localPubKey { + continue // skip self + } + if wgPort == 0 { + wgPort = 51820 + } + desiredPeers[pubKey] = production.WireGuardPeer{ + PublicKey: pubKey, + Endpoint: fmt.Sprintf("%s:%d", pubIP, wgPort), + AllowedIP: wgIP + "/32", + } + } + + wp := &production.WireGuardProvisioner{} + + // Add missing peers + for pubKey, peer := range desiredPeers { + if _, exists := currentPeers[pubKey]; !exists { + if err := wp.AddPeer(peer); err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "failed to add WG peer", + zap.String("public_key", pubKey[:8]+"..."), + zap.Error(err)) + } else { + n.logger.ComponentInfo(logging.ComponentNode, "added WG peer", + zap.String("allowed_ip", peer.AllowedIP)) + } + } + } + + // Remove peers not in the desired set + for pubKey := range currentPeers { + if _, exists := desiredPeers[pubKey]; !exists { + if err := wp.RemovePeer(pubKey); err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "failed to remove stale WG peer", + zap.String("public_key", pubKey[:8]+"..."), + zap.Error(err)) + } else { + n.logger.ComponentInfo(logging.ComponentNode, "removed stale WG peer", + zap.String("public_key", pubKey[:8]+"...")) + } + } + } + + n.logger.ComponentInfo(logging.ComponentNode, "WireGuard peer sync completed", + zap.Int("desired_peers", len(desiredPeers)), + zap.Int("current_peers", len(currentPeers))) + + return nil +} + +// ensureWireGuardSelfRegistered ensures this node's WireGuard info is in the +// wireguard_peers table. Without this, joining nodes get an empty peer list +// from the /v1/internal/join endpoint and can't establish WG tunnels. +func (n *Node) ensureWireGuardSelfRegistered(ctx context.Context) { + if n.rqliteAdapter == nil { + return + } + + // Check if wg0 is active + out, err := exec.CommandContext(ctx, "sudo", "wg", "show", "wg0").CombinedOutput() + if err != nil { + return // WG not active, nothing to register + } + + // Get local public key + localPubKey := parseWGShowLocalKey(string(out)) + if localPubKey == "" { + return + } + + // Get WG IP from interface + wgIP := "" + iface, err := net.InterfaceByName("wg0") + if err != nil { + return + } + addrs, err := iface.Addrs() + if err != nil { + return + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil { + wgIP = ipnet.IP.String() + break + } + } + if wgIP == "" { + return + } + + // Get public IP + publicIP, err := n.getNodeIPAddress() + if err != nil { + return + } + + nodeID := n.GetPeerID() + if nodeID == "" { + nodeID = fmt.Sprintf("node-%s", wgIP) + } + + db := n.rqliteAdapter.GetSQLDB() + _, err = db.ExecContext(ctx, + "INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)", + nodeID, wgIP, localPubKey, publicIP, 51820) + if err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "Failed to self-register WG peer", zap.Error(err)) + } else { + n.logger.ComponentInfo(logging.ComponentNode, "WireGuard self-registered", + zap.String("wg_ip", wgIP), + zap.String("public_key", localPubKey[:8]+"...")) + } +} + +// startWireGuardSyncLoop runs syncWireGuardPeers periodically +func (n *Node) startWireGuardSyncLoop(ctx context.Context) { + // Ensure this node is registered in wireguard_peers (critical for join flow) + n.ensureWireGuardSelfRegistered(ctx) + + // Run initial sync + if err := n.syncWireGuardPeers(ctx); err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "initial WireGuard peer sync failed", zap.Error(err)) + } + + // Periodic sync every 60 seconds + go func() { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := n.syncWireGuardPeers(ctx); err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "WireGuard peer sync failed", zap.Error(err)) + } + } + } + }() +} + +// parseWGShowPeers extracts public keys of current peers from `wg show wg0` output +func parseWGShowPeers(output string) map[string]struct{} { + peers := make(map[string]struct{}) + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "peer:") { + key := strings.TrimSpace(strings.TrimPrefix(line, "peer:")) + if key != "" { + peers[key] = struct{}{} + } + } + } + return peers +} + +// parseWGShowLocalKey extracts the local public key from `wg show wg0` output +func parseWGShowLocalKey(output string) string { + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "public key:") { + return strings.TrimSpace(strings.TrimPrefix(line, "public key:")) + } + } + return "" +} diff --git a/pkg/rqlite/cluster.go b/pkg/rqlite/cluster.go index 4b3b172..dfb89d1 100644 --- a/pkg/rqlite/cluster.go +++ b/pkg/rqlite/cluster.go @@ -38,6 +38,13 @@ func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rq } requiredRemotePeers := r.config.MinClusterSize - 1 + + // Genesis node (single-node cluster) doesn't need to wait for peers + if requiredRemotePeers <= 0 { + r.logger.Info("Genesis node, skipping peer discovery wait") + return nil + } + _ = r.discoveryService.TriggerPeerExchange(ctx) checkInterval := 2 * time.Second diff --git a/pkg/rqlite/process.go b/pkg/rqlite/process.go index 283034b..d6e5c00 100644 --- a/pkg/rqlite/process.go +++ b/pkg/rqlite/process.go @@ -17,8 +17,55 @@ import ( "go.uber.org/zap" ) +// killOrphanedRQLite kills any orphaned rqlited process still holding the port. +// This can happen when the parent node process crashes and rqlited keeps running. +func (r *RQLiteManager) killOrphanedRQLite() { + // Check if port is already in use by querying the status endpoint + url := fmt.Sprintf("http://localhost:%d/status", r.config.RQLitePort) + client := &http.Client{Timeout: 2 * time.Second} + resp, err := client.Get(url) + if err != nil { + return // Port not in use, nothing to clean up + } + resp.Body.Close() + + // Port is in use — find and kill the orphaned process + r.logger.Warn("Found orphaned rqlited process on port, killing it", + zap.Int("port", r.config.RQLitePort)) + + // Use fuser to find and kill the process holding the port + cmd := exec.Command("fuser", "-k", fmt.Sprintf("%d/tcp", r.config.RQLitePort)) + if err := cmd.Run(); err != nil { + r.logger.Warn("fuser failed, trying lsof", zap.Error(err)) + // Fallback: use lsof + out, err := exec.Command("lsof", "-ti", fmt.Sprintf(":%d", r.config.RQLitePort)).Output() + if err == nil { + for _, pidStr := range strings.Split(strings.TrimSpace(string(out)), "\n") { + if pidStr != "" { + killCmd := exec.Command("kill", "-9", pidStr) + killCmd.Run() + } + } + } + } + + // Wait for port to be released + for i := 0; i < 10; i++ { + time.Sleep(500 * time.Millisecond) + resp, err := client.Get(url) + if err != nil { + return // Port released + } + resp.Body.Close() + } + r.logger.Warn("Could not release port from orphaned process") +} + // launchProcess starts the RQLite process with appropriate arguments func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) error { + // Kill any orphaned rqlited from a previous crash + r.killOrphanedRQLite() + // Build RQLite command args := []string{ "-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort), @@ -180,21 +227,30 @@ func (r *RQLiteManager) waitForReady(ctx context.Context) error { // waitForSQLAvailable waits until a simple query succeeds func (r *RQLiteManager) waitForSQLAvailable(ctx context.Context) error { + r.logger.Info("Waiting for SQL to become available...") ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() + attempts := 0 for { select { case <-ctx.Done(): + r.logger.Error("waitForSQLAvailable timed out", zap.Int("attempts", attempts)) return ctx.Err() case <-ticker.C: + attempts++ if r.connection == nil { + r.logger.Warn("connection is nil in waitForSQLAvailable") continue } _, err := r.connection.QueryOne("SELECT 1") if err == nil { + r.logger.Info("SQL is available", zap.Int("attempts", attempts)) return nil } + if attempts <= 5 || attempts%10 == 0 { + r.logger.Debug("SQL not yet available", zap.Int("attempt", attempts), zap.Error(err)) + } } } } diff --git a/scripts/build-coredns.sh b/scripts/build-coredns.sh index a10889c..c0ac5d1 100755 --- a/scripts/build-coredns.sh +++ b/scripts/build-coredns.sh @@ -74,6 +74,9 @@ on:github.com/coredns/caddy/onevent sign:sign view:view +# Response Rate Limiting (DNS amplification protection) +rrl:rrl + # Custom RQLite plugin rqlite:github.com/DeBrosOfficial/network/pkg/coredns/rqlite EOF