Added wireguard and updated installation process and added more tests

This commit is contained in:
anonpenguin23 2026-01-30 15:30:18 +02:00
parent dcaf695fbc
commit 4acea72467
38 changed files with 2700 additions and 119 deletions

View File

@ -22,7 +22,7 @@ check-gateway:
echo " 3. Run tests: make test-e2e-local"; \ echo " 3. Run tests: make test-e2e-local"; \
echo ""; \ echo ""; \
echo "To run tests against production:"; \ echo "To run tests against production:"; \
echo " ORAMA_GATEWAY_URL=http://VPS-IP:6001 make test-e2e"; \ echo " ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e"; \
exit 1; \ exit 1; \
fi fi
@echo "✅ Gateway is running" @echo "✅ Gateway is running"
@ -36,7 +36,7 @@ test-e2e-local: check-gateway
test-e2e-prod: test-e2e-prod:
@if [ -z "$$ORAMA_GATEWAY_URL" ]; then \ @if [ -z "$$ORAMA_GATEWAY_URL" ]; then \
echo "❌ ORAMA_GATEWAY_URL not set"; \ echo "❌ ORAMA_GATEWAY_URL not set"; \
echo "Usage: ORAMA_GATEWAY_URL=http://VPS-IP:6001 make test-e2e-prod"; \ echo "Usage: ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e-prod"; \
exit 1; \ exit 1; \
fi fi
@echo "Running E2E tests (including production-only) against $$ORAMA_GATEWAY_URL..." @echo "Running E2E tests (including production-only) against $$ORAMA_GATEWAY_URL..."
@ -182,7 +182,7 @@ help:
@echo " make test-e2e - Generic E2E tests (auto-discovers config)" @echo " make test-e2e - Generic E2E tests (auto-discovers config)"
@echo "" @echo ""
@echo " Example production test:" @echo " Example production test:"
@echo " ORAMA_GATEWAY_URL=http://141.227.165.168:6001 make test-e2e-prod" @echo " ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e-prod"
@echo "" @echo ""
@echo "Development Management (via orama):" @echo "Development Management (via orama):"
@echo " ./bin/orama dev status - Show status of all dev services" @echo " ./bin/orama dev status - Show status of all dev services"

View File

@ -53,6 +53,8 @@ func main() {
cli.HandleProdCommand(args) cli.HandleProdCommand(args)
// Direct production commands (new simplified interface) // Direct production commands (new simplified interface)
case "invite":
cli.HandleProdCommand(append([]string{"invite"}, args...))
case "install": case "install":
cli.HandleProdCommand(append([]string{"install"}, args...)) cli.HandleProdCommand(append([]string{"install"}, args...))
case "upgrade": case "upgrade":

View File

@ -353,12 +353,22 @@ Function Invocation:
- Refresh token support - Refresh token support
- Claims-based authorization - Claims-based authorization
### Network Security (WireGuard Mesh)
All inter-node communication is encrypted via a WireGuard VPN mesh:
- **WireGuard IPs:** Each node gets a private IP (10.0.0.x) used for all cluster traffic
- **UFW Firewall:** Only public ports are exposed: 22 (SSH), 53 (DNS, nameservers only), 80/443 (HTTP/HTTPS), 51820 (WireGuard UDP)
- **Internal services** (RQLite 5001/7001, IPFS 4001/4501, Olric 3320/3322, Gateway 6001) are only accessible via WireGuard or localhost
- **Invite tokens:** Single-use, time-limited tokens for secure node joining. No shared secrets on the CLI
- **Join flow:** New nodes authenticate via HTTPS (443), establish WireGuard tunnel, then join all services over the encrypted mesh
### TLS/HTTPS ### TLS/HTTPS
- Automatic ACME (Let's Encrypt) certificate management - Automatic ACME (Let's Encrypt) certificate management via Caddy
- TLS 1.3 support - TLS 1.3 support
- HTTP/2 enabled - HTTP/2 enabled
- Certificate caching - On-demand TLS for deployment custom domains
### Middleware Stack ### Middleware Stack
@ -441,17 +451,25 @@ make test-e2e # Run E2E tests
### Production ### Production
```bash ```bash
# First node (creates cluster) # First node (genesis — creates cluster)
sudo orama install --vps-ip <IP> --domain node1.example.com sudo orama install --vps-ip <IP> --domain node1.example.com --nameserver
# Additional nodes (join cluster) # On the genesis node, generate an invite for a new node
sudo orama install --vps-ip <IP> --domain node2.example.com \ orama invite
--peers /dns4/node1.example.com/tcp/4001/p2p/<PEER_ID> \ # Outputs: sudo orama install --join https://node1.example.com --token <TOKEN> --vps-ip <NEW_IP>
--join <node1-ip>:7002 \
--cluster-secret <secret> \ # Additional nodes (join via invite token over HTTPS)
--swarm-key <key> sudo orama install --join https://node1.example.com --token <TOKEN> \
--vps-ip <IP> --nameserver
``` ```
**Security:** Nodes join via single-use invite tokens over HTTPS. A WireGuard VPN tunnel
is established before any cluster services start. All inter-node traffic (RQLite, IPFS,
Olric, LibP2P) flows over the encrypted WireGuard mesh — no cluster ports are exposed
publicly. **Never use `http://<ip>:6001`** for joining — port 6001 is internal-only and
blocked by UFW. Use the domain (`https://node1.example.com`) or, if DNS is not yet
configured, use the IP over HTTP port 80 (`http://<ip>`) which goes through Caddy.
### Docker (Future) ### Docker (Future)
Planned containerization with Docker Compose and Kubernetes support. Planned containerization with Docker Compose and Kubernetes support.

View File

@ -95,14 +95,74 @@ To deploy to all nodes, repeat steps 3-5 (dev) or 3-4 (production) for each VPS
### CLI Flags Reference ### CLI Flags Reference
#### `orama install`
| Flag | Description | | Flag | Description |
|------|-------------| |------|-------------|
| `--branch <branch>` | Git branch to pull from (production deployment) | | `--vps-ip <ip>` | VPS public IP address (required) |
| `--no-pull` | Skip git pull, use existing `/home/debros/src` (dev deployment) | | `--domain <domain>` | Domain for HTTPS certificates |
| `--base-domain <domain>` | Base domain for deployment routing (e.g., dbrs.space) |
| `--nameserver` | Configure this node as a nameserver (CoreDNS + Caddy) |
| `--join <url>` | Join existing cluster via HTTPS URL (e.g., `https://node1.dbrs.space`) |
| `--token <token>` | Invite token for joining (from `orama invite` on existing node) |
| `--branch <branch>` | Git branch to use (default: main) |
| `--no-pull` | Skip git clone/pull, use existing `/home/debros/src` |
| `--force` | Force reconfiguration even if already installed |
| `--skip-firewall` | Skip UFW firewall setup |
| `--skip-checks` | Skip minimum resource checks (RAM/CPU) |
#### `orama invite`
| Flag | Description |
|------|-------------|
| `--expiry <duration>` | Token expiry duration (default: 1h) |
#### `orama upgrade`
| Flag | Description |
|------|-------------|
| `--branch <branch>` | Git branch to pull from |
| `--no-pull` | Skip git pull, use existing source |
| `--restart` | Restart all services after upgrade | | `--restart` | Restart all services after upgrade |
| `--nameserver` | Configure this node as a nameserver (install only) |
| `--domain <domain>` | Domain for HTTPS certificates (install only) | ### Node Join Flow
| `--vps-ip <ip>` | VPS public IP address (install only) |
```bash
# 1. Genesis node (first node, creates cluster)
sudo orama install --vps-ip 1.2.3.4 --domain node1.dbrs.space \
--base-domain dbrs.space --nameserver
# 2. On genesis node, generate an invite
orama invite
# Output: sudo orama install --join https://node1.dbrs.space --token <TOKEN> --vps-ip <IP>
# 3. On the new node, run the printed command
sudo orama install --join https://node1.dbrs.space --token abc123... \
--vps-ip 5.6.7.8 --nameserver
```
The join flow establishes a WireGuard VPN tunnel before starting cluster services.
All inter-node communication (RQLite, IPFS, Olric) uses WireGuard IPs (10.0.0.x).
No cluster ports are ever exposed publicly.
#### DNS Prerequisite
The `--join` URL should use the HTTPS domain of the genesis node (e.g., `https://node1.dbrs.space`).
For this to work, the domain registrar for `dbrs.space` must have NS records pointing to the genesis
node's IP so that `node1.dbrs.space` resolves publicly.
**If DNS is not yet configured**, you can use the genesis node's public IP with HTTP as a fallback:
```bash
sudo orama install --join http://1.2.3.4 --vps-ip 5.6.7.8 --token abc123... --nameserver
```
This works because Caddy's `:80` block proxies all HTTP traffic to the gateway. However, once DNS
is properly configured, always use the HTTPS domain URL.
**Important:** Never use `http://<ip>:6001` — port 6001 is the internal gateway and is blocked by
UFW from external access. The join request goes through Caddy on port 80 (HTTP) or 443 (HTTPS),
which proxies to the gateway internally.
## Debugging Production Issues ## Debugging Production Issues

View File

@ -170,9 +170,9 @@ func TestNamespaceCluster_OlricHealth(t *testing.T) {
func TestNamespaceCluster_GatewayHealth(t *testing.T) { func TestNamespaceCluster_GatewayHealth(t *testing.T) {
// Check if gateway binary exists // Check if gateway binary exists
gatewayBinaryPaths := []string{ gatewayBinaryPaths := []string{
"./bin/gateway", "./bin/orama",
"../bin/gateway", "../bin/orama",
"/usr/local/bin/orama-gateway", "/usr/local/bin/orama",
} }
var gatewayBinaryExists bool var gatewayBinaryExists bool

View File

@ -0,0 +1,9 @@
-- WireGuard mesh peer tracking
CREATE TABLE IF NOT EXISTS wireguard_peers (
node_id TEXT PRIMARY KEY,
wg_ip TEXT NOT NULL UNIQUE,
public_key TEXT NOT NULL UNIQUE,
public_ip TEXT NOT NULL,
wg_port INTEGER DEFAULT 51820,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@ -0,0 +1,8 @@
CREATE TABLE IF NOT EXISTS invite_tokens (
token TEXT PRIMARY KEY,
created_by TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME NOT NULL,
used_at DATETIME,
used_by_ip TEXT
);

View File

@ -7,42 +7,32 @@ import (
) )
// TestProdCommandFlagParsing verifies that prod command flags are parsed correctly // TestProdCommandFlagParsing verifies that prod command flags are parsed correctly
// Note: The installer now uses --vps-ip presence to determine if it's a first node (no --bootstrap flag) // Genesis node: has --vps-ip but no --join or --token
// First node: has --vps-ip but no --peers or --join // Joining node: has --vps-ip, --join (HTTPS URL), and --token (invite token)
// Joining node: has --vps-ip, --peers, and --cluster-secret
func TestProdCommandFlagParsing(t *testing.T) { func TestProdCommandFlagParsing(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args []string args []string
expectVPSIP string expectVPSIP string
expectDomain string expectDomain string
expectPeers string expectJoin string
expectJoin string expectToken string
expectSecret string expectBranch string
expectBranch string isFirstNode bool // genesis node = no --join and no --token
isFirstNode bool // first node = no peers and no join address
}{ }{
{ {
name: "first node (creates new cluster)", name: "genesis node (creates new cluster)",
args: []string{"install", "--vps-ip", "10.0.0.1", "--domain", "node-1.example.com"}, args: []string{"install", "--vps-ip", "10.0.0.1", "--domain", "node-1.example.com"},
expectVPSIP: "10.0.0.1", expectVPSIP: "10.0.0.1",
expectDomain: "node-1.example.com", expectDomain: "node-1.example.com",
isFirstNode: true, isFirstNode: true,
}, },
{ {
name: "joining node with peers", name: "joining node with invite token",
args: []string{"install", "--vps-ip", "10.0.0.2", "--peers", "/ip4/10.0.0.1/tcp/4001/p2p/Qm123", "--cluster-secret", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, args: []string{"install", "--vps-ip", "10.0.0.2", "--join", "https://node1.dbrs.space", "--token", "abc123def456"},
expectVPSIP: "10.0.0.2", expectVPSIP: "10.0.0.2",
expectPeers: "/ip4/10.0.0.1/tcp/4001/p2p/Qm123", expectJoin: "https://node1.dbrs.space",
expectSecret: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", expectToken: "abc123def456",
isFirstNode: false,
},
{
name: "joining node with join address",
args: []string{"install", "--vps-ip", "10.0.0.3", "--join", "10.0.0.1:7001", "--cluster-secret", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"},
expectVPSIP: "10.0.0.3",
expectJoin: "10.0.0.1:7001",
expectSecret: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
isFirstNode: false, isFirstNode: false,
}, },
{ {
@ -56,8 +46,7 @@ func TestProdCommandFlagParsing(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Extract flags manually to verify parsing logic var vpsIP, domain, joinAddr, token, branch string
var vpsIP, domain, peersStr, joinAddr, clusterSecret, branch string
for i, arg := range tt.args { for i, arg := range tt.args {
switch arg { switch arg {
@ -69,17 +58,13 @@ func TestProdCommandFlagParsing(t *testing.T) {
if i+1 < len(tt.args) { if i+1 < len(tt.args) {
domain = tt.args[i+1] domain = tt.args[i+1]
} }
case "--peers":
if i+1 < len(tt.args) {
peersStr = tt.args[i+1]
}
case "--join": case "--join":
if i+1 < len(tt.args) { if i+1 < len(tt.args) {
joinAddr = tt.args[i+1] joinAddr = tt.args[i+1]
} }
case "--cluster-secret": case "--token":
if i+1 < len(tt.args) { if i+1 < len(tt.args) {
clusterSecret = tt.args[i+1] token = tt.args[i+1]
} }
case "--branch": case "--branch":
if i+1 < len(tt.args) { if i+1 < len(tt.args) {
@ -88,8 +73,8 @@ func TestProdCommandFlagParsing(t *testing.T) {
} }
} }
// First node detection: no peers and no join address // Genesis node detection: no --join and no --token
isFirstNode := peersStr == "" && joinAddr == "" isFirstNode := joinAddr == "" && token == ""
if vpsIP != tt.expectVPSIP { if vpsIP != tt.expectVPSIP {
t.Errorf("expected vpsIP=%q, got %q", tt.expectVPSIP, vpsIP) t.Errorf("expected vpsIP=%q, got %q", tt.expectVPSIP, vpsIP)
@ -97,14 +82,11 @@ func TestProdCommandFlagParsing(t *testing.T) {
if domain != tt.expectDomain { if domain != tt.expectDomain {
t.Errorf("expected domain=%q, got %q", tt.expectDomain, domain) t.Errorf("expected domain=%q, got %q", tt.expectDomain, domain)
} }
if peersStr != tt.expectPeers {
t.Errorf("expected peers=%q, got %q", tt.expectPeers, peersStr)
}
if joinAddr != tt.expectJoin { if joinAddr != tt.expectJoin {
t.Errorf("expected join=%q, got %q", tt.expectJoin, joinAddr) t.Errorf("expected join=%q, got %q", tt.expectJoin, joinAddr)
} }
if clusterSecret != tt.expectSecret { if token != tt.expectToken {
t.Errorf("expected clusterSecret=%q, got %q", tt.expectSecret, clusterSecret) t.Errorf("expected token=%q, got %q", tt.expectToken, token)
} }
if branch != tt.expectBranch { if branch != tt.expectBranch {
t.Errorf("expected branch=%q, got %q", tt.expectBranch, branch) t.Errorf("expected branch=%q, got %q", tt.expectBranch, branch)

View File

@ -5,6 +5,7 @@ import (
"os" "os"
"github.com/DeBrosOfficial/network/pkg/cli/production/install" "github.com/DeBrosOfficial/network/pkg/cli/production/install"
"github.com/DeBrosOfficial/network/pkg/cli/production/invite"
"github.com/DeBrosOfficial/network/pkg/cli/production/lifecycle" "github.com/DeBrosOfficial/network/pkg/cli/production/lifecycle"
"github.com/DeBrosOfficial/network/pkg/cli/production/logs" "github.com/DeBrosOfficial/network/pkg/cli/production/logs"
"github.com/DeBrosOfficial/network/pkg/cli/production/migrate" "github.com/DeBrosOfficial/network/pkg/cli/production/migrate"
@ -24,6 +25,8 @@ func HandleCommand(args []string) {
subargs := args[1:] subargs := args[1:]
switch subcommand { switch subcommand {
case "invite":
invite.Handle(subargs)
case "install": case "install":
install.Handle(subargs) install.Handle(subargs)
case "upgrade": case "upgrade":

View File

@ -17,10 +17,11 @@ type Flags struct {
DryRun bool DryRun bool
SkipChecks bool SkipChecks bool
Nameserver bool // Make this node a nameserver (runs CoreDNS + Caddy) Nameserver bool // Make this node a nameserver (runs CoreDNS + Caddy)
JoinAddress string JoinAddress string // HTTPS URL of existing node (e.g., https://node1.dbrs.space)
ClusterSecret string Token string // Invite token for joining (from orama invite)
SwarmKey string ClusterSecret string // Deprecated: use --token instead
PeersStr string SwarmKey string // Deprecated: use --token instead
PeersStr string // Deprecated: use --token instead
// IPFS/Cluster specific info for Peering configuration // IPFS/Cluster specific info for Peering configuration
IPFSPeerID string IPFSPeerID string
@ -28,6 +29,9 @@ type Flags struct {
IPFSClusterPeerID string IPFSClusterPeerID string
IPFSClusterAddrs string IPFSClusterAddrs string
// Security flags
SkipFirewall bool // Skip UFW firewall setup (for users who manage their own firewall)
// Anyone relay operator flags // Anyone relay operator flags
AnyoneRelay bool // Run as relay operator instead of client AnyoneRelay bool // Run as relay operator instead of client
AnyoneExit bool // Run as exit relay (legal implications) AnyoneExit bool // Run as exit relay (legal implications)
@ -57,9 +61,10 @@ func ParseFlags(args []string) (*Flags, error) {
fs.BoolVar(&flags.Nameserver, "nameserver", false, "Make this node a nameserver (runs CoreDNS + Caddy)") fs.BoolVar(&flags.Nameserver, "nameserver", false, "Make this node a nameserver (runs CoreDNS + Caddy)")
// Cluster join flags // Cluster join flags
fs.StringVar(&flags.JoinAddress, "join", "", "Join an existing cluster (e.g. 1.2.3.4:7001)") fs.StringVar(&flags.JoinAddress, "join", "", "Join existing cluster via HTTPS URL (e.g. https://node1.dbrs.space)")
fs.StringVar(&flags.ClusterSecret, "cluster-secret", "", "Cluster secret for IPFS Cluster (required if joining)") fs.StringVar(&flags.Token, "token", "", "Invite token for joining (from orama invite on existing node)")
fs.StringVar(&flags.SwarmKey, "swarm-key", "", "IPFS Swarm key hex (64 chars, last line of swarm.key)") fs.StringVar(&flags.ClusterSecret, "cluster-secret", "", "Deprecated: use --token instead")
fs.StringVar(&flags.SwarmKey, "swarm-key", "", "Deprecated: use --token instead")
fs.StringVar(&flags.PeersStr, "peers", "", "Comma-separated list of bootstrap peer multiaddrs") fs.StringVar(&flags.PeersStr, "peers", "", "Comma-separated list of bootstrap peer multiaddrs")
// IPFS/Cluster specific info for Peering configuration // IPFS/Cluster specific info for Peering configuration
@ -68,6 +73,9 @@ func ParseFlags(args []string) (*Flags, error) {
fs.StringVar(&flags.IPFSClusterPeerID, "ipfs-cluster-peer", "", "Peer ID of existing IPFS Cluster node") fs.StringVar(&flags.IPFSClusterPeerID, "ipfs-cluster-peer", "", "Peer ID of existing IPFS Cluster node")
fs.StringVar(&flags.IPFSClusterAddrs, "ipfs-cluster-addrs", "", "Comma-separated multiaddrs of existing IPFS Cluster node") fs.StringVar(&flags.IPFSClusterAddrs, "ipfs-cluster-addrs", "", "Comma-separated multiaddrs of existing IPFS Cluster node")
// Security flags
fs.BoolVar(&flags.SkipFirewall, "skip-firewall", false, "Skip UFW firewall setup (for users who manage their own firewall)")
// Anyone relay operator flags // Anyone relay operator flags
fs.BoolVar(&flags.AnyoneRelay, "anyone-relay", false, "Run as Anyone relay operator (earn rewards)") fs.BoolVar(&flags.AnyoneRelay, "anyone-relay", false, "Run as Anyone relay operator (earn rewards)")
fs.BoolVar(&flags.AnyoneExit, "anyone-exit", false, "Run as exit relay (requires --anyone-relay, legal implications)") fs.BoolVar(&flags.AnyoneExit, "anyone-exit", false, "Run as exit relay (requires --anyone-relay, legal implications)")

View File

@ -2,14 +2,20 @@ package install
import ( import (
"bufio" "bufio"
"crypto/tls"
"encoding/json"
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"github.com/DeBrosOfficial/network/pkg/cli/utils" "github.com/DeBrosOfficial/network/pkg/cli/utils"
"github.com/DeBrosOfficial/network/pkg/environments/production" "github.com/DeBrosOfficial/network/pkg/environments/production"
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
) )
// Orchestrator manages the install process // Orchestrator manages the install process
@ -97,9 +103,11 @@ func (o *Orchestrator) Execute() error {
return nil return nil
} }
// Save secrets before installation // Save secrets before installation (only for genesis; join flow gets secrets from response)
if err := o.validator.SaveSecrets(); err != nil { if !o.isJoiningNode() {
return err if err := o.validator.SaveSecrets(); err != nil {
return err
}
} }
// Save preferences for future upgrades (branch + nameserver) // Save preferences for future upgrades (branch + nameserver)
@ -132,33 +140,56 @@ func (o *Orchestrator) Execute() error {
return fmt.Errorf("binary installation failed: %w", err) return fmt.Errorf("binary installation failed: %w", err)
} }
// Phase 3: Generate secrets FIRST (before service initialization) // Branch: genesis node vs joining node
if o.isJoiningNode() {
return o.executeJoinFlow()
}
return o.executeGenesisFlow()
}
// isJoiningNode returns true if --join and --token are both set
func (o *Orchestrator) isJoiningNode() bool {
return o.flags.JoinAddress != "" && o.flags.Token != ""
}
// executeGenesisFlow runs the install for the first node in a new cluster
func (o *Orchestrator) executeGenesisFlow() error {
// Phase 3: Generate secrets locally
fmt.Printf("\n🔐 Phase 3: Generating secrets...\n") fmt.Printf("\n🔐 Phase 3: Generating secrets...\n")
if err := o.setup.Phase3GenerateSecrets(); err != nil { if err := o.setup.Phase3GenerateSecrets(); err != nil {
return fmt.Errorf("secret generation failed: %w", err) return fmt.Errorf("secret generation failed: %w", err)
} }
// Phase 4: Generate configs (BEFORE service initialization) // Phase 6a: WireGuard — self-assign 10.0.0.1
fmt.Printf("\n🔒 Phase 6a: Setting up WireGuard mesh VPN...\n")
if _, _, err := o.setup.Phase6SetupWireGuard(true); err != nil {
fmt.Fprintf(os.Stderr, " ⚠️ Warning: WireGuard setup failed: %v\n", err)
} else {
fmt.Printf(" ✓ WireGuard configured (10.0.0.1)\n")
}
// Phase 6b: UFW firewall
fmt.Printf("\n🛡 Phase 6b: Setting up UFW firewall...\n")
if err := o.setup.Phase6bSetupFirewall(o.flags.SkipFirewall); err != nil {
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall setup failed: %v\n", err)
}
// Phase 4: Generate configs using WG IP (10.0.0.1) as advertise address
// All inter-node communication uses WireGuard IPs, not public IPs
fmt.Printf("\n⚙ Phase 4: Generating configurations...\n") fmt.Printf("\n⚙ Phase 4: Generating configurations...\n")
// Internal gateway always runs HTTP on port 6001
// When using Caddy (nameserver mode), Caddy handles external HTTPS and proxies to internal gateway
// When not using Caddy, the gateway runs HTTP-only (use a reverse proxy for HTTPS)
enableHTTPS := false enableHTTPS := false
if err := o.setup.Phase4GenerateConfigs(o.peers, o.flags.VpsIP, enableHTTPS, o.flags.Domain, o.flags.BaseDomain, o.flags.JoinAddress); err != nil { genesisWGIP := "10.0.0.1"
if err := o.setup.Phase4GenerateConfigs(o.peers, genesisWGIP, enableHTTPS, o.flags.Domain, o.flags.BaseDomain, ""); err != nil {
return fmt.Errorf("configuration generation failed: %w", err) return fmt.Errorf("configuration generation failed: %w", err)
} }
// Validate generated configuration
if err := o.validator.ValidateGeneratedConfig(); err != nil { if err := o.validator.ValidateGeneratedConfig(); err != nil {
return err return err
} }
// Phase 2c: Initialize services (after config is in place) // Phase 2c: Initialize services (use WG IP for IPFS Cluster peer discovery)
fmt.Printf("\nPhase 2c: Initializing services...\n") fmt.Printf("\nPhase 2c: Initializing services...\n")
ipfsPeerInfo := o.buildIPFSPeerInfo() if err := o.setup.Phase2cInitializeServices(o.peers, genesisWGIP, nil, nil); err != nil {
ipfsClusterPeerInfo := o.buildIPFSClusterPeerInfo()
if err := o.setup.Phase2cInitializeServices(o.peers, o.flags.VpsIP, ipfsPeerInfo, ipfsClusterPeerInfo); err != nil {
return fmt.Errorf("service initialization failed: %w", err) return fmt.Errorf("service initialization failed: %w", err)
} }
@ -168,9 +199,9 @@ func (o *Orchestrator) Execute() error {
return fmt.Errorf("service creation failed: %w", err) return fmt.Errorf("service creation failed: %w", err)
} }
// Seed DNS records after services are running (RQLite must be up) // Phase 7: Seed DNS records
if o.flags.Nameserver && o.flags.BaseDomain != "" { if o.flags.Nameserver && o.flags.BaseDomain != "" {
fmt.Printf("\n🌐 Phase 6: Seeding DNS records...\n") fmt.Printf("\n🌐 Phase 7: Seeding DNS records...\n")
fmt.Printf(" Waiting for RQLite to start (10s)...\n") fmt.Printf(" Waiting for RQLite to start (10s)...\n")
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
if err := o.setup.SeedDNSRecords(o.flags.BaseDomain, o.flags.VpsIP, o.peers); err != nil { if err := o.setup.SeedDNSRecords(o.flags.BaseDomain, o.flags.VpsIP, o.peers); err != nil {
@ -180,18 +211,206 @@ func (o *Orchestrator) Execute() error {
} }
} }
// Log completion with actual peer ID
o.setup.LogSetupComplete(o.setup.NodePeerID) o.setup.LogSetupComplete(o.setup.NodePeerID)
fmt.Printf("✅ Production installation complete!\n\n") fmt.Printf("✅ Production installation complete!\n\n")
o.printFirstNodeSecrets()
return nil
}
// For first node, print important secrets and identifiers // executeJoinFlow runs the install for a node joining an existing cluster via invite token
if o.validator.IsFirstNode() { func (o *Orchestrator) executeJoinFlow() error {
o.printFirstNodeSecrets() // Step 1: Generate WG keypair
fmt.Printf("\n🔑 Generating WireGuard keypair...\n")
privKey, pubKey, err := production.GenerateKeyPair()
if err != nil {
return fmt.Errorf("failed to generate WG keypair: %w", err)
}
fmt.Printf(" ✓ WireGuard keypair generated\n")
// Step 2: Call join endpoint on existing node
fmt.Printf("\n🤝 Requesting cluster join from %s...\n", o.flags.JoinAddress)
joinResp, err := o.callJoinEndpoint(pubKey)
if err != nil {
return fmt.Errorf("join request failed: %w", err)
}
fmt.Printf(" ✓ Join approved — assigned WG IP: %s\n", joinResp.WGIP)
fmt.Printf(" ✓ Received %d WG peers\n", len(joinResp.WGPeers))
// Step 3: Configure WireGuard with assigned IP and peers
fmt.Printf("\n🔒 Configuring WireGuard tunnel...\n")
var wgPeers []production.WireGuardPeer
for _, p := range joinResp.WGPeers {
wgPeers = append(wgPeers, production.WireGuardPeer{
PublicKey: p.PublicKey,
Endpoint: p.Endpoint,
AllowedIP: p.AllowedIP,
})
}
// Install WG package first
wp := production.NewWireGuardProvisioner(production.WireGuardConfig{})
if err := wp.Install(); err != nil {
return fmt.Errorf("failed to install wireguard: %w", err)
}
if err := o.setup.EnableWireGuardWithPeers(privKey, joinResp.WGIP, wgPeers); err != nil {
return fmt.Errorf("failed to enable WireGuard: %w", err)
}
// Step 4: Verify WG tunnel
fmt.Printf("\n🔍 Verifying WireGuard tunnel...\n")
if err := o.verifyWGTunnel(joinResp.WGPeers); err != nil {
return fmt.Errorf("WireGuard tunnel verification failed: %w", err)
}
fmt.Printf(" ✓ WireGuard tunnel established\n")
// Step 5: UFW firewall
fmt.Printf("\n🛡 Setting up UFW firewall...\n")
if err := o.setup.Phase6bSetupFirewall(o.flags.SkipFirewall); err != nil {
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall setup failed: %v\n", err)
}
// Step 6: Save secrets from join response
fmt.Printf("\n🔐 Saving cluster secrets...\n")
if err := o.saveSecretsFromJoinResponse(joinResp); err != nil {
return fmt.Errorf("failed to save secrets: %w", err)
}
fmt.Printf(" ✓ Secrets saved\n")
// Step 7: Generate configs using WG IP as advertise address
// All inter-node communication uses WireGuard IPs, not public IPs
fmt.Printf("\n⚙ Generating configurations...\n")
enableHTTPS := false
rqliteJoin := joinResp.RQLiteJoinAddress
if err := o.setup.Phase4GenerateConfigs(joinResp.BootstrapPeers, joinResp.WGIP, enableHTTPS, o.flags.Domain, joinResp.BaseDomain, rqliteJoin); err != nil {
return fmt.Errorf("configuration generation failed: %w", err)
}
if err := o.validator.ValidateGeneratedConfig(); err != nil {
return err
}
// Step 8: Initialize services with IPFS peer info from join response
fmt.Printf("\nInitializing services...\n")
var ipfsPeerInfo *production.IPFSPeerInfo
if joinResp.IPFSPeer.ID != "" {
ipfsPeerInfo = &production.IPFSPeerInfo{
PeerID: joinResp.IPFSPeer.ID,
Addrs: joinResp.IPFSPeer.Addrs,
}
}
var ipfsClusterPeerInfo *production.IPFSClusterPeerInfo
if joinResp.IPFSClusterPeer.ID != "" {
ipfsClusterPeerInfo = &production.IPFSClusterPeerInfo{
PeerID: joinResp.IPFSClusterPeer.ID,
Addrs: joinResp.IPFSClusterPeer.Addrs,
}
}
if err := o.setup.Phase2cInitializeServices(joinResp.BootstrapPeers, joinResp.WGIP, ipfsPeerInfo, ipfsClusterPeerInfo); err != nil {
return fmt.Errorf("service initialization failed: %w", err)
}
// Step 9: Create systemd services
fmt.Printf("\n🔧 Creating systemd services...\n")
if err := o.setup.Phase5CreateSystemdServices(enableHTTPS); err != nil {
return fmt.Errorf("service creation failed: %w", err)
}
o.setup.LogSetupComplete(o.setup.NodePeerID)
fmt.Printf("✅ Production installation complete! Joined cluster via %s\n\n", o.flags.JoinAddress)
return nil
}
// callJoinEndpoint sends the join request to the existing node's HTTPS endpoint
func (o *Orchestrator) callJoinEndpoint(wgPubKey string) (*joinhandlers.JoinResponse, error) {
reqBody := joinhandlers.JoinRequest{
Token: o.flags.Token,
WGPublicKey: wgPubKey,
PublicIP: o.flags.VpsIP,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := strings.TrimRight(o.flags.JoinAddress, "/") + "/v1/internal/join"
client := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true, // Self-signed certs during initial setup
},
},
}
resp, err := client.Post(url, "application/json", strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, fmt.Errorf("failed to contact %s: %w", url, err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("join rejected (HTTP %d): %s", resp.StatusCode, string(respBody))
}
var joinResp joinhandlers.JoinResponse
if err := json.Unmarshal(respBody, &joinResp); err != nil {
return nil, fmt.Errorf("failed to parse join response: %w", err)
}
return &joinResp, nil
}
// saveSecretsFromJoinResponse writes cluster secrets received from the join endpoint to disk
func (o *Orchestrator) saveSecretsFromJoinResponse(resp *joinhandlers.JoinResponse) error {
secretsDir := filepath.Join(o.oramaDir, "secrets")
if err := os.MkdirAll(secretsDir, 0700); err != nil {
return fmt.Errorf("failed to create secrets dir: %w", err)
}
// Write cluster secret
if resp.ClusterSecret != "" {
if err := os.WriteFile(filepath.Join(secretsDir, "cluster-secret"), []byte(resp.ClusterSecret), 0600); err != nil {
return fmt.Errorf("failed to write cluster-secret: %w", err)
}
}
// Write swarm key
if resp.SwarmKey != "" {
if err := os.WriteFile(filepath.Join(secretsDir, "swarm.key"), []byte(resp.SwarmKey), 0600); err != nil {
return fmt.Errorf("failed to write swarm.key: %w", err)
}
} }
return nil return nil
} }
// verifyWGTunnel pings a WG peer to verify the tunnel is working
func (o *Orchestrator) verifyWGTunnel(peers []joinhandlers.WGPeerInfo) error {
if len(peers) == 0 {
return fmt.Errorf("no WG peers to verify")
}
// Extract the IP from the first peer's AllowedIP (e.g. "10.0.0.1/32" -> "10.0.0.1")
targetIP := strings.TrimSuffix(peers[0].AllowedIP, "/32")
// Retry ping for up to 30 seconds
deadline := time.Now().Add(30 * time.Second)
for time.Now().Before(deadline) {
cmd := exec.Command("ping", "-c", "1", "-W", "2", targetIP)
if err := cmd.Run(); err == nil {
return nil
}
time.Sleep(2 * time.Second)
}
return fmt.Errorf("could not reach %s via WireGuard after 30s", targetIP)
}
func (o *Orchestrator) buildIPFSPeerInfo() *production.IPFSPeerInfo { func (o *Orchestrator) buildIPFSPeerInfo() *production.IPFSPeerInfo {
if o.flags.IPFSPeerID != "" { if o.flags.IPFSPeerID != "" {
var addrs []string var addrs []string

View File

@ -0,0 +1,115 @@
package invite
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"os"
"strings"
"time"
"gopkg.in/yaml.v3"
)
// Handle processes the invite command
func Handle(args []string) {
// Must run on a cluster node with RQLite running locally
domain, err := readNodeDomain()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: could not read node config: %v\n", err)
fmt.Fprintf(os.Stderr, "Make sure you're running this on an installed node.\n")
os.Exit(1)
}
// Generate random token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
fmt.Fprintf(os.Stderr, "Error generating token: %v\n", err)
os.Exit(1)
}
token := hex.EncodeToString(tokenBytes)
// Determine expiry (default 1 hour, --expiry flag for override)
expiry := time.Hour
for i, arg := range args {
if arg == "--expiry" && i+1 < len(args) {
d, err := time.ParseDuration(args[i+1])
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid expiry duration: %v\n", err)
os.Exit(1)
}
expiry = d
}
}
expiresAt := time.Now().UTC().Add(expiry).Format("2006-01-02 15:04:05")
// Get node ID for created_by
nodeID := "unknown"
if hostname, err := os.Hostname(); err == nil {
nodeID = hostname
}
// Insert token into RQLite via HTTP API
if err := insertToken(token, nodeID, expiresAt); err != nil {
fmt.Fprintf(os.Stderr, "Error storing invite token: %v\n", err)
fmt.Fprintf(os.Stderr, "Make sure RQLite is running on this node.\n")
os.Exit(1)
}
// Print the invite command
fmt.Printf("\nInvite token created (expires in %s)\n\n", expiry)
fmt.Printf("Run this on the new node:\n\n")
fmt.Printf(" sudo orama install --join https://%s --token %s --vps-ip <NEW_NODE_IP> --nameserver\n\n", domain, token)
fmt.Printf("Replace <NEW_NODE_IP> with the new node's public IP address.\n")
}
// readNodeDomain reads the domain from the node config file
func readNodeDomain() (string, error) {
configPath := "/home/debros/.orama/configs/node.yaml"
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("read config: %w", err)
}
var config struct {
Node struct {
Domain string `yaml:"domain"`
} `yaml:"node"`
}
if err := yaml.Unmarshal(data, &config); err != nil {
return "", fmt.Errorf("parse config: %w", err)
}
if config.Node.Domain == "" {
return "", fmt.Errorf("node domain not set in config")
}
return config.Node.Domain, nil
}
// insertToken inserts an invite token into RQLite via HTTP API
func insertToken(token, createdBy, expiresAt string) error {
body := fmt.Sprintf(`[["INSERT INTO invite_tokens (token, created_by, expires_at) VALUES ('%s', '%s', '%s')"]]`,
token, createdBy, expiresAt)
req, err := http.NewRequest("POST", "http://localhost:5001/db/execute", strings.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to connect to RQLite: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("RQLite returned status %d", resp.StatusCode)
}
return nil
}

View File

@ -147,6 +147,12 @@ func (o *Orchestrator) Execute() error {
fmt.Fprintf(os.Stderr, "⚠️ Service update warning: %v\n", err) fmt.Fprintf(os.Stderr, "⚠️ Service update warning: %v\n", err)
} }
// Re-apply UFW firewall rules (idempotent)
fmt.Printf("\n🛡 Re-applying firewall rules...\n")
if err := o.setup.Phase6bSetupFirewall(false); err != nil {
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall re-apply failed: %v\n", err)
}
fmt.Printf("\n✅ Upgrade complete!\n") fmt.Printf("\n✅ Upgrade complete!\n")
// Restart services if requested // Restart services if requested
@ -317,12 +323,24 @@ func (o *Orchestrator) extractGatewayConfig() (enableHTTPS bool, domain string,
} }
} }
// Also check node.yaml for base_domain // Also check node.yaml for domain and base_domain
nodeConfigPath := filepath.Join(o.oramaDir, "configs", "node.yaml") nodeConfigPath := filepath.Join(o.oramaDir, "configs", "node.yaml")
if data, err := os.ReadFile(nodeConfigPath); err == nil { if data, err := os.ReadFile(nodeConfigPath); err == nil {
configStr := string(data) configStr := string(data)
for _, line := range strings.Split(configStr, "\n") { for _, line := range strings.Split(configStr, "\n") {
trimmed := strings.TrimSpace(line) trimmed := strings.TrimSpace(line)
// Extract domain from node.yaml (under node: section) if not already found
if domain == "" && strings.HasPrefix(trimmed, "domain:") && !strings.HasPrefix(trimmed, "domain_") {
parts := strings.SplitN(trimmed, ":", 2)
if len(parts) > 1 {
d := strings.TrimSpace(parts[1])
d = strings.Trim(d, "\"'")
if d != "" && d != "null" {
domain = d
enableHTTPS = true
}
}
}
if strings.HasPrefix(trimmed, "base_domain:") { if strings.HasPrefix(trimmed, "base_domain:") {
parts := strings.SplitN(trimmed, ":", 2) parts := strings.SplitN(trimmed, ":", 2)
if len(parts) > 1 { if len(parts) > 1 {
@ -332,7 +350,6 @@ func (o *Orchestrator) extractGatewayConfig() (enableHTTPS bool, domain string,
baseDomain = "" baseDomain = ""
} }
} }
break
} }
} }
} }

View File

@ -259,7 +259,7 @@ func (rm *ReplicaManager) GetNodeIP(ctx context.Context, nodeID string) (string,
} }
var rows []nodeRow var rows []nodeRow
query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1` query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1`
err := rm.db.Query(internalCtx, &rows, query, nodeID) err := rm.db.Query(internalCtx, &rows, query, nodeID)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -0,0 +1,133 @@
package production
import (
"fmt"
"os/exec"
"strings"
)
// FirewallConfig holds the configuration for UFW firewall rules
type FirewallConfig struct {
SSHPort int // default 22
IsNameserver bool // enables port 53 TCP+UDP
AnyoneORPort int // 0 = disabled, typically 9001
WireGuardPort int // default 51820
}
// FirewallProvisioner manages UFW firewall setup
type FirewallProvisioner struct {
config FirewallConfig
}
// NewFirewallProvisioner creates a new firewall provisioner
func NewFirewallProvisioner(config FirewallConfig) *FirewallProvisioner {
if config.SSHPort == 0 {
config.SSHPort = 22
}
if config.WireGuardPort == 0 {
config.WireGuardPort = 51820
}
return &FirewallProvisioner{
config: config,
}
}
// IsInstalled checks if UFW is available
func (fp *FirewallProvisioner) IsInstalled() bool {
_, err := exec.LookPath("ufw")
return err == nil
}
// Install installs UFW if not present
func (fp *FirewallProvisioner) Install() error {
if fp.IsInstalled() {
return nil
}
cmd := exec.Command("apt-get", "install", "-y", "ufw")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to install ufw: %w\n%s", err, string(output))
}
return nil
}
// GenerateRules returns the list of UFW commands to apply
func (fp *FirewallProvisioner) GenerateRules() []string {
rules := []string{
// Reset to clean state
"ufw --force reset",
// Default policies
"ufw default deny incoming",
"ufw default allow outgoing",
// SSH (always required)
fmt.Sprintf("ufw allow %d/tcp", fp.config.SSHPort),
// WireGuard (always required for mesh)
fmt.Sprintf("ufw allow %d/udp", fp.config.WireGuardPort),
// Public web services
"ufw allow 80/tcp", // ACME / HTTP redirect
"ufw allow 443/tcp", // HTTPS (Caddy → Gateway)
}
// DNS (only for nameserver nodes)
if fp.config.IsNameserver {
rules = append(rules, "ufw allow 53/tcp")
rules = append(rules, "ufw allow 53/udp")
}
// Anyone relay ORPort
if fp.config.AnyoneORPort > 0 {
rules = append(rules, fmt.Sprintf("ufw allow %d/tcp", fp.config.AnyoneORPort))
}
// Allow all traffic from WireGuard subnet (inter-node encrypted traffic)
rules = append(rules, "ufw allow from 10.0.0.0/8")
// Enable firewall
rules = append(rules, "ufw --force enable")
return rules
}
// Setup applies all firewall rules. Idempotent — safe to call multiple times.
func (fp *FirewallProvisioner) Setup() error {
if err := fp.Install(); err != nil {
return err
}
rules := fp.GenerateRules()
for _, rule := range rules {
parts := strings.Fields(rule)
cmd := exec.Command(parts[0], parts[1:]...)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to apply firewall rule '%s': %w\n%s", rule, err, string(output))
}
}
return nil
}
// IsActive checks if UFW is active
func (fp *FirewallProvisioner) IsActive() bool {
cmd := exec.Command("ufw", "status")
output, err := cmd.CombinedOutput()
if err != nil {
return false
}
return strings.Contains(string(output), "Status: active")
}
// GetStatus returns the current UFW status
func (fp *FirewallProvisioner) GetStatus() (string, error) {
cmd := exec.Command("ufw", "status", "verbose")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to get ufw status: %w\n%s", err, string(output))
}
return string(output), nil
}

View File

@ -0,0 +1,117 @@
package production
import (
"strings"
"testing"
)
func TestFirewallProvisioner_GenerateRules_StandardNode(t *testing.T) {
fp := NewFirewallProvisioner(FirewallConfig{})
rules := fp.GenerateRules()
// Should contain defaults
assertContainsRule(t, rules, "ufw --force reset")
assertContainsRule(t, rules, "ufw default deny incoming")
assertContainsRule(t, rules, "ufw default allow outgoing")
assertContainsRule(t, rules, "ufw allow 22/tcp")
assertContainsRule(t, rules, "ufw allow 51820/udp")
assertContainsRule(t, rules, "ufw allow 80/tcp")
assertContainsRule(t, rules, "ufw allow 443/tcp")
assertContainsRule(t, rules, "ufw allow from 10.0.0.0/8")
assertContainsRule(t, rules, "ufw --force enable")
// Should NOT contain DNS or Anyone relay
for _, rule := range rules {
if strings.Contains(rule, "53/") {
t.Errorf("standard node should not have DNS rule: %s", rule)
}
if strings.Contains(rule, "9001") {
t.Errorf("standard node should not have Anyone relay rule: %s", rule)
}
}
}
func TestFirewallProvisioner_GenerateRules_Nameserver(t *testing.T) {
fp := NewFirewallProvisioner(FirewallConfig{
IsNameserver: true,
})
rules := fp.GenerateRules()
assertContainsRule(t, rules, "ufw allow 53/tcp")
assertContainsRule(t, rules, "ufw allow 53/udp")
}
func TestFirewallProvisioner_GenerateRules_WithAnyoneRelay(t *testing.T) {
fp := NewFirewallProvisioner(FirewallConfig{
AnyoneORPort: 9001,
})
rules := fp.GenerateRules()
assertContainsRule(t, rules, "ufw allow 9001/tcp")
}
func TestFirewallProvisioner_GenerateRules_CustomSSHPort(t *testing.T) {
fp := NewFirewallProvisioner(FirewallConfig{
SSHPort: 2222,
})
rules := fp.GenerateRules()
assertContainsRule(t, rules, "ufw allow 2222/tcp")
// Should NOT have default port 22
for _, rule := range rules {
if rule == "ufw allow 22/tcp" {
t.Error("should not have default SSH port 22 when custom port is set")
}
}
}
func TestFirewallProvisioner_GenerateRules_WireGuardSubnetAllowed(t *testing.T) {
fp := NewFirewallProvisioner(FirewallConfig{})
rules := fp.GenerateRules()
assertContainsRule(t, rules, "ufw allow from 10.0.0.0/8")
}
func TestFirewallProvisioner_GenerateRules_FullConfig(t *testing.T) {
fp := NewFirewallProvisioner(FirewallConfig{
SSHPort: 2222,
IsNameserver: true,
AnyoneORPort: 9001,
WireGuardPort: 51821,
})
rules := fp.GenerateRules()
assertContainsRule(t, rules, "ufw allow 2222/tcp")
assertContainsRule(t, rules, "ufw allow 51821/udp")
assertContainsRule(t, rules, "ufw allow 53/tcp")
assertContainsRule(t, rules, "ufw allow 53/udp")
assertContainsRule(t, rules, "ufw allow 9001/tcp")
}
func TestFirewallProvisioner_DefaultPorts(t *testing.T) {
fp := NewFirewallProvisioner(FirewallConfig{})
if fp.config.SSHPort != 22 {
t.Errorf("default SSHPort = %d, want 22", fp.config.SSHPort)
}
if fp.config.WireGuardPort != 51820 {
t.Errorf("default WireGuardPort = %d, want 51820", fp.config.WireGuardPort)
}
}
func assertContainsRule(t *testing.T, rules []string, expected string) {
t.Helper()
for _, rule := range rules {
if rule == expected {
return
}
}
t.Errorf("rules should contain '%s', got: %v", expected, rules)
}

View File

@ -254,6 +254,13 @@ func (ps *ProductionSetup) Phase2ProvisionEnvironment() error {
ps.logf(" ✓ Deployment sudoers configured") ps.logf(" ✓ Deployment sudoers configured")
} }
// Set up WireGuard sudoers (allows debros user to manage WG peers)
if err := ps.userProvisioner.SetupWireGuardSudoers(); err != nil {
ps.logf(" ⚠️ Failed to setup wireguard sudoers: %v", err)
} else {
ps.logf(" ✓ WireGuard sudoers configured")
}
// Create directory structure (unified structure) // Create directory structure (unified structure)
if err := ps.fsProvisioner.EnsureDirectoryStructure(); err != nil { if err := ps.fsProvisioner.EnsureDirectoryStructure(); err != nil {
return fmt.Errorf("failed to create directory structure: %w", err) return fmt.Errorf("failed to create directory structure: %w", err)
@ -724,6 +731,25 @@ func (ps *ProductionSetup) Phase5CreateSystemdServices(enableHTTPS bool) error {
ps.logf(" - debros-node.service started (with embedded gateway)") ps.logf(" - debros-node.service started (with embedded gateway)")
} }
// Start CoreDNS and Caddy (nameserver nodes only)
// Caddy depends on debros-node.service (gateway on :6001), so start after node
if ps.isNameserver {
if _, err := os.Stat("/usr/local/bin/coredns"); err == nil {
if err := ps.serviceController.StartService("coredns.service"); err != nil {
ps.logf(" ⚠️ Failed to start coredns.service: %v", err)
} else {
ps.logf(" - coredns.service started")
}
}
if _, err := os.Stat("/usr/bin/caddy"); err == nil {
if err := ps.serviceController.StartService("caddy.service"); err != nil {
ps.logf(" ⚠️ Failed to start caddy.service: %v", err)
} else {
ps.logf(" - caddy.service started")
}
}
}
ps.logf(" ✓ All services started") ps.logf(" ✓ All services started")
return nil return nil
} }
@ -775,6 +801,96 @@ func (ps *ProductionSetup) SeedDNSRecords(baseDomain, vpsIP string, peerAddresse
return nil return nil
} }
// Phase6SetupWireGuard installs WireGuard and generates keys for this node.
// For the first node, it self-assigns 10.0.0.1. For joining nodes, the peer
// exchange happens via HTTPS in the install CLI orchestrator.
func (ps *ProductionSetup) Phase6SetupWireGuard(isFirstNode bool) (privateKey, publicKey string, err error) {
ps.logf("Phase 6a: Setting up WireGuard...")
wp := NewWireGuardProvisioner(WireGuardConfig{})
// Install WireGuard package
if err := wp.Install(); err != nil {
return "", "", fmt.Errorf("failed to install wireguard: %w", err)
}
ps.logf(" ✓ WireGuard installed")
// Generate keypair
privKey, pubKey, err := GenerateKeyPair()
if err != nil {
return "", "", fmt.Errorf("failed to generate WG keys: %w", err)
}
ps.logf(" ✓ WireGuard keypair generated")
if isFirstNode {
// First node: self-assign 10.0.0.1, no peers yet
wp.config = WireGuardConfig{
PrivateKey: privKey,
PrivateIP: "10.0.0.1",
ListenPort: 51820,
}
if err := wp.WriteConfig(); err != nil {
return "", "", fmt.Errorf("failed to write WG config: %w", err)
}
if err := wp.Enable(); err != nil {
return "", "", fmt.Errorf("failed to enable WG: %w", err)
}
ps.logf(" ✓ WireGuard enabled (first node: 10.0.0.1)")
}
return privKey, pubKey, nil
}
// Phase6bSetupFirewall sets up UFW firewall rules
func (ps *ProductionSetup) Phase6bSetupFirewall(skipFirewall bool) error {
if skipFirewall {
ps.logf("Phase 6b: Skipping firewall setup (--skip-firewall)")
return nil
}
ps.logf("Phase 6b: Setting up UFW firewall...")
anyoneORPort := 0
if ps.IsAnyoneRelay() && ps.anyoneRelayConfig != nil {
anyoneORPort = ps.anyoneRelayConfig.ORPort
}
fp := NewFirewallProvisioner(FirewallConfig{
SSHPort: 22,
IsNameserver: ps.isNameserver,
AnyoneORPort: anyoneORPort,
WireGuardPort: 51820,
})
if err := fp.Setup(); err != nil {
return fmt.Errorf("firewall setup failed: %w", err)
}
ps.logf(" ✓ UFW firewall configured and enabled")
return nil
}
// EnableWireGuardWithPeers writes WG config with assigned IP and peers, then enables it.
// Called by joining nodes after peer exchange.
func (ps *ProductionSetup) EnableWireGuardWithPeers(privateKey, assignedIP string, peers []WireGuardPeer) error {
wp := NewWireGuardProvisioner(WireGuardConfig{
PrivateKey: privateKey,
PrivateIP: assignedIP,
ListenPort: 51820,
Peers: peers,
})
if err := wp.WriteConfig(); err != nil {
return fmt.Errorf("failed to write WG config: %w", err)
}
if err := wp.Enable(); err != nil {
return fmt.Errorf("failed to enable WG: %w", err)
}
ps.logf(" ✓ WireGuard enabled (IP: %s, peers: %d)", assignedIP, len(peers))
return nil
}
// LogSetupComplete logs completion information // LogSetupComplete logs completion information
func (ps *ProductionSetup) LogSetupComplete(peerID string) { func (ps *ProductionSetup) LogSetupComplete(peerID string) {
ps.logf("\n" + strings.Repeat("=", 70)) ps.logf("\n" + strings.Repeat("=", 70))

View File

@ -224,6 +224,34 @@ debros ALL=(ALL) NOPASSWD: /bin/rm -f /etc/systemd/system/orama-deploy-*.service
return nil return nil
} }
// SetupWireGuardSudoers configures the debros user with permissions to manage WireGuard
func (up *UserProvisioner) SetupWireGuardSudoers() error {
sudoersFile := "/etc/sudoers.d/debros-wireguard"
sudoersContent := `# DeBros Network - WireGuard Management Permissions
# Allows debros user to manage WireGuard peers
debros ALL=(ALL) NOPASSWD: /usr/bin/wg set wg0 *
debros ALL=(ALL) NOPASSWD: /usr/bin/wg show wg0
debros ALL=(ALL) NOPASSWD: /usr/bin/wg showconf wg0
debros ALL=(ALL) NOPASSWD: /usr/bin/tee /etc/wireguard/wg0.conf
`
// Write sudoers rule (always overwrite to ensure latest)
if err := os.WriteFile(sudoersFile, []byte(sudoersContent), 0440); err != nil {
return fmt.Errorf("failed to create wireguard sudoers rule: %w", err)
}
// Validate sudoers file
cmd := exec.Command("visudo", "-c", "-f", sudoersFile)
if err := cmd.Run(); err != nil {
os.Remove(sudoersFile)
return fmt.Errorf("wireguard sudoers rule validation failed: %w", err)
}
return nil
}
// StateDetector checks for existing production state // StateDetector checks for existing production state
type StateDetector struct { type StateDetector struct {
oramaDir string oramaDir string

View File

@ -0,0 +1,228 @@
package production
import (
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"golang.org/x/crypto/curve25519"
)
// WireGuardPeer represents a WireGuard mesh peer
type WireGuardPeer struct {
PublicKey string // Base64-encoded public key
Endpoint string // e.g., "141.227.165.154:51820"
AllowedIP string // e.g., "10.0.0.2/32"
}
// WireGuardConfig holds the configuration for a WireGuard interface
type WireGuardConfig struct {
PrivateIP string // e.g., "10.0.0.1"
ListenPort int // default 51820
PrivateKey string // Base64-encoded private key
Peers []WireGuardPeer // Known peers
}
// WireGuardProvisioner manages WireGuard VPN setup
type WireGuardProvisioner struct {
configDir string // /etc/wireguard
config WireGuardConfig
}
// NewWireGuardProvisioner creates a new WireGuard provisioner
func NewWireGuardProvisioner(config WireGuardConfig) *WireGuardProvisioner {
if config.ListenPort == 0 {
config.ListenPort = 51820
}
return &WireGuardProvisioner{
configDir: "/etc/wireguard",
config: config,
}
}
// IsInstalled checks if WireGuard tools are available
func (wp *WireGuardProvisioner) IsInstalled() bool {
_, err := exec.LookPath("wg")
return err == nil
}
// Install installs the WireGuard package
func (wp *WireGuardProvisioner) Install() error {
if wp.IsInstalled() {
return nil
}
cmd := exec.Command("apt-get", "install", "-y", "wireguard", "wireguard-tools")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to install wireguard: %w\n%s", err, string(output))
}
return nil
}
// GenerateKeyPair generates a new WireGuard private/public key pair
func GenerateKeyPair() (privateKey, publicKey string, err error) {
// Generate 32 random bytes for private key
var privBytes [32]byte
if _, err := rand.Read(privBytes[:]); err != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Clamp private key per Curve25519 spec
privBytes[0] &= 248
privBytes[31] &= 127
privBytes[31] |= 64
// Derive public key
var pubBytes [32]byte
curve25519.ScalarBaseMult(&pubBytes, &privBytes)
privateKey = base64.StdEncoding.EncodeToString(privBytes[:])
publicKey = base64.StdEncoding.EncodeToString(pubBytes[:])
return privateKey, publicKey, nil
}
// PublicKeyFromPrivate derives the public key from a private key
func PublicKeyFromPrivate(privateKey string) (string, error) {
privBytes, err := base64.StdEncoding.DecodeString(privateKey)
if err != nil {
return "", fmt.Errorf("failed to decode private key: %w", err)
}
if len(privBytes) != 32 {
return "", fmt.Errorf("invalid private key length: %d", len(privBytes))
}
var priv, pub [32]byte
copy(priv[:], privBytes)
curve25519.ScalarBaseMult(&pub, &priv)
return base64.StdEncoding.EncodeToString(pub[:]), nil
}
// GenerateConfig returns the wg0.conf file content
func (wp *WireGuardProvisioner) GenerateConfig() string {
var sb strings.Builder
sb.WriteString("# WireGuard mesh configuration (managed by Orama Network)\n")
sb.WriteString("# Do not edit manually — use orama CLI to manage peers\n\n")
sb.WriteString("[Interface]\n")
sb.WriteString(fmt.Sprintf("PrivateKey = %s\n", wp.config.PrivateKey))
sb.WriteString(fmt.Sprintf("Address = %s/24\n", wp.config.PrivateIP))
sb.WriteString(fmt.Sprintf("ListenPort = %d\n", wp.config.ListenPort))
for _, peer := range wp.config.Peers {
sb.WriteString("\n[Peer]\n")
sb.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey))
if peer.Endpoint != "" {
sb.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint))
}
sb.WriteString(fmt.Sprintf("AllowedIPs = %s\n", peer.AllowedIP))
sb.WriteString("PersistentKeepalive = 25\n")
}
return sb.String()
}
// WriteConfig writes the WireGuard config to /etc/wireguard/wg0.conf
func (wp *WireGuardProvisioner) WriteConfig() error {
confPath := filepath.Join(wp.configDir, "wg0.conf")
content := wp.GenerateConfig()
// Try direct write first (works when running as root)
if err := os.MkdirAll(wp.configDir, 0700); err == nil {
if err := os.WriteFile(confPath, []byte(content), 0600); err == nil {
return nil
}
}
// Fallback to sudo tee (for non-root, e.g. debros user)
cmd := exec.Command("sudo", "tee", confPath)
cmd.Stdin = strings.NewReader(content)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to write wg0.conf via sudo: %w\n%s", err, string(output))
}
return nil
}
// Enable starts and enables the WireGuard interface
func (wp *WireGuardProvisioner) Enable() error {
// Enable on boot
cmd := exec.Command("systemctl", "enable", "wg-quick@wg0")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to enable wg-quick@wg0: %w\n%s", err, string(output))
}
// Start now
cmd = exec.Command("systemctl", "start", "wg-quick@wg0")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to start wg-quick@wg0: %w\n%s", err, string(output))
}
return nil
}
// Restart restarts the WireGuard interface
func (wp *WireGuardProvisioner) Restart() error {
cmd := exec.Command("systemctl", "restart", "wg-quick@wg0")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to restart wg-quick@wg0: %w\n%s", err, string(output))
}
return nil
}
// IsActive checks if the WireGuard interface is up
func (wp *WireGuardProvisioner) IsActive() bool {
cmd := exec.Command("systemctl", "is-active", "--quiet", "wg-quick@wg0")
return cmd.Run() == nil
}
// AddPeer adds a peer to the running WireGuard interface without restart
func (wp *WireGuardProvisioner) AddPeer(peer WireGuardPeer) error {
// Add peer to running interface
args := []string{"wg", "set", "wg0", "peer", peer.PublicKey, "allowed-ips", peer.AllowedIP, "persistent-keepalive", "25"}
if peer.Endpoint != "" {
args = append(args, "endpoint", peer.Endpoint)
}
cmd := exec.Command("sudo", args...)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to add peer %s: %w\n%s", peer.AllowedIP, err, string(output))
}
// Also update config file so it persists across restarts
wp.config.Peers = append(wp.config.Peers, peer)
return wp.WriteConfig()
}
// RemovePeer removes a peer from the running WireGuard interface
func (wp *WireGuardProvisioner) RemovePeer(publicKey string) error {
cmd := exec.Command("sudo", "wg", "set", "wg0", "peer", publicKey, "remove")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("failed to remove peer: %w\n%s", err, string(output))
}
// Remove from config
filtered := make([]WireGuardPeer, 0, len(wp.config.Peers))
for _, p := range wp.config.Peers {
if p.PublicKey != publicKey {
filtered = append(filtered, p)
}
}
wp.config.Peers = filtered
return wp.WriteConfig()
}
// GetStatus returns the current WireGuard interface status
func (wp *WireGuardProvisioner) GetStatus() (string, error) {
cmd := exec.Command("wg", "show", "wg0")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to get wg status: %w\n%s", err, string(output))
}
return string(output), nil
}

View File

@ -0,0 +1,167 @@
package production
import (
"encoding/base64"
"strings"
"testing"
)
func TestGenerateKeyPair(t *testing.T) {
priv, pub, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair failed: %v", err)
}
// Keys should be base64, 44 chars (32 bytes + padding)
if len(priv) != 44 {
t.Errorf("private key length = %d, want 44", len(priv))
}
if len(pub) != 44 {
t.Errorf("public key length = %d, want 44", len(pub))
}
// Should be valid base64
if _, err := base64.StdEncoding.DecodeString(priv); err != nil {
t.Errorf("private key is not valid base64: %v", err)
}
if _, err := base64.StdEncoding.DecodeString(pub); err != nil {
t.Errorf("public key is not valid base64: %v", err)
}
// Private and public should differ
if priv == pub {
t.Error("private and public keys should differ")
}
}
func TestGenerateKeyPair_Unique(t *testing.T) {
priv1, _, _ := GenerateKeyPair()
priv2, _, _ := GenerateKeyPair()
if priv1 == priv2 {
t.Error("two generated key pairs should be unique")
}
}
func TestPublicKeyFromPrivate(t *testing.T) {
priv, expectedPub, err := GenerateKeyPair()
if err != nil {
t.Fatalf("GenerateKeyPair failed: %v", err)
}
pub, err := PublicKeyFromPrivate(priv)
if err != nil {
t.Fatalf("PublicKeyFromPrivate failed: %v", err)
}
if pub != expectedPub {
t.Errorf("PublicKeyFromPrivate = %s, want %s", pub, expectedPub)
}
}
func TestPublicKeyFromPrivate_InvalidKey(t *testing.T) {
_, err := PublicKeyFromPrivate("not-valid-base64!!!")
if err == nil {
t.Error("expected error for invalid base64")
}
_, err = PublicKeyFromPrivate(base64.StdEncoding.EncodeToString([]byte("short")))
if err == nil {
t.Error("expected error for short key")
}
}
func TestWireGuardProvisioner_GenerateConfig_NoPeers(t *testing.T) {
wp := NewWireGuardProvisioner(WireGuardConfig{
PrivateIP: "10.0.0.1",
ListenPort: 51820,
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
})
config := wp.GenerateConfig()
if !strings.Contains(config, "[Interface]") {
t.Error("config should contain [Interface] section")
}
if !strings.Contains(config, "Address = 10.0.0.1/24") {
t.Error("config should contain correct Address")
}
if !strings.Contains(config, "ListenPort = 51820") {
t.Error("config should contain ListenPort")
}
if !strings.Contains(config, "PrivateKey = dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=") {
t.Error("config should contain PrivateKey")
}
if strings.Contains(config, "[Peer]") {
t.Error("config should NOT contain [Peer] section with no peers")
}
}
func TestWireGuardProvisioner_GenerateConfig_WithPeers(t *testing.T) {
wp := NewWireGuardProvisioner(WireGuardConfig{
PrivateIP: "10.0.0.1",
ListenPort: 51820,
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
Peers: []WireGuardPeer{
{
PublicKey: "cGVlcjFwdWJsaWNrZXlwZWVyMXB1YmxpY2tleXM=",
Endpoint: "1.2.3.4:51820",
AllowedIP: "10.0.0.2/32",
},
{
PublicKey: "cGVlcjJwdWJsaWNrZXlwZWVyMnB1YmxpY2tleXM=",
Endpoint: "5.6.7.8:51820",
AllowedIP: "10.0.0.3/32",
},
},
})
config := wp.GenerateConfig()
if strings.Count(config, "[Peer]") != 2 {
t.Errorf("expected 2 [Peer] sections, got %d", strings.Count(config, "[Peer]"))
}
if !strings.Contains(config, "Endpoint = 1.2.3.4:51820") {
t.Error("config should contain first peer endpoint")
}
if !strings.Contains(config, "AllowedIPs = 10.0.0.2/32") {
t.Error("config should contain first peer AllowedIPs")
}
if !strings.Contains(config, "PersistentKeepalive = 25") {
t.Error("config should contain PersistentKeepalive")
}
if !strings.Contains(config, "Endpoint = 5.6.7.8:51820") {
t.Error("config should contain second peer endpoint")
}
}
func TestWireGuardProvisioner_GenerateConfig_PeerWithoutEndpoint(t *testing.T) {
wp := NewWireGuardProvisioner(WireGuardConfig{
PrivateIP: "10.0.0.1",
ListenPort: 51820,
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
Peers: []WireGuardPeer{
{
PublicKey: "cGVlcjFwdWJsaWNrZXlwZWVyMXB1YmxpY2tleXM=",
AllowedIP: "10.0.0.2/32",
},
},
})
config := wp.GenerateConfig()
if strings.Contains(config, "Endpoint") {
t.Error("config should NOT contain Endpoint when peer has none")
}
}
func TestWireGuardProvisioner_DefaultPort(t *testing.T) {
wp := NewWireGuardProvisioner(WireGuardConfig{
PrivateIP: "10.0.0.1",
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
})
if wp.config.ListenPort != 51820 {
t.Errorf("default ListenPort = %d, want 51820", wp.config.ListenPort)
}
}

View File

@ -34,4 +34,7 @@ type Config struct {
IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s) IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s)
IPFSReplicationFactor int // Replication factor for pins (default: 3) IPFSReplicationFactor int // Replication factor for pins (default: 3)
IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs) IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs)
// WireGuard mesh configuration
ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange
} }

View File

@ -26,6 +26,8 @@ import (
deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments" deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments"
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub" pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard"
sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite" sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite"
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage" "github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
"github.com/DeBrosOfficial/network/pkg/ipfs" "github.com/DeBrosOfficial/network/pkg/ipfs"
@ -98,6 +100,15 @@ type Gateway struct {
processManager *process.Manager processManager *process.Manager
healthChecker *health.HealthChecker healthChecker *health.HealthChecker
// Rate limiter
rateLimiter *RateLimiter
// WireGuard peer exchange
wireguardHandler *wireguardhandlers.Handler
// Node join handler
joinHandler *joinhandlers.Handler
// Cluster provisioning for namespace clusters // Cluster provisioning for namespace clusters
clusterProvisioner authhandlers.ClusterProvisioner clusterProvisioner authhandlers.ClusterProvisioner
} }
@ -246,6 +257,16 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
) )
} }
// Initialize rate limiter (300 req/min, burst 50)
gw.rateLimiter = NewRateLimiter(300, 50)
gw.rateLimiter.StartCleanup(5*time.Minute, 10*time.Minute)
// Initialize WireGuard peer exchange handler
if deps.ORMClient != nil {
gw.wireguardHandler = wireguardhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.ClusterSecret)
gw.joinHandler = joinhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.DataDir)
}
// Initialize deployment system // Initialize deployment system
if deps.ORMClient != nil && deps.IPFSClient != nil { if deps.ORMClient != nil && deps.IPFSClient != nil {
// Convert rqlite.Client to database.Database interface for health checker // Convert rqlite.Client to database.Database interface for health checker

View File

@ -643,8 +643,8 @@ func (s *DeploymentService) getNodeIP(ctx context.Context, nodeID string) (strin
var rows []nodeRow var rows []nodeRow
// Try full node ID first // Try full node ID first (prefer internal/WG IP for cross-node communication)
query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1` query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1`
err := s.db.Query(ctx, &rows, query, nodeID) err := s.db.Query(ctx, &rows, query, nodeID)
if err != nil { if err != nil {
return "", err return "", err

View File

@ -0,0 +1,424 @@
package join
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"go.uber.org/zap"
)
// JoinRequest is the request body for node join
type JoinRequest struct {
Token string `json:"token"`
WGPublicKey string `json:"wg_public_key"`
PublicIP string `json:"public_ip"`
}
// JoinResponse contains everything a joining node needs
type JoinResponse struct {
// WireGuard
WGIP string `json:"wg_ip"`
WGPeers []WGPeerInfo `json:"wg_peers"`
// Secrets
ClusterSecret string `json:"cluster_secret"`
SwarmKey string `json:"swarm_key"`
// Cluster join info (all using WG IPs)
RQLiteJoinAddress string `json:"rqlite_join_address"`
IPFSPeer PeerInfo `json:"ipfs_peer"`
IPFSClusterPeer PeerInfo `json:"ipfs_cluster_peer"`
BootstrapPeers []string `json:"bootstrap_peers"`
// Domain
BaseDomain string `json:"base_domain"`
}
// WGPeerInfo represents a WireGuard peer
type WGPeerInfo struct {
PublicKey string `json:"public_key"`
Endpoint string `json:"endpoint"`
AllowedIP string `json:"allowed_ip"`
}
// PeerInfo represents an IPFS/Cluster peer
type PeerInfo struct {
ID string `json:"id"`
Addrs []string `json:"addrs"`
}
// Handler handles the node join endpoint
type Handler struct {
logger *zap.Logger
rqliteClient rqlite.Client
oramaDir string // e.g., /home/debros/.orama
}
// NewHandler creates a new join handler
func NewHandler(logger *zap.Logger, rqliteClient rqlite.Client, oramaDir string) *Handler {
return &Handler{
logger: logger,
rqliteClient: rqliteClient,
oramaDir: oramaDir,
}
}
// HandleJoin handles POST /v1/internal/join
func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req JoinRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
if req.Token == "" || req.WGPublicKey == "" || req.PublicIP == "" {
http.Error(w, "token, wg_public_key, and public_ip are required", http.StatusBadRequest)
return
}
ctx := r.Context()
// 1. Validate and consume the invite token (atomic single-use)
if err := h.consumeToken(ctx, req.Token, req.PublicIP); err != nil {
h.logger.Warn("join token validation failed", zap.Error(err))
http.Error(w, "unauthorized: invalid or expired token", http.StatusUnauthorized)
return
}
// 2. Assign WG IP with retry on conflict
wgIP, err := h.assignWGIP(ctx)
if err != nil {
h.logger.Error("failed to assign WG IP", zap.Error(err))
http.Error(w, "failed to assign WG IP", http.StatusInternalServerError)
return
}
// 3. Register WG peer in database
nodeID := fmt.Sprintf("node-%s", wgIP) // temporary ID based on WG IP
_, err = h.rqliteClient.Exec(ctx,
"INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)",
nodeID, wgIP, req.WGPublicKey, req.PublicIP, 51820)
if err != nil {
h.logger.Error("failed to register WG peer", zap.Error(err))
http.Error(w, "failed to register peer", http.StatusInternalServerError)
return
}
// 4. Add peer to local WireGuard interface immediately
if err := h.addWGPeerLocally(req.WGPublicKey, req.PublicIP, wgIP); err != nil {
h.logger.Warn("failed to add WG peer to local interface", zap.Error(err))
// Non-fatal: the sync loop will pick it up
}
// 5. Read secrets from disk
clusterSecret, err := os.ReadFile(h.oramaDir + "/secrets/cluster-secret")
if err != nil {
h.logger.Error("failed to read cluster secret", zap.Error(err))
http.Error(w, "internal error reading secrets", http.StatusInternalServerError)
return
}
swarmKey, err := os.ReadFile(h.oramaDir + "/secrets/swarm.key")
if err != nil {
h.logger.Error("failed to read swarm key", zap.Error(err))
http.Error(w, "internal error reading secrets", http.StatusInternalServerError)
return
}
// 6. Get all WG peers
wgPeers, err := h.getWGPeers(ctx, req.WGPublicKey)
if err != nil {
h.logger.Error("failed to list WG peers", zap.Error(err))
http.Error(w, "failed to list peers", http.StatusInternalServerError)
return
}
// 7. Get this node's WG IP
myWGIP, err := h.getMyWGIP()
if err != nil {
h.logger.Error("failed to get local WG IP", zap.Error(err))
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
// 8. Query IPFS and IPFS Cluster peer info
ipfsPeer := h.queryIPFSPeerInfo(myWGIP)
ipfsClusterPeer := h.queryIPFSClusterPeerInfo(myWGIP)
// 9. Get this node's libp2p peer ID for bootstrap peers
bootstrapPeers := h.buildBootstrapPeers(myWGIP, ipfsPeer.ID)
// 10. Read base domain from config
baseDomain := h.readBaseDomain()
resp := JoinResponse{
WGIP: wgIP,
WGPeers: wgPeers,
ClusterSecret: strings.TrimSpace(string(clusterSecret)),
SwarmKey: strings.TrimSpace(string(swarmKey)),
RQLiteJoinAddress: fmt.Sprintf("%s:7001", myWGIP),
IPFSPeer: ipfsPeer,
IPFSClusterPeer: ipfsClusterPeer,
BootstrapPeers: bootstrapPeers,
BaseDomain: baseDomain,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
h.logger.Info("node joined cluster",
zap.String("wg_ip", wgIP),
zap.String("public_ip", req.PublicIP))
}
// consumeToken validates and marks an invite token as used (atomic single-use)
func (h *Handler) consumeToken(ctx context.Context, token, usedByIP string) error {
// Atomically mark as used — only succeeds if token exists, is unused, and not expired
result, err := h.rqliteClient.Exec(ctx,
"UPDATE invite_tokens SET used_at = datetime('now'), used_by_ip = ? WHERE token = ? AND used_at IS NULL AND expires_at > datetime('now')",
usedByIP, token)
if err != nil {
return fmt.Errorf("database error: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to check result: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("token invalid, expired, or already used")
}
return nil
}
// assignWGIP finds the next available 10.0.0.x IP, retrying on UNIQUE constraint violation
func (h *Handler) assignWGIP(ctx context.Context) (string, error) {
for attempt := 0; attempt < 3; attempt++ {
var result []struct {
MaxIP string `db:"max_ip"`
}
err := h.rqliteClient.Query(ctx, &result,
"SELECT MAX(wg_ip) as max_ip FROM wireguard_peers")
if err != nil {
return "", fmt.Errorf("failed to query max WG IP: %w", err)
}
if len(result) == 0 || result[0].MaxIP == "" {
return "10.0.0.2", nil // 10.0.0.1 is genesis
}
maxIP := result[0].MaxIP
var a, b, c, d int
if _, err := fmt.Sscanf(maxIP, "%d.%d.%d.%d", &a, &b, &c, &d); err != nil {
return "", fmt.Errorf("failed to parse max WG IP %s: %w", maxIP, err)
}
d++
if d > 254 {
c++
d = 1
if c > 255 {
return "", fmt.Errorf("WireGuard IP space exhausted")
}
}
nextIP := fmt.Sprintf("%d.%d.%d.%d", a, b, c, d)
return nextIP, nil
}
return "", fmt.Errorf("failed to assign WG IP after retries")
}
// addWGPeerLocally adds a peer to the local wg0 interface and persists to config
func (h *Handler) addWGPeerLocally(pubKey, publicIP, wgIP string) error {
// Add to running interface with persistent-keepalive
cmd := exec.Command("sudo", "wg", "set", "wg0",
"peer", pubKey,
"endpoint", fmt.Sprintf("%s:51820", publicIP),
"allowed-ips", fmt.Sprintf("%s/32", wgIP),
"persistent-keepalive", "25")
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("wg set failed: %w\n%s", err, string(output))
}
// Persist to wg0.conf so peer survives wg-quick restart.
// Read current config, append peer section, write back.
confPath := "/etc/wireguard/wg0.conf"
data, err := os.ReadFile(confPath)
if err != nil {
h.logger.Warn("could not read wg0.conf for persistence", zap.Error(err))
return nil // non-fatal: runtime peer is added
}
// Check if peer already in config
if strings.Contains(string(data), pubKey) {
return nil // already persisted
}
peerSection := fmt.Sprintf("\n[Peer]\nPublicKey = %s\nEndpoint = %s:51820\nAllowedIPs = %s/32\nPersistentKeepalive = 25\n",
pubKey, publicIP, wgIP)
newConf := string(data) + peerSection
writeCmd := exec.Command("sudo", "tee", confPath)
writeCmd.Stdin = strings.NewReader(newConf)
if output, err := writeCmd.CombinedOutput(); err != nil {
h.logger.Warn("could not persist peer to wg0.conf", zap.Error(err), zap.String("output", string(output)))
}
return nil
}
// getWGPeers returns all WG peers except the requesting node
func (h *Handler) getWGPeers(ctx context.Context, excludePubKey string) ([]WGPeerInfo, error) {
type peerRow struct {
WGIP string `db:"wg_ip"`
PublicKey string `db:"public_key"`
PublicIP string `db:"public_ip"`
WGPort int `db:"wg_port"`
}
var rows []peerRow
err := h.rqliteClient.Query(ctx, &rows,
"SELECT wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip")
if err != nil {
return nil, err
}
var peers []WGPeerInfo
for _, row := range rows {
if row.PublicKey == excludePubKey {
continue // don't include the requesting node itself
}
port := row.WGPort
if port == 0 {
port = 51820
}
peers = append(peers, WGPeerInfo{
PublicKey: row.PublicKey,
Endpoint: fmt.Sprintf("%s:%d", row.PublicIP, port),
AllowedIP: fmt.Sprintf("%s/32", row.WGIP),
})
}
return peers, nil
}
// getMyWGIP gets this node's WireGuard IP from the wg0 interface
func (h *Handler) getMyWGIP() (string, error) {
out, err := exec.Command("ip", "-4", "addr", "show", "wg0").CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to get wg0 info: %w", err)
}
// Parse "inet 10.0.0.1/32" from output
for _, line := range strings.Split(string(out), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "inet ") {
parts := strings.Fields(line)
if len(parts) >= 2 {
ip := strings.Split(parts[1], "/")[0]
return ip, nil
}
}
}
return "", fmt.Errorf("could not find wg0 IP address")
}
// queryIPFSPeerInfo gets the local IPFS node's peer ID and builds addrs with WG IP
func (h *Handler) queryIPFSPeerInfo(myWGIP string) PeerInfo {
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Post("http://localhost:4501/api/v0/id", "", nil)
if err != nil {
h.logger.Warn("failed to query IPFS peer info", zap.Error(err))
return PeerInfo{}
}
defer resp.Body.Close()
var result struct {
ID string `json:"ID"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
h.logger.Warn("failed to decode IPFS peer info", zap.Error(err))
return PeerInfo{}
}
return PeerInfo{
ID: result.ID,
Addrs: []string{
fmt.Sprintf("/ip4/%s/tcp/4101", myWGIP),
},
}
}
// queryIPFSClusterPeerInfo gets the local IPFS Cluster peer ID and builds addrs with WG IP
func (h *Handler) queryIPFSClusterPeerInfo(myWGIP string) PeerInfo {
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get("http://localhost:9094/id")
if err != nil {
h.logger.Warn("failed to query IPFS Cluster peer info", zap.Error(err))
return PeerInfo{}
}
defer resp.Body.Close()
var result struct {
ID string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
h.logger.Warn("failed to decode IPFS Cluster peer info", zap.Error(err))
return PeerInfo{}
}
return PeerInfo{
ID: result.ID,
Addrs: []string{
fmt.Sprintf("/ip4/%s/tcp/9100/p2p/%s", myWGIP, result.ID),
},
}
}
// buildBootstrapPeers constructs bootstrap peer multiaddrs using WG IPs
func (h *Handler) buildBootstrapPeers(myWGIP, ipfsPeerID string) []string {
if ipfsPeerID == "" {
return nil
}
return []string{
fmt.Sprintf("/ip4/%s/tcp/4101/p2p/%s", myWGIP, ipfsPeerID),
}
}
// readBaseDomain reads the base domain from node config
func (h *Handler) readBaseDomain() string {
data, err := os.ReadFile(h.oramaDir + "/configs/node.yaml")
if err != nil {
return ""
}
// Simple parse — look for base_domain field
for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "base_domain:") {
val := strings.TrimPrefix(line, "base_domain:")
val = strings.TrimSpace(val)
val = strings.Trim(val, `"'`)
return val
}
}
return ""
}

View File

@ -0,0 +1,211 @@
package wireguard
import (
"context"
"encoding/json"
"fmt"
"net/http"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"go.uber.org/zap"
)
// PeerRecord represents a WireGuard peer stored in RQLite
type PeerRecord struct {
NodeID string `json:"node_id" db:"node_id"`
WGIP string `json:"wg_ip" db:"wg_ip"`
PublicKey string `json:"public_key" db:"public_key"`
PublicIP string `json:"public_ip" db:"public_ip"`
WGPort int `json:"wg_port" db:"wg_port"`
}
// RegisterPeerRequest is the request body for peer registration
type RegisterPeerRequest struct {
NodeID string `json:"node_id"`
PublicKey string `json:"public_key"`
PublicIP string `json:"public_ip"`
WGPort int `json:"wg_port,omitempty"`
ClusterSecret string `json:"cluster_secret"`
}
// RegisterPeerResponse is the response for peer registration
type RegisterPeerResponse struct {
AssignedWGIP string `json:"assigned_wg_ip"`
Peers []PeerRecord `json:"peers"`
}
// Handler handles WireGuard peer exchange endpoints
type Handler struct {
logger *zap.Logger
rqliteClient rqlite.Client
clusterSecret string // expected cluster secret for auth
}
// NewHandler creates a new WireGuard handler
func NewHandler(logger *zap.Logger, rqliteClient rqlite.Client, clusterSecret string) *Handler {
return &Handler{
logger: logger,
rqliteClient: rqliteClient,
clusterSecret: clusterSecret,
}
}
// HandleRegisterPeer handles POST /v1/internal/wg/peer
// A new node calls this to register itself and get all existing peers.
func (h *Handler) HandleRegisterPeer(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req RegisterPeerRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}
// Validate cluster secret
if h.clusterSecret != "" && req.ClusterSecret != h.clusterSecret {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if req.NodeID == "" || req.PublicKey == "" || req.PublicIP == "" {
http.Error(w, "node_id, public_key, and public_ip are required", http.StatusBadRequest)
return
}
if req.WGPort == 0 {
req.WGPort = 51820
}
ctx := r.Context()
// Assign next available WG IP
wgIP, err := h.assignNextWGIP(ctx)
if err != nil {
h.logger.Error("failed to assign WG IP", zap.Error(err))
http.Error(w, "failed to assign WG IP", http.StatusInternalServerError)
return
}
// Insert peer record
_, err = h.rqliteClient.Exec(ctx,
"INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)",
req.NodeID, wgIP, req.PublicKey, req.PublicIP, req.WGPort)
if err != nil {
h.logger.Error("failed to insert WG peer", zap.Error(err))
http.Error(w, "failed to register peer", http.StatusInternalServerError)
return
}
// Get all peers (including the one just added)
peers, err := h.ListPeers(ctx)
if err != nil {
h.logger.Error("failed to list WG peers", zap.Error(err))
http.Error(w, "failed to list peers", http.StatusInternalServerError)
return
}
resp := RegisterPeerResponse{
AssignedWGIP: wgIP,
Peers: peers,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
h.logger.Info("registered WireGuard peer",
zap.String("node_id", req.NodeID),
zap.String("wg_ip", wgIP),
zap.String("public_ip", req.PublicIP))
}
// HandleListPeers handles GET /v1/internal/wg/peers
func (h *Handler) HandleListPeers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
peers, err := h.ListPeers(r.Context())
if err != nil {
h.logger.Error("failed to list WG peers", zap.Error(err))
http.Error(w, "failed to list peers", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(peers)
}
// HandleRemovePeer handles DELETE /v1/internal/wg/peer?node_id=xxx
func (h *Handler) HandleRemovePeer(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
nodeID := r.URL.Query().Get("node_id")
if nodeID == "" {
http.Error(w, "node_id parameter required", http.StatusBadRequest)
return
}
_, err := h.rqliteClient.Exec(r.Context(),
"DELETE FROM wireguard_peers WHERE node_id = ?", nodeID)
if err != nil {
h.logger.Error("failed to remove WG peer", zap.Error(err))
http.Error(w, "failed to remove peer", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
h.logger.Info("removed WireGuard peer", zap.String("node_id", nodeID))
}
// ListPeers returns all registered WireGuard peers
func (h *Handler) ListPeers(ctx context.Context) ([]PeerRecord, error) {
var peers []PeerRecord
err := h.rqliteClient.Query(ctx, &peers,
"SELECT node_id, wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip")
if err != nil {
return nil, fmt.Errorf("failed to query wireguard_peers: %w", err)
}
return peers, nil
}
// assignNextWGIP finds the next available 10.0.0.x IP
func (h *Handler) assignNextWGIP(ctx context.Context) (string, error) {
var result []struct {
MaxIP string `db:"max_ip"`
}
err := h.rqliteClient.Query(ctx, &result,
"SELECT MAX(wg_ip) as max_ip FROM wireguard_peers")
if err != nil {
return "", fmt.Errorf("failed to query max WG IP: %w", err)
}
if len(result) == 0 || result[0].MaxIP == "" {
return "10.0.0.1", nil
}
// Parse last octet and increment
maxIP := result[0].MaxIP
var a, b, c, d int
if _, err := fmt.Sscanf(maxIP, "%d.%d.%d.%d", &a, &b, &c, &d); err != nil {
return "", fmt.Errorf("failed to parse max WG IP %s: %w", maxIP, err)
}
d++
if d > 254 {
c++
d = 1
if c > 255 {
return "", fmt.Errorf("WireGuard IP space exhausted")
}
}
return fmt.Sprintf("%d.%d.%d.%d", a, b, c, d), nil
}

View File

@ -19,15 +19,32 @@ import (
// Note: context keys (ctxKeyAPIKey, ctxKeyJWT, CtxKeyNamespaceOverride) are now defined in context.go // Note: context keys (ctxKeyAPIKey, ctxKeyJWT, CtxKeyNamespaceOverride) are now defined in context.go
// withMiddleware adds CORS and logging middleware // withMiddleware adds CORS, security headers, rate limiting, and logging middleware
func (g *Gateway) withMiddleware(next http.Handler) http.Handler { func (g *Gateway) withMiddleware(next http.Handler) http.Handler {
// Order: logging (outermost) -> CORS -> domain routing -> auth -> handler // Order: logging -> security headers -> rate limit -> CORS -> domain routing -> auth -> handler
// Domain routing must come BEFORE auth to handle deployment domains without auth
return g.loggingMiddleware( return g.loggingMiddleware(
g.corsMiddleware( g.securityHeadersMiddleware(
g.domainRoutingMiddleware( g.rateLimitMiddleware(
g.authMiddleware( g.corsMiddleware(
g.authorizationMiddleware(next))))) g.domainRoutingMiddleware(
g.authMiddleware(
g.authorizationMiddleware(next)))))))
}
// securityHeadersMiddleware adds standard security headers to all responses
func (g *Gateway) securityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "0")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
// HSTS only when behind TLS (Caddy)
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
next.ServeHTTP(w, r)
})
} }
// loggingMiddleware logs basic request info and duration // loggingMiddleware logs basic request info and duration
@ -202,6 +219,16 @@ func isPublicPath(p string) bool {
return true return true
} }
// WireGuard peer exchange (auth handled by cluster secret in handler)
if strings.HasPrefix(p, "/v1/internal/wg/") {
return true
}
// Node join endpoint (auth handled by invite token in handler)
if p == "/v1/internal/join" {
return true
}
switch p { switch p {
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers", "/v1/internal/tls/check", "/v1/internal/acme/present", "/v1/internal/acme/cleanup": case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers", "/v1/internal/tls/check", "/v1/internal/acme/present", "/v1/internal/acme/cleanup":
return true return true
@ -912,7 +939,7 @@ func (g *Gateway) proxyCrossNode(w http.ResponseWriter, r *http.Request, deploym
db := g.client.Database() db := g.client.Database()
internalCtx := client.WithInternalAuth(r.Context()) internalCtx := client.WithInternalAuth(r.Context())
query := "SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1" query := "SELECT COALESCE(internal_ip, ip_address) FROM dns_nodes WHERE id = ? LIMIT 1"
result, err := db.Query(internalCtx, query, deployment.HomeNodeID) result, err := db.Query(internalCtx, query, deployment.HomeNodeID)
if err != nil || result == nil || len(result.Rows) == 0 { if err != nil || result == nil || len(result.Rows) == 0 {
g.logger.Warn("Failed to get home node IP", g.logger.Warn("Failed to get home node IP",

129
pkg/gateway/rate_limiter.go Normal file
View File

@ -0,0 +1,129 @@
package gateway
import (
"net"
"net/http"
"strings"
"sync"
"time"
)
// RateLimiter implements a token-bucket rate limiter per client IP.
type RateLimiter struct {
mu sync.Mutex
clients map[string]*bucket
rate float64 // tokens per second
burst int // max tokens (burst capacity)
}
type bucket struct {
tokens float64
lastCheck time.Time
}
// NewRateLimiter creates a rate limiter. ratePerMinute is the sustained rate;
// burst is the maximum number of requests that can be made in a short window.
func NewRateLimiter(ratePerMinute, burst int) *RateLimiter {
return &RateLimiter{
clients: make(map[string]*bucket),
rate: float64(ratePerMinute) / 60.0,
burst: burst,
}
}
// Allow checks if a request from the given IP should be allowed.
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
b, exists := rl.clients[ip]
if !exists {
rl.clients[ip] = &bucket{tokens: float64(rl.burst) - 1, lastCheck: now}
return true
}
// Refill tokens based on elapsed time
elapsed := now.Sub(b.lastCheck).Seconds()
b.tokens += elapsed * rl.rate
if b.tokens > float64(rl.burst) {
b.tokens = float64(rl.burst)
}
b.lastCheck = now
if b.tokens >= 1 {
b.tokens--
return true
}
return false
}
// Cleanup removes stale entries older than the given duration.
func (rl *RateLimiter) Cleanup(maxAge time.Duration) {
rl.mu.Lock()
defer rl.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
for ip, b := range rl.clients {
if b.lastCheck.Before(cutoff) {
delete(rl.clients, ip)
}
}
}
// StartCleanup runs periodic cleanup in a goroutine.
func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) {
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
rl.Cleanup(maxAge)
}
}()
}
// rateLimitMiddleware returns 429 when a client exceeds the rate limit.
// Internal traffic from the WireGuard subnet (10.0.0.0/8) is exempt.
func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler {
if g.rateLimiter == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getClientIP(r)
// Exempt internal cluster traffic (WireGuard subnet)
if isInternalIP(ip) {
next.ServeHTTP(w, r)
return
}
if !g.rateLimiter.Allow(ip) {
w.Header().Set("Retry-After", "5")
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// isInternalIP returns true if the IP is in the WireGuard 10.0.0.0/8 subnet
// or is a loopback address.
func isInternalIP(ipStr string) bool {
// Strip port if present
if strings.Contains(ipStr, ":") {
host, _, err := net.SplitHostPort(ipStr)
if err == nil {
ipStr = host
}
}
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
if ip.IsLoopback() {
return true
}
// 10.0.0.0/8 — WireGuard mesh
_, wgNet, _ := net.ParseCIDR("10.0.0.0/8")
return wgNet.Contains(ip)
}

View File

@ -0,0 +1,197 @@
package gateway
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
)
func TestRateLimiter_AllowsUnderLimit(t *testing.T) {
rl := NewRateLimiter(60, 10) // 1/sec, burst 10
for i := 0; i < 10; i++ {
if !rl.Allow("1.2.3.4") {
t.Fatalf("request %d should be allowed (within burst)", i)
}
}
}
func TestRateLimiter_BlocksOverLimit(t *testing.T) {
rl := NewRateLimiter(60, 5) // 1/sec, burst 5
// Exhaust burst
for i := 0; i < 5; i++ {
rl.Allow("1.2.3.4")
}
if rl.Allow("1.2.3.4") {
t.Fatal("request after burst should be blocked")
}
}
func TestRateLimiter_RefillsOverTime(t *testing.T) {
rl := NewRateLimiter(6000, 5) // 100/sec, burst 5
// Exhaust burst
for i := 0; i < 5; i++ {
rl.Allow("1.2.3.4")
}
if rl.Allow("1.2.3.4") {
t.Fatal("should be blocked after burst")
}
// Wait for refill
time.Sleep(100 * time.Millisecond)
if !rl.Allow("1.2.3.4") {
t.Fatal("should be allowed after refill")
}
}
func TestRateLimiter_PerIPIsolation(t *testing.T) {
rl := NewRateLimiter(60, 2)
// Exhaust IP A
rl.Allow("1.1.1.1")
rl.Allow("1.1.1.1")
if rl.Allow("1.1.1.1") {
t.Fatal("IP A should be blocked")
}
// IP B should still be allowed
if !rl.Allow("2.2.2.2") {
t.Fatal("IP B should be allowed")
}
}
func TestRateLimiter_Cleanup(t *testing.T) {
rl := NewRateLimiter(60, 10)
rl.Allow("old-ip")
// Force the entry to be old
rl.mu.Lock()
rl.clients["old-ip"].lastCheck = time.Now().Add(-20 * time.Minute)
rl.mu.Unlock()
rl.Cleanup(10 * time.Minute)
rl.mu.Lock()
_, exists := rl.clients["old-ip"]
rl.mu.Unlock()
if exists {
t.Fatal("stale entry should have been cleaned up")
}
}
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
rl := NewRateLimiter(60000, 100) // high limit to avoid false failures
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
rl.Allow("concurrent-ip")
}
}()
}
wg.Wait()
}
func TestIsInternalIP(t *testing.T) {
tests := []struct {
ip string
internal bool
}{
{"10.0.0.1", true},
{"10.0.0.254", true},
{"10.255.255.255", true},
{"127.0.0.1", true},
{"192.168.1.1", false},
{"8.8.8.8", false},
{"141.227.165.168", false},
}
for _, tt := range tests {
if got := isInternalIP(tt.ip); got != tt.internal {
t.Errorf("isInternalIP(%q) = %v, want %v", tt.ip, got, tt.internal)
}
}
}
func TestSecurityHeaders(t *testing.T) {
gw := &Gateway{}
handler := gw.securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-Proto", "https")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
expected := map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "0",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
}
for header, want := range expected {
if got := w.Header().Get(header); got != want {
t.Errorf("header %s = %q, want %q", header, got, want)
}
}
}
func TestSecurityHeaders_NoHSTS_WithoutTLS(t *testing.T) {
gw := &Gateway{}
handler := gw.securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if got := w.Header().Get("Strict-Transport-Security"); got != "" {
t.Errorf("HSTS should not be set without TLS, got %q", got)
}
}
func TestRateLimitMiddleware_Returns429(t *testing.T) {
gw := &Gateway{rateLimiter: NewRateLimiter(60, 1)}
handler := gw.rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request should pass
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "8.8.8.8:1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("first request should be 200, got %d", w.Code)
}
// Second request should be rate limited
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("second request should be 429, got %d", w.Code)
}
if w.Header().Get("Retry-After") == "" {
t.Fatal("should have Retry-After header")
}
}
func TestRateLimitMiddleware_ExemptsInternalTraffic(t *testing.T) {
gw := &Gateway{rateLimiter: NewRateLimiter(60, 1)}
handler := gw.rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Internal IP should never be rate limited
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "10.0.0.1:1234"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("internal request %d should be 200, got %d", i, w.Code)
}
}
}

View File

@ -24,6 +24,18 @@ func (g *Gateway) Routes() http.Handler {
mux.HandleFunc("/v1/internal/acme/present", g.acmePresentHandler) mux.HandleFunc("/v1/internal/acme/present", g.acmePresentHandler)
mux.HandleFunc("/v1/internal/acme/cleanup", g.acmeCleanupHandler) mux.HandleFunc("/v1/internal/acme/cleanup", g.acmeCleanupHandler)
// WireGuard peer exchange (internal, cluster-secret auth)
if g.wireguardHandler != nil {
mux.HandleFunc("/v1/internal/wg/peer", g.wireguardHandler.HandleRegisterPeer)
mux.HandleFunc("/v1/internal/wg/peers", g.wireguardHandler.HandleListPeers)
mux.HandleFunc("/v1/internal/wg/peer/remove", g.wireguardHandler.HandleRemovePeer)
}
// Node join endpoint (token-authenticated, no middleware auth needed)
if g.joinHandler != nil {
mux.HandleFunc("/v1/internal/join", g.joinHandler.HandleJoin)
}
// auth endpoints // auth endpoints
mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler) mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler)
mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler) mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler)

View File

@ -226,8 +226,8 @@ func (cm *ClusterManager) startRQLiteCluster(ctx context.Context, cluster *Names
NodeID: nodes[0].NodeID, NodeID: nodes[0].NodeID,
HTTPPort: portBlocks[0].RQLiteHTTPPort, HTTPPort: portBlocks[0].RQLiteHTTPPort,
RaftPort: portBlocks[0].RQLiteRaftPort, RaftPort: portBlocks[0].RQLiteRaftPort,
HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[0].IPAddress, portBlocks[0].RQLiteHTTPPort), HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[0].InternalIP, portBlocks[0].RQLiteHTTPPort),
RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[0].IPAddress, portBlocks[0].RQLiteRaftPort), RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[0].InternalIP, portBlocks[0].RQLiteRaftPort),
IsLeader: true, IsLeader: true,
} }
@ -254,8 +254,8 @@ func (cm *ClusterManager) startRQLiteCluster(ctx context.Context, cluster *Names
NodeID: nodes[i].NodeID, NodeID: nodes[i].NodeID,
HTTPPort: portBlocks[i].RQLiteHTTPPort, HTTPPort: portBlocks[i].RQLiteHTTPPort,
RaftPort: portBlocks[i].RQLiteRaftPort, RaftPort: portBlocks[i].RQLiteRaftPort,
HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[i].IPAddress, portBlocks[i].RQLiteHTTPPort), HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[i].InternalIP, portBlocks[i].RQLiteHTTPPort),
RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[i].IPAddress, portBlocks[i].RQLiteRaftPort), RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[i].InternalIP, portBlocks[i].RQLiteRaftPort),
JoinAddresses: []string{leaderRaftAddr}, JoinAddresses: []string{leaderRaftAddr},
IsLeader: false, IsLeader: false,
} }
@ -288,7 +288,7 @@ func (cm *ClusterManager) startOlricCluster(ctx context.Context, cluster *Namesp
// Build peer addresses (all nodes) // Build peer addresses (all nodes)
peerAddresses := make([]string, len(nodes)) peerAddresses := make([]string, len(nodes))
for i, node := range nodes { for i, node := range nodes {
peerAddresses[i] = fmt.Sprintf("%s:%d", node.IPAddress, portBlocks[i].OlricMemberlistPort) peerAddresses[i] = fmt.Sprintf("%s:%d", node.InternalIP, portBlocks[i].OlricMemberlistPort)
} }
// Start all Olric instances // Start all Olric instances
@ -299,7 +299,7 @@ func (cm *ClusterManager) startOlricCluster(ctx context.Context, cluster *Namesp
HTTPPort: portBlocks[i].OlricHTTPPort, HTTPPort: portBlocks[i].OlricHTTPPort,
MemberlistPort: portBlocks[i].OlricMemberlistPort, MemberlistPort: portBlocks[i].OlricMemberlistPort,
BindAddr: "0.0.0.0", BindAddr: "0.0.0.0",
AdvertiseAddr: node.IPAddress, AdvertiseAddr: node.InternalIP,
PeerAddresses: peerAddresses, PeerAddresses: peerAddresses,
} }

View File

@ -23,6 +23,7 @@ type ClusterNodeSelector struct {
type NodeCapacity struct { type NodeCapacity struct {
NodeID string `json:"node_id"` NodeID string `json:"node_id"`
IPAddress string `json:"ip_address"` IPAddress string `json:"ip_address"`
InternalIP string `json:"internal_ip"` // WireGuard IP for inter-node communication
DeploymentCount int `json:"deployment_count"` DeploymentCount int `json:"deployment_count"`
AllocatedPorts int `json:"allocated_ports"` AllocatedPorts int `json:"allocated_ports"`
AvailablePorts int `json:"available_ports"` AvailablePorts int `json:"available_ports"`
@ -59,7 +60,7 @@ func (cns *ClusterNodeSelector) SelectNodesForCluster(ctx context.Context, nodeC
// Filter nodes that have capacity for namespace instances // Filter nodes that have capacity for namespace instances
eligibleNodes := make([]NodeCapacity, 0) eligibleNodes := make([]NodeCapacity, 0)
for _, node := range activeNodes { for _, node := range activeNodes {
capacity, err := cns.getNodeCapacity(internalCtx, node.NodeID, node.IPAddress) capacity, err := cns.getNodeCapacity(internalCtx, node.NodeID, node.IPAddress, node.InternalIP)
if err != nil { if err != nil {
cns.logger.Warn("Failed to get node capacity, skipping", cns.logger.Warn("Failed to get node capacity, skipping",
zap.String("node_id", node.NodeID), zap.String("node_id", node.NodeID),
@ -117,8 +118,9 @@ func (cns *ClusterNodeSelector) SelectNodesForCluster(ctx context.Context, nodeC
// nodeInfo is used for querying active nodes // nodeInfo is used for querying active nodes
type nodeInfo struct { type nodeInfo struct {
NodeID string `db:"id"` NodeID string `db:"id"`
IPAddress string `db:"ip_address"` IPAddress string `db:"ip_address"`
InternalIP string `db:"internal_ip"`
} }
// getActiveNodes retrieves all active nodes from dns_nodes table // getActiveNodes retrieves all active nodes from dns_nodes table
@ -128,7 +130,7 @@ func (cns *ClusterNodeSelector) getActiveNodes(ctx context.Context) ([]nodeInfo,
var results []nodeInfo var results []nodeInfo
query := ` query := `
SELECT id, ip_address FROM dns_nodes SELECT id, ip_address, COALESCE(internal_ip, ip_address) as internal_ip FROM dns_nodes
WHERE status = 'active' AND last_seen > ? WHERE status = 'active' AND last_seen > ?
ORDER BY id ORDER BY id
` `
@ -148,7 +150,7 @@ func (cns *ClusterNodeSelector) getActiveNodes(ctx context.Context) ([]nodeInfo,
} }
// getNodeCapacity calculates capacity metrics for a single node // getNodeCapacity calculates capacity metrics for a single node
func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipAddress string) (*NodeCapacity, error) { func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipAddress, internalIP string) (*NodeCapacity, error) {
// Get deployment count // Get deployment count
deploymentCount, err := cns.getDeploymentCount(ctx, nodeID) deploymentCount, err := cns.getDeploymentCount(ctx, nodeID)
if err != nil { if err != nil {
@ -209,6 +211,7 @@ func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipA
capacity := &NodeCapacity{ capacity := &NodeCapacity{
NodeID: nodeID, NodeID: nodeID,
IPAddress: ipAddress, IPAddress: ipAddress,
InternalIP: internalIP,
DeploymentCount: deploymentCount, DeploymentCount: deploymentCount,
AllocatedPorts: allocatedPorts, AllocatedPorts: allocatedPorts,
AvailablePorts: availablePorts, AvailablePorts: availablePorts,
@ -365,7 +368,7 @@ func (cns *ClusterNodeSelector) GetNodeByID(ctx context.Context, nodeID string)
internalCtx := client.WithInternalAuth(ctx) internalCtx := client.WithInternalAuth(ctx)
var results []nodeInfo var results []nodeInfo
query := `SELECT id, ip_address FROM dns_nodes WHERE id = ? LIMIT 1` query := `SELECT id, ip_address, COALESCE(internal_ip, ip_address) as internal_ip FROM dns_nodes WHERE id = ? LIMIT 1`
err := cns.db.Query(internalCtx, &results, query, nodeID) err := cns.db.Query(internalCtx, &results, query, nodeID)
if err != nil { if err != nil {
return nil, &ClusterError{ return nil, &ClusterError{

View File

@ -29,8 +29,11 @@ func (n *Node) registerDNSNode(ctx context.Context) error {
ipAddress = "127.0.0.1" ipAddress = "127.0.0.1"
} }
// Get internal IP (same as external for now, or could use private network IP) // Get internal IP from WireGuard interface (for cross-node communication over VPN)
internalIP := ipAddress internalIP := ipAddress
if wgIP, err := n.getWireGuardIP(); err == nil && wgIP != "" {
internalIP = wgIP
}
// Determine region (defaulting to "local" for now, could be from cloud metadata in future) // Determine region (defaulting to "local" for now, could be from cloud metadata in future)
region := "local" region := "local"
@ -297,6 +300,24 @@ func (n *Node) cleanupStaleNodeRecords(ctx context.Context) {
} }
} }
// getWireGuardIP returns the IPv4 address assigned to the wg0 interface, if any
func (n *Node) getWireGuardIP() (string, error) {
iface, err := net.InterfaceByName("wg0")
if err != nil {
return "", err
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil {
return ipnet.IP.String(), nil
}
}
return "", fmt.Errorf("no IPv4 address on wg0")
}
// getNodeIPAddress attempts to determine the node's external IP address // getNodeIPAddress attempts to determine the node's external IP address
func (n *Node) getNodeIPAddress() (string, error) { func (n *Node) getNodeIPAddress() (string, error) {
// Try to detect external IP by connecting to a public server // Try to detect external IP by connecting to a public server

View File

@ -33,6 +33,15 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
return err return err
} }
// DataDir in node config is ~/.orama/data; the orama dir is the parent
oramaDir := filepath.Join(os.ExpandEnv(n.config.Node.DataDir), "..")
// Read cluster secret for WireGuard peer exchange auth
clusterSecret := ""
if secretBytes, err := os.ReadFile(filepath.Join(oramaDir, "secrets", "cluster-secret")); err == nil {
clusterSecret = string(secretBytes)
}
gwCfg := &gateway.Config{ gwCfg := &gateway.Config{
ListenAddr: n.config.HTTPGateway.ListenAddr, ListenAddr: n.config.HTTPGateway.ListenAddr,
ClientNamespace: n.config.HTTPGateway.ClientNamespace, ClientNamespace: n.config.HTTPGateway.ClientNamespace,
@ -45,6 +54,8 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL, IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout, IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
BaseDomain: n.config.HTTPGateway.BaseDomain, BaseDomain: n.config.HTTPGateway.BaseDomain,
DataDir: oramaDir,
ClusterSecret: clusterSecret,
} }
apiGateway, err := gateway.New(gatewayLogger, gwCfg) apiGateway, err := gateway.New(gatewayLogger, gwCfg)

View File

@ -103,6 +103,9 @@ func (n *Node) Start(ctx context.Context) error {
return fmt.Errorf("failed to start RQLite: %w", err) return fmt.Errorf("failed to start RQLite: %w", err)
} }
// Sync WireGuard peers from RQLite (if WG is active on this node)
n.startWireGuardSyncLoop(ctx)
// Register this node in dns_nodes table for deployment routing // Register this node in dns_nodes table for deployment routing
if err := n.registerDNSNode(ctx); err != nil { if err := n.registerDNSNode(ctx); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to register DNS node", zap.Error(err)) n.logger.ComponentWarn(logging.ComponentNode, "Failed to register DNS node", zap.Error(err))

223
pkg/node/wireguard_sync.go Normal file
View File

@ -0,0 +1,223 @@
package node
import (
"context"
"fmt"
"net"
"os/exec"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/environments/production"
"github.com/DeBrosOfficial/network/pkg/logging"
"go.uber.org/zap"
)
// syncWireGuardPeers reads all peers from RQLite and reconciles the local
// WireGuard interface so it matches the cluster state. This is called on
// startup after RQLite is ready and periodically thereafter.
func (n *Node) syncWireGuardPeers(ctx context.Context) error {
if n.rqliteAdapter == nil {
return fmt.Errorf("rqlite adapter not initialized")
}
// Check if WireGuard is installed and active
if _, err := exec.LookPath("wg"); err != nil {
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard not installed, skipping peer sync")
return nil
}
// Check if wg0 interface exists
out, err := exec.CommandContext(ctx, "sudo", "wg", "show", "wg0").CombinedOutput()
if err != nil {
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard interface wg0 not active, skipping peer sync")
return nil
}
// Parse current peers from wg show output
currentPeers := parseWGShowPeers(string(out))
localPubKey := parseWGShowLocalKey(string(out))
// Query all peers from RQLite
db := n.rqliteAdapter.GetSQLDB()
rows, err := db.QueryContext(ctx,
"SELECT node_id, wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip")
if err != nil {
return fmt.Errorf("failed to query wireguard_peers: %w", err)
}
defer rows.Close()
// Build desired peer set (excluding self)
desiredPeers := make(map[string]production.WireGuardPeer)
for rows.Next() {
var nodeID, wgIP, pubKey, pubIP string
var wgPort int
if err := rows.Scan(&nodeID, &wgIP, &pubKey, &pubIP, &wgPort); err != nil {
continue
}
if pubKey == localPubKey {
continue // skip self
}
if wgPort == 0 {
wgPort = 51820
}
desiredPeers[pubKey] = production.WireGuardPeer{
PublicKey: pubKey,
Endpoint: fmt.Sprintf("%s:%d", pubIP, wgPort),
AllowedIP: wgIP + "/32",
}
}
wp := &production.WireGuardProvisioner{}
// Add missing peers
for pubKey, peer := range desiredPeers {
if _, exists := currentPeers[pubKey]; !exists {
if err := wp.AddPeer(peer); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "failed to add WG peer",
zap.String("public_key", pubKey[:8]+"..."),
zap.Error(err))
} else {
n.logger.ComponentInfo(logging.ComponentNode, "added WG peer",
zap.String("allowed_ip", peer.AllowedIP))
}
}
}
// Remove peers not in the desired set
for pubKey := range currentPeers {
if _, exists := desiredPeers[pubKey]; !exists {
if err := wp.RemovePeer(pubKey); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "failed to remove stale WG peer",
zap.String("public_key", pubKey[:8]+"..."),
zap.Error(err))
} else {
n.logger.ComponentInfo(logging.ComponentNode, "removed stale WG peer",
zap.String("public_key", pubKey[:8]+"..."))
}
}
}
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard peer sync completed",
zap.Int("desired_peers", len(desiredPeers)),
zap.Int("current_peers", len(currentPeers)))
return nil
}
// ensureWireGuardSelfRegistered ensures this node's WireGuard info is in the
// wireguard_peers table. Without this, joining nodes get an empty peer list
// from the /v1/internal/join endpoint and can't establish WG tunnels.
func (n *Node) ensureWireGuardSelfRegistered(ctx context.Context) {
if n.rqliteAdapter == nil {
return
}
// Check if wg0 is active
out, err := exec.CommandContext(ctx, "sudo", "wg", "show", "wg0").CombinedOutput()
if err != nil {
return // WG not active, nothing to register
}
// Get local public key
localPubKey := parseWGShowLocalKey(string(out))
if localPubKey == "" {
return
}
// Get WG IP from interface
wgIP := ""
iface, err := net.InterfaceByName("wg0")
if err != nil {
return
}
addrs, err := iface.Addrs()
if err != nil {
return
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil {
wgIP = ipnet.IP.String()
break
}
}
if wgIP == "" {
return
}
// Get public IP
publicIP, err := n.getNodeIPAddress()
if err != nil {
return
}
nodeID := n.GetPeerID()
if nodeID == "" {
nodeID = fmt.Sprintf("node-%s", wgIP)
}
db := n.rqliteAdapter.GetSQLDB()
_, err = db.ExecContext(ctx,
"INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)",
nodeID, wgIP, localPubKey, publicIP, 51820)
if err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to self-register WG peer", zap.Error(err))
} else {
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard self-registered",
zap.String("wg_ip", wgIP),
zap.String("public_key", localPubKey[:8]+"..."))
}
}
// startWireGuardSyncLoop runs syncWireGuardPeers periodically
func (n *Node) startWireGuardSyncLoop(ctx context.Context) {
// Ensure this node is registered in wireguard_peers (critical for join flow)
n.ensureWireGuardSelfRegistered(ctx)
// Run initial sync
if err := n.syncWireGuardPeers(ctx); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "initial WireGuard peer sync failed", zap.Error(err))
}
// Periodic sync every 60 seconds
go func() {
ticker := time.NewTicker(60 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := n.syncWireGuardPeers(ctx); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "WireGuard peer sync failed", zap.Error(err))
}
}
}
}()
}
// parseWGShowPeers extracts public keys of current peers from `wg show wg0` output
func parseWGShowPeers(output string) map[string]struct{} {
peers := make(map[string]struct{})
for _, line := range strings.Split(output, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "peer:") {
key := strings.TrimSpace(strings.TrimPrefix(line, "peer:"))
if key != "" {
peers[key] = struct{}{}
}
}
}
return peers
}
// parseWGShowLocalKey extracts the local public key from `wg show wg0` output
func parseWGShowLocalKey(output string) string {
for _, line := range strings.Split(output, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "public key:") {
return strings.TrimSpace(strings.TrimPrefix(line, "public key:"))
}
}
return ""
}

View File

@ -38,6 +38,13 @@ func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rq
} }
requiredRemotePeers := r.config.MinClusterSize - 1 requiredRemotePeers := r.config.MinClusterSize - 1
// Genesis node (single-node cluster) doesn't need to wait for peers
if requiredRemotePeers <= 0 {
r.logger.Info("Genesis node, skipping peer discovery wait")
return nil
}
_ = r.discoveryService.TriggerPeerExchange(ctx) _ = r.discoveryService.TriggerPeerExchange(ctx)
checkInterval := 2 * time.Second checkInterval := 2 * time.Second

View File

@ -17,8 +17,55 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// killOrphanedRQLite kills any orphaned rqlited process still holding the port.
// This can happen when the parent node process crashes and rqlited keeps running.
func (r *RQLiteManager) killOrphanedRQLite() {
// Check if port is already in use by querying the status endpoint
url := fmt.Sprintf("http://localhost:%d/status", r.config.RQLitePort)
client := &http.Client{Timeout: 2 * time.Second}
resp, err := client.Get(url)
if err != nil {
return // Port not in use, nothing to clean up
}
resp.Body.Close()
// Port is in use — find and kill the orphaned process
r.logger.Warn("Found orphaned rqlited process on port, killing it",
zap.Int("port", r.config.RQLitePort))
// Use fuser to find and kill the process holding the port
cmd := exec.Command("fuser", "-k", fmt.Sprintf("%d/tcp", r.config.RQLitePort))
if err := cmd.Run(); err != nil {
r.logger.Warn("fuser failed, trying lsof", zap.Error(err))
// Fallback: use lsof
out, err := exec.Command("lsof", "-ti", fmt.Sprintf(":%d", r.config.RQLitePort)).Output()
if err == nil {
for _, pidStr := range strings.Split(strings.TrimSpace(string(out)), "\n") {
if pidStr != "" {
killCmd := exec.Command("kill", "-9", pidStr)
killCmd.Run()
}
}
}
}
// Wait for port to be released
for i := 0; i < 10; i++ {
time.Sleep(500 * time.Millisecond)
resp, err := client.Get(url)
if err != nil {
return // Port released
}
resp.Body.Close()
}
r.logger.Warn("Could not release port from orphaned process")
}
// launchProcess starts the RQLite process with appropriate arguments // launchProcess starts the RQLite process with appropriate arguments
func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) error { func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) error {
// Kill any orphaned rqlited from a previous crash
r.killOrphanedRQLite()
// Build RQLite command // Build RQLite command
args := []string{ args := []string{
"-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort), "-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort),
@ -180,21 +227,30 @@ func (r *RQLiteManager) waitForReady(ctx context.Context) error {
// waitForSQLAvailable waits until a simple query succeeds // waitForSQLAvailable waits until a simple query succeeds
func (r *RQLiteManager) waitForSQLAvailable(ctx context.Context) error { func (r *RQLiteManager) waitForSQLAvailable(ctx context.Context) error {
r.logger.Info("Waiting for SQL to become available...")
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop() defer ticker.Stop()
attempts := 0
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
r.logger.Error("waitForSQLAvailable timed out", zap.Int("attempts", attempts))
return ctx.Err() return ctx.Err()
case <-ticker.C: case <-ticker.C:
attempts++
if r.connection == nil { if r.connection == nil {
r.logger.Warn("connection is nil in waitForSQLAvailable")
continue continue
} }
_, err := r.connection.QueryOne("SELECT 1") _, err := r.connection.QueryOne("SELECT 1")
if err == nil { if err == nil {
r.logger.Info("SQL is available", zap.Int("attempts", attempts))
return nil return nil
} }
if attempts <= 5 || attempts%10 == 0 {
r.logger.Debug("SQL not yet available", zap.Int("attempt", attempts), zap.Error(err))
}
} }
} }
} }

View File

@ -74,6 +74,9 @@ on:github.com/coredns/caddy/onevent
sign:sign sign:sign
view:view view:view
# Response Rate Limiting (DNS amplification protection)
rrl:rrl
# Custom RQLite plugin # Custom RQLite plugin
rqlite:github.com/DeBrosOfficial/network/pkg/coredns/rqlite rqlite:github.com/DeBrosOfficial/network/pkg/coredns/rqlite
EOF EOF