mirror of
https://github.com/DeBrosOfficial/network.git
synced 2026-01-30 19:03:03 +00:00
Added wireguard and updated installation process and added more tests
This commit is contained in:
parent
dcaf695fbc
commit
4acea72467
6
Makefile
6
Makefile
@ -22,7 +22,7 @@ check-gateway:
|
||||
echo " 3. Run tests: make test-e2e-local"; \
|
||||
echo ""; \
|
||||
echo "To run tests against production:"; \
|
||||
echo " ORAMA_GATEWAY_URL=http://VPS-IP:6001 make test-e2e"; \
|
||||
echo " ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "✅ Gateway is running"
|
||||
@ -36,7 +36,7 @@ test-e2e-local: check-gateway
|
||||
test-e2e-prod:
|
||||
@if [ -z "$$ORAMA_GATEWAY_URL" ]; then \
|
||||
echo "❌ ORAMA_GATEWAY_URL not set"; \
|
||||
echo "Usage: ORAMA_GATEWAY_URL=http://VPS-IP:6001 make test-e2e-prod"; \
|
||||
echo "Usage: ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e-prod"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@echo "Running E2E tests (including production-only) against $$ORAMA_GATEWAY_URL..."
|
||||
@ -182,7 +182,7 @@ help:
|
||||
@echo " make test-e2e - Generic E2E tests (auto-discovers config)"
|
||||
@echo ""
|
||||
@echo " Example production test:"
|
||||
@echo " ORAMA_GATEWAY_URL=http://141.227.165.168:6001 make test-e2e-prod"
|
||||
@echo " ORAMA_GATEWAY_URL=https://dbrs.space make test-e2e-prod"
|
||||
@echo ""
|
||||
@echo "Development Management (via orama):"
|
||||
@echo " ./bin/orama dev status - Show status of all dev services"
|
||||
|
||||
@ -53,6 +53,8 @@ func main() {
|
||||
cli.HandleProdCommand(args)
|
||||
|
||||
// Direct production commands (new simplified interface)
|
||||
case "invite":
|
||||
cli.HandleProdCommand(append([]string{"invite"}, args...))
|
||||
case "install":
|
||||
cli.HandleProdCommand(append([]string{"install"}, args...))
|
||||
case "upgrade":
|
||||
|
||||
@ -353,12 +353,22 @@ Function Invocation:
|
||||
- Refresh token support
|
||||
- Claims-based authorization
|
||||
|
||||
### Network Security (WireGuard Mesh)
|
||||
|
||||
All inter-node communication is encrypted via a WireGuard VPN mesh:
|
||||
|
||||
- **WireGuard IPs:** Each node gets a private IP (10.0.0.x) used for all cluster traffic
|
||||
- **UFW Firewall:** Only public ports are exposed: 22 (SSH), 53 (DNS, nameservers only), 80/443 (HTTP/HTTPS), 51820 (WireGuard UDP)
|
||||
- **Internal services** (RQLite 5001/7001, IPFS 4001/4501, Olric 3320/3322, Gateway 6001) are only accessible via WireGuard or localhost
|
||||
- **Invite tokens:** Single-use, time-limited tokens for secure node joining. No shared secrets on the CLI
|
||||
- **Join flow:** New nodes authenticate via HTTPS (443), establish WireGuard tunnel, then join all services over the encrypted mesh
|
||||
|
||||
### TLS/HTTPS
|
||||
|
||||
- Automatic ACME (Let's Encrypt) certificate management
|
||||
- Automatic ACME (Let's Encrypt) certificate management via Caddy
|
||||
- TLS 1.3 support
|
||||
- HTTP/2 enabled
|
||||
- Certificate caching
|
||||
- On-demand TLS for deployment custom domains
|
||||
|
||||
### Middleware Stack
|
||||
|
||||
@ -441,17 +451,25 @@ make test-e2e # Run E2E tests
|
||||
### Production
|
||||
|
||||
```bash
|
||||
# First node (creates cluster)
|
||||
sudo orama install --vps-ip <IP> --domain node1.example.com
|
||||
# First node (genesis — creates cluster)
|
||||
sudo orama install --vps-ip <IP> --domain node1.example.com --nameserver
|
||||
|
||||
# Additional nodes (join cluster)
|
||||
sudo orama install --vps-ip <IP> --domain node2.example.com \
|
||||
--peers /dns4/node1.example.com/tcp/4001/p2p/<PEER_ID> \
|
||||
--join <node1-ip>:7002 \
|
||||
--cluster-secret <secret> \
|
||||
--swarm-key <key>
|
||||
# On the genesis node, generate an invite for a new node
|
||||
orama invite
|
||||
# Outputs: sudo orama install --join https://node1.example.com --token <TOKEN> --vps-ip <NEW_IP>
|
||||
|
||||
# Additional nodes (join via invite token over HTTPS)
|
||||
sudo orama install --join https://node1.example.com --token <TOKEN> \
|
||||
--vps-ip <IP> --nameserver
|
||||
```
|
||||
|
||||
**Security:** Nodes join via single-use invite tokens over HTTPS. A WireGuard VPN tunnel
|
||||
is established before any cluster services start. All inter-node traffic (RQLite, IPFS,
|
||||
Olric, LibP2P) flows over the encrypted WireGuard mesh — no cluster ports are exposed
|
||||
publicly. **Never use `http://<ip>:6001`** for joining — port 6001 is internal-only and
|
||||
blocked by UFW. Use the domain (`https://node1.example.com`) or, if DNS is not yet
|
||||
configured, use the IP over HTTP port 80 (`http://<ip>`) which goes through Caddy.
|
||||
|
||||
### Docker (Future)
|
||||
|
||||
Planned containerization with Docker Compose and Kubernetes support.
|
||||
|
||||
@ -95,14 +95,74 @@ To deploy to all nodes, repeat steps 3-5 (dev) or 3-4 (production) for each VPS
|
||||
|
||||
### CLI Flags Reference
|
||||
|
||||
#### `orama install`
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--branch <branch>` | Git branch to pull from (production deployment) |
|
||||
| `--no-pull` | Skip git pull, use existing `/home/debros/src` (dev deployment) |
|
||||
| `--vps-ip <ip>` | VPS public IP address (required) |
|
||||
| `--domain <domain>` | Domain for HTTPS certificates |
|
||||
| `--base-domain <domain>` | Base domain for deployment routing (e.g., dbrs.space) |
|
||||
| `--nameserver` | Configure this node as a nameserver (CoreDNS + Caddy) |
|
||||
| `--join <url>` | Join existing cluster via HTTPS URL (e.g., `https://node1.dbrs.space`) |
|
||||
| `--token <token>` | Invite token for joining (from `orama invite` on existing node) |
|
||||
| `--branch <branch>` | Git branch to use (default: main) |
|
||||
| `--no-pull` | Skip git clone/pull, use existing `/home/debros/src` |
|
||||
| `--force` | Force reconfiguration even if already installed |
|
||||
| `--skip-firewall` | Skip UFW firewall setup |
|
||||
| `--skip-checks` | Skip minimum resource checks (RAM/CPU) |
|
||||
|
||||
#### `orama invite`
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--expiry <duration>` | Token expiry duration (default: 1h) |
|
||||
|
||||
#### `orama upgrade`
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--branch <branch>` | Git branch to pull from |
|
||||
| `--no-pull` | Skip git pull, use existing source |
|
||||
| `--restart` | Restart all services after upgrade |
|
||||
| `--nameserver` | Configure this node as a nameserver (install only) |
|
||||
| `--domain <domain>` | Domain for HTTPS certificates (install only) |
|
||||
| `--vps-ip <ip>` | VPS public IP address (install only) |
|
||||
|
||||
### Node Join Flow
|
||||
|
||||
```bash
|
||||
# 1. Genesis node (first node, creates cluster)
|
||||
sudo orama install --vps-ip 1.2.3.4 --domain node1.dbrs.space \
|
||||
--base-domain dbrs.space --nameserver
|
||||
|
||||
# 2. On genesis node, generate an invite
|
||||
orama invite
|
||||
# Output: sudo orama install --join https://node1.dbrs.space --token <TOKEN> --vps-ip <IP>
|
||||
|
||||
# 3. On the new node, run the printed command
|
||||
sudo orama install --join https://node1.dbrs.space --token abc123... \
|
||||
--vps-ip 5.6.7.8 --nameserver
|
||||
```
|
||||
|
||||
The join flow establishes a WireGuard VPN tunnel before starting cluster services.
|
||||
All inter-node communication (RQLite, IPFS, Olric) uses WireGuard IPs (10.0.0.x).
|
||||
No cluster ports are ever exposed publicly.
|
||||
|
||||
#### DNS Prerequisite
|
||||
|
||||
The `--join` URL should use the HTTPS domain of the genesis node (e.g., `https://node1.dbrs.space`).
|
||||
For this to work, the domain registrar for `dbrs.space` must have NS records pointing to the genesis
|
||||
node's IP so that `node1.dbrs.space` resolves publicly.
|
||||
|
||||
**If DNS is not yet configured**, you can use the genesis node's public IP with HTTP as a fallback:
|
||||
|
||||
```bash
|
||||
sudo orama install --join http://1.2.3.4 --vps-ip 5.6.7.8 --token abc123... --nameserver
|
||||
```
|
||||
|
||||
This works because Caddy's `:80` block proxies all HTTP traffic to the gateway. However, once DNS
|
||||
is properly configured, always use the HTTPS domain URL.
|
||||
|
||||
**Important:** Never use `http://<ip>:6001` — port 6001 is the internal gateway and is blocked by
|
||||
UFW from external access. The join request goes through Caddy on port 80 (HTTP) or 443 (HTTPS),
|
||||
which proxies to the gateway internally.
|
||||
|
||||
## Debugging Production Issues
|
||||
|
||||
|
||||
@ -170,9 +170,9 @@ func TestNamespaceCluster_OlricHealth(t *testing.T) {
|
||||
func TestNamespaceCluster_GatewayHealth(t *testing.T) {
|
||||
// Check if gateway binary exists
|
||||
gatewayBinaryPaths := []string{
|
||||
"./bin/gateway",
|
||||
"../bin/gateway",
|
||||
"/usr/local/bin/orama-gateway",
|
||||
"./bin/orama",
|
||||
"../bin/orama",
|
||||
"/usr/local/bin/orama",
|
||||
}
|
||||
|
||||
var gatewayBinaryExists bool
|
||||
|
||||
9
migrations/013_wireguard_peers.sql
Normal file
9
migrations/013_wireguard_peers.sql
Normal file
@ -0,0 +1,9 @@
|
||||
-- WireGuard mesh peer tracking
|
||||
CREATE TABLE IF NOT EXISTS wireguard_peers (
|
||||
node_id TEXT PRIMARY KEY,
|
||||
wg_ip TEXT NOT NULL UNIQUE,
|
||||
public_key TEXT NOT NULL UNIQUE,
|
||||
public_ip TEXT NOT NULL,
|
||||
wg_port INTEGER DEFAULT 51820,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
8
migrations/014_invite_tokens.sql
Normal file
8
migrations/014_invite_tokens.sql
Normal file
@ -0,0 +1,8 @@
|
||||
CREATE TABLE IF NOT EXISTS invite_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
created_by TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at DATETIME NOT NULL,
|
||||
used_at DATETIME,
|
||||
used_by_ip TEXT
|
||||
);
|
||||
@ -7,42 +7,32 @@ import (
|
||||
)
|
||||
|
||||
// TestProdCommandFlagParsing verifies that prod command flags are parsed correctly
|
||||
// Note: The installer now uses --vps-ip presence to determine if it's a first node (no --bootstrap flag)
|
||||
// First node: has --vps-ip but no --peers or --join
|
||||
// Joining node: has --vps-ip, --peers, and --cluster-secret
|
||||
// Genesis node: has --vps-ip but no --join or --token
|
||||
// Joining node: has --vps-ip, --join (HTTPS URL), and --token (invite token)
|
||||
func TestProdCommandFlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectVPSIP string
|
||||
expectDomain string
|
||||
expectPeers string
|
||||
expectJoin string
|
||||
expectSecret string
|
||||
expectBranch string
|
||||
isFirstNode bool // first node = no peers and no join address
|
||||
name string
|
||||
args []string
|
||||
expectVPSIP string
|
||||
expectDomain string
|
||||
expectJoin string
|
||||
expectToken string
|
||||
expectBranch string
|
||||
isFirstNode bool // genesis node = no --join and no --token
|
||||
}{
|
||||
{
|
||||
name: "first node (creates new cluster)",
|
||||
args: []string{"install", "--vps-ip", "10.0.0.1", "--domain", "node-1.example.com"},
|
||||
expectVPSIP: "10.0.0.1",
|
||||
name: "genesis node (creates new cluster)",
|
||||
args: []string{"install", "--vps-ip", "10.0.0.1", "--domain", "node-1.example.com"},
|
||||
expectVPSIP: "10.0.0.1",
|
||||
expectDomain: "node-1.example.com",
|
||||
isFirstNode: true,
|
||||
isFirstNode: true,
|
||||
},
|
||||
{
|
||||
name: "joining node with peers",
|
||||
args: []string{"install", "--vps-ip", "10.0.0.2", "--peers", "/ip4/10.0.0.1/tcp/4001/p2p/Qm123", "--cluster-secret", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"},
|
||||
expectVPSIP: "10.0.0.2",
|
||||
expectPeers: "/ip4/10.0.0.1/tcp/4001/p2p/Qm123",
|
||||
expectSecret: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
isFirstNode: false,
|
||||
},
|
||||
{
|
||||
name: "joining node with join address",
|
||||
args: []string{"install", "--vps-ip", "10.0.0.3", "--join", "10.0.0.1:7001", "--cluster-secret", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"},
|
||||
expectVPSIP: "10.0.0.3",
|
||||
expectJoin: "10.0.0.1:7001",
|
||||
expectSecret: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
name: "joining node with invite token",
|
||||
args: []string{"install", "--vps-ip", "10.0.0.2", "--join", "https://node1.dbrs.space", "--token", "abc123def456"},
|
||||
expectVPSIP: "10.0.0.2",
|
||||
expectJoin: "https://node1.dbrs.space",
|
||||
expectToken: "abc123def456",
|
||||
isFirstNode: false,
|
||||
},
|
||||
{
|
||||
@ -56,8 +46,7 @@ func TestProdCommandFlagParsing(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Extract flags manually to verify parsing logic
|
||||
var vpsIP, domain, peersStr, joinAddr, clusterSecret, branch string
|
||||
var vpsIP, domain, joinAddr, token, branch string
|
||||
|
||||
for i, arg := range tt.args {
|
||||
switch arg {
|
||||
@ -69,17 +58,13 @@ func TestProdCommandFlagParsing(t *testing.T) {
|
||||
if i+1 < len(tt.args) {
|
||||
domain = tt.args[i+1]
|
||||
}
|
||||
case "--peers":
|
||||
if i+1 < len(tt.args) {
|
||||
peersStr = tt.args[i+1]
|
||||
}
|
||||
case "--join":
|
||||
if i+1 < len(tt.args) {
|
||||
joinAddr = tt.args[i+1]
|
||||
}
|
||||
case "--cluster-secret":
|
||||
case "--token":
|
||||
if i+1 < len(tt.args) {
|
||||
clusterSecret = tt.args[i+1]
|
||||
token = tt.args[i+1]
|
||||
}
|
||||
case "--branch":
|
||||
if i+1 < len(tt.args) {
|
||||
@ -88,8 +73,8 @@ func TestProdCommandFlagParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// First node detection: no peers and no join address
|
||||
isFirstNode := peersStr == "" && joinAddr == ""
|
||||
// Genesis node detection: no --join and no --token
|
||||
isFirstNode := joinAddr == "" && token == ""
|
||||
|
||||
if vpsIP != tt.expectVPSIP {
|
||||
t.Errorf("expected vpsIP=%q, got %q", tt.expectVPSIP, vpsIP)
|
||||
@ -97,14 +82,11 @@ func TestProdCommandFlagParsing(t *testing.T) {
|
||||
if domain != tt.expectDomain {
|
||||
t.Errorf("expected domain=%q, got %q", tt.expectDomain, domain)
|
||||
}
|
||||
if peersStr != tt.expectPeers {
|
||||
t.Errorf("expected peers=%q, got %q", tt.expectPeers, peersStr)
|
||||
}
|
||||
if joinAddr != tt.expectJoin {
|
||||
t.Errorf("expected join=%q, got %q", tt.expectJoin, joinAddr)
|
||||
}
|
||||
if clusterSecret != tt.expectSecret {
|
||||
t.Errorf("expected clusterSecret=%q, got %q", tt.expectSecret, clusterSecret)
|
||||
if token != tt.expectToken {
|
||||
t.Errorf("expected token=%q, got %q", tt.expectToken, token)
|
||||
}
|
||||
if branch != tt.expectBranch {
|
||||
t.Errorf("expected branch=%q, got %q", tt.expectBranch, branch)
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/production/install"
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/production/invite"
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/production/lifecycle"
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/production/logs"
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/production/migrate"
|
||||
@ -24,6 +25,8 @@ func HandleCommand(args []string) {
|
||||
subargs := args[1:]
|
||||
|
||||
switch subcommand {
|
||||
case "invite":
|
||||
invite.Handle(subargs)
|
||||
case "install":
|
||||
install.Handle(subargs)
|
||||
case "upgrade":
|
||||
|
||||
@ -17,10 +17,11 @@ type Flags struct {
|
||||
DryRun bool
|
||||
SkipChecks bool
|
||||
Nameserver bool // Make this node a nameserver (runs CoreDNS + Caddy)
|
||||
JoinAddress string
|
||||
ClusterSecret string
|
||||
SwarmKey string
|
||||
PeersStr string
|
||||
JoinAddress string // HTTPS URL of existing node (e.g., https://node1.dbrs.space)
|
||||
Token string // Invite token for joining (from orama invite)
|
||||
ClusterSecret string // Deprecated: use --token instead
|
||||
SwarmKey string // Deprecated: use --token instead
|
||||
PeersStr string // Deprecated: use --token instead
|
||||
|
||||
// IPFS/Cluster specific info for Peering configuration
|
||||
IPFSPeerID string
|
||||
@ -28,6 +29,9 @@ type Flags struct {
|
||||
IPFSClusterPeerID string
|
||||
IPFSClusterAddrs string
|
||||
|
||||
// Security flags
|
||||
SkipFirewall bool // Skip UFW firewall setup (for users who manage their own firewall)
|
||||
|
||||
// Anyone relay operator flags
|
||||
AnyoneRelay bool // Run as relay operator instead of client
|
||||
AnyoneExit bool // Run as exit relay (legal implications)
|
||||
@ -57,9 +61,10 @@ func ParseFlags(args []string) (*Flags, error) {
|
||||
fs.BoolVar(&flags.Nameserver, "nameserver", false, "Make this node a nameserver (runs CoreDNS + Caddy)")
|
||||
|
||||
// Cluster join flags
|
||||
fs.StringVar(&flags.JoinAddress, "join", "", "Join an existing cluster (e.g. 1.2.3.4:7001)")
|
||||
fs.StringVar(&flags.ClusterSecret, "cluster-secret", "", "Cluster secret for IPFS Cluster (required if joining)")
|
||||
fs.StringVar(&flags.SwarmKey, "swarm-key", "", "IPFS Swarm key hex (64 chars, last line of swarm.key)")
|
||||
fs.StringVar(&flags.JoinAddress, "join", "", "Join existing cluster via HTTPS URL (e.g. https://node1.dbrs.space)")
|
||||
fs.StringVar(&flags.Token, "token", "", "Invite token for joining (from orama invite on existing node)")
|
||||
fs.StringVar(&flags.ClusterSecret, "cluster-secret", "", "Deprecated: use --token instead")
|
||||
fs.StringVar(&flags.SwarmKey, "swarm-key", "", "Deprecated: use --token instead")
|
||||
fs.StringVar(&flags.PeersStr, "peers", "", "Comma-separated list of bootstrap peer multiaddrs")
|
||||
|
||||
// IPFS/Cluster specific info for Peering configuration
|
||||
@ -68,6 +73,9 @@ func ParseFlags(args []string) (*Flags, error) {
|
||||
fs.StringVar(&flags.IPFSClusterPeerID, "ipfs-cluster-peer", "", "Peer ID of existing IPFS Cluster node")
|
||||
fs.StringVar(&flags.IPFSClusterAddrs, "ipfs-cluster-addrs", "", "Comma-separated multiaddrs of existing IPFS Cluster node")
|
||||
|
||||
// Security flags
|
||||
fs.BoolVar(&flags.SkipFirewall, "skip-firewall", false, "Skip UFW firewall setup (for users who manage their own firewall)")
|
||||
|
||||
// Anyone relay operator flags
|
||||
fs.BoolVar(&flags.AnyoneRelay, "anyone-relay", false, "Run as Anyone relay operator (earn rewards)")
|
||||
fs.BoolVar(&flags.AnyoneExit, "anyone-exit", false, "Run as exit relay (requires --anyone-relay, legal implications)")
|
||||
|
||||
@ -2,14 +2,20 @@ package install
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/cli/utils"
|
||||
"github.com/DeBrosOfficial/network/pkg/environments/production"
|
||||
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
|
||||
)
|
||||
|
||||
// Orchestrator manages the install process
|
||||
@ -97,9 +103,11 @@ func (o *Orchestrator) Execute() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save secrets before installation
|
||||
if err := o.validator.SaveSecrets(); err != nil {
|
||||
return err
|
||||
// Save secrets before installation (only for genesis; join flow gets secrets from response)
|
||||
if !o.isJoiningNode() {
|
||||
if err := o.validator.SaveSecrets(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Save preferences for future upgrades (branch + nameserver)
|
||||
@ -132,33 +140,56 @@ func (o *Orchestrator) Execute() error {
|
||||
return fmt.Errorf("binary installation failed: %w", err)
|
||||
}
|
||||
|
||||
// Phase 3: Generate secrets FIRST (before service initialization)
|
||||
// Branch: genesis node vs joining node
|
||||
if o.isJoiningNode() {
|
||||
return o.executeJoinFlow()
|
||||
}
|
||||
return o.executeGenesisFlow()
|
||||
}
|
||||
|
||||
// isJoiningNode returns true if --join and --token are both set
|
||||
func (o *Orchestrator) isJoiningNode() bool {
|
||||
return o.flags.JoinAddress != "" && o.flags.Token != ""
|
||||
}
|
||||
|
||||
// executeGenesisFlow runs the install for the first node in a new cluster
|
||||
func (o *Orchestrator) executeGenesisFlow() error {
|
||||
// Phase 3: Generate secrets locally
|
||||
fmt.Printf("\n🔐 Phase 3: Generating secrets...\n")
|
||||
if err := o.setup.Phase3GenerateSecrets(); err != nil {
|
||||
return fmt.Errorf("secret generation failed: %w", err)
|
||||
}
|
||||
|
||||
// Phase 4: Generate configs (BEFORE service initialization)
|
||||
// Phase 6a: WireGuard — self-assign 10.0.0.1
|
||||
fmt.Printf("\n🔒 Phase 6a: Setting up WireGuard mesh VPN...\n")
|
||||
if _, _, err := o.setup.Phase6SetupWireGuard(true); err != nil {
|
||||
fmt.Fprintf(os.Stderr, " ⚠️ Warning: WireGuard setup failed: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf(" ✓ WireGuard configured (10.0.0.1)\n")
|
||||
}
|
||||
|
||||
// Phase 6b: UFW firewall
|
||||
fmt.Printf("\n🛡️ Phase 6b: Setting up UFW firewall...\n")
|
||||
if err := o.setup.Phase6bSetupFirewall(o.flags.SkipFirewall); err != nil {
|
||||
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall setup failed: %v\n", err)
|
||||
}
|
||||
|
||||
// Phase 4: Generate configs using WG IP (10.0.0.1) as advertise address
|
||||
// All inter-node communication uses WireGuard IPs, not public IPs
|
||||
fmt.Printf("\n⚙️ Phase 4: Generating configurations...\n")
|
||||
// Internal gateway always runs HTTP on port 6001
|
||||
// When using Caddy (nameserver mode), Caddy handles external HTTPS and proxies to internal gateway
|
||||
// When not using Caddy, the gateway runs HTTP-only (use a reverse proxy for HTTPS)
|
||||
enableHTTPS := false
|
||||
if err := o.setup.Phase4GenerateConfigs(o.peers, o.flags.VpsIP, enableHTTPS, o.flags.Domain, o.flags.BaseDomain, o.flags.JoinAddress); err != nil {
|
||||
genesisWGIP := "10.0.0.1"
|
||||
if err := o.setup.Phase4GenerateConfigs(o.peers, genesisWGIP, enableHTTPS, o.flags.Domain, o.flags.BaseDomain, ""); err != nil {
|
||||
return fmt.Errorf("configuration generation failed: %w", err)
|
||||
}
|
||||
|
||||
// Validate generated configuration
|
||||
if err := o.validator.ValidateGeneratedConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Phase 2c: Initialize services (after config is in place)
|
||||
// Phase 2c: Initialize services (use WG IP for IPFS Cluster peer discovery)
|
||||
fmt.Printf("\nPhase 2c: Initializing services...\n")
|
||||
ipfsPeerInfo := o.buildIPFSPeerInfo()
|
||||
ipfsClusterPeerInfo := o.buildIPFSClusterPeerInfo()
|
||||
|
||||
if err := o.setup.Phase2cInitializeServices(o.peers, o.flags.VpsIP, ipfsPeerInfo, ipfsClusterPeerInfo); err != nil {
|
||||
if err := o.setup.Phase2cInitializeServices(o.peers, genesisWGIP, nil, nil); err != nil {
|
||||
return fmt.Errorf("service initialization failed: %w", err)
|
||||
}
|
||||
|
||||
@ -168,9 +199,9 @@ func (o *Orchestrator) Execute() error {
|
||||
return fmt.Errorf("service creation failed: %w", err)
|
||||
}
|
||||
|
||||
// Seed DNS records after services are running (RQLite must be up)
|
||||
// Phase 7: Seed DNS records
|
||||
if o.flags.Nameserver && o.flags.BaseDomain != "" {
|
||||
fmt.Printf("\n🌐 Phase 6: Seeding DNS records...\n")
|
||||
fmt.Printf("\n🌐 Phase 7: Seeding DNS records...\n")
|
||||
fmt.Printf(" Waiting for RQLite to start (10s)...\n")
|
||||
time.Sleep(10 * time.Second)
|
||||
if err := o.setup.SeedDNSRecords(o.flags.BaseDomain, o.flags.VpsIP, o.peers); err != nil {
|
||||
@ -180,18 +211,206 @@ func (o *Orchestrator) Execute() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Log completion with actual peer ID
|
||||
o.setup.LogSetupComplete(o.setup.NodePeerID)
|
||||
fmt.Printf("✅ Production installation complete!\n\n")
|
||||
o.printFirstNodeSecrets()
|
||||
return nil
|
||||
}
|
||||
|
||||
// For first node, print important secrets and identifiers
|
||||
if o.validator.IsFirstNode() {
|
||||
o.printFirstNodeSecrets()
|
||||
// executeJoinFlow runs the install for a node joining an existing cluster via invite token
|
||||
func (o *Orchestrator) executeJoinFlow() error {
|
||||
// Step 1: Generate WG keypair
|
||||
fmt.Printf("\n🔑 Generating WireGuard keypair...\n")
|
||||
privKey, pubKey, err := production.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate WG keypair: %w", err)
|
||||
}
|
||||
fmt.Printf(" ✓ WireGuard keypair generated\n")
|
||||
|
||||
// Step 2: Call join endpoint on existing node
|
||||
fmt.Printf("\n🤝 Requesting cluster join from %s...\n", o.flags.JoinAddress)
|
||||
joinResp, err := o.callJoinEndpoint(pubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("join request failed: %w", err)
|
||||
}
|
||||
fmt.Printf(" ✓ Join approved — assigned WG IP: %s\n", joinResp.WGIP)
|
||||
fmt.Printf(" ✓ Received %d WG peers\n", len(joinResp.WGPeers))
|
||||
|
||||
// Step 3: Configure WireGuard with assigned IP and peers
|
||||
fmt.Printf("\n🔒 Configuring WireGuard tunnel...\n")
|
||||
var wgPeers []production.WireGuardPeer
|
||||
for _, p := range joinResp.WGPeers {
|
||||
wgPeers = append(wgPeers, production.WireGuardPeer{
|
||||
PublicKey: p.PublicKey,
|
||||
Endpoint: p.Endpoint,
|
||||
AllowedIP: p.AllowedIP,
|
||||
})
|
||||
}
|
||||
// Install WG package first
|
||||
wp := production.NewWireGuardProvisioner(production.WireGuardConfig{})
|
||||
if err := wp.Install(); err != nil {
|
||||
return fmt.Errorf("failed to install wireguard: %w", err)
|
||||
}
|
||||
if err := o.setup.EnableWireGuardWithPeers(privKey, joinResp.WGIP, wgPeers); err != nil {
|
||||
return fmt.Errorf("failed to enable WireGuard: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Verify WG tunnel
|
||||
fmt.Printf("\n🔍 Verifying WireGuard tunnel...\n")
|
||||
if err := o.verifyWGTunnel(joinResp.WGPeers); err != nil {
|
||||
return fmt.Errorf("WireGuard tunnel verification failed: %w", err)
|
||||
}
|
||||
fmt.Printf(" ✓ WireGuard tunnel established\n")
|
||||
|
||||
// Step 5: UFW firewall
|
||||
fmt.Printf("\n🛡️ Setting up UFW firewall...\n")
|
||||
if err := o.setup.Phase6bSetupFirewall(o.flags.SkipFirewall); err != nil {
|
||||
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall setup failed: %v\n", err)
|
||||
}
|
||||
|
||||
// Step 6: Save secrets from join response
|
||||
fmt.Printf("\n🔐 Saving cluster secrets...\n")
|
||||
if err := o.saveSecretsFromJoinResponse(joinResp); err != nil {
|
||||
return fmt.Errorf("failed to save secrets: %w", err)
|
||||
}
|
||||
fmt.Printf(" ✓ Secrets saved\n")
|
||||
|
||||
// Step 7: Generate configs using WG IP as advertise address
|
||||
// All inter-node communication uses WireGuard IPs, not public IPs
|
||||
fmt.Printf("\n⚙️ Generating configurations...\n")
|
||||
enableHTTPS := false
|
||||
rqliteJoin := joinResp.RQLiteJoinAddress
|
||||
if err := o.setup.Phase4GenerateConfigs(joinResp.BootstrapPeers, joinResp.WGIP, enableHTTPS, o.flags.Domain, joinResp.BaseDomain, rqliteJoin); err != nil {
|
||||
return fmt.Errorf("configuration generation failed: %w", err)
|
||||
}
|
||||
|
||||
if err := o.validator.ValidateGeneratedConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 8: Initialize services with IPFS peer info from join response
|
||||
fmt.Printf("\nInitializing services...\n")
|
||||
var ipfsPeerInfo *production.IPFSPeerInfo
|
||||
if joinResp.IPFSPeer.ID != "" {
|
||||
ipfsPeerInfo = &production.IPFSPeerInfo{
|
||||
PeerID: joinResp.IPFSPeer.ID,
|
||||
Addrs: joinResp.IPFSPeer.Addrs,
|
||||
}
|
||||
}
|
||||
var ipfsClusterPeerInfo *production.IPFSClusterPeerInfo
|
||||
if joinResp.IPFSClusterPeer.ID != "" {
|
||||
ipfsClusterPeerInfo = &production.IPFSClusterPeerInfo{
|
||||
PeerID: joinResp.IPFSClusterPeer.ID,
|
||||
Addrs: joinResp.IPFSClusterPeer.Addrs,
|
||||
}
|
||||
}
|
||||
|
||||
if err := o.setup.Phase2cInitializeServices(joinResp.BootstrapPeers, joinResp.WGIP, ipfsPeerInfo, ipfsClusterPeerInfo); err != nil {
|
||||
return fmt.Errorf("service initialization failed: %w", err)
|
||||
}
|
||||
|
||||
// Step 9: Create systemd services
|
||||
fmt.Printf("\n🔧 Creating systemd services...\n")
|
||||
if err := o.setup.Phase5CreateSystemdServices(enableHTTPS); err != nil {
|
||||
return fmt.Errorf("service creation failed: %w", err)
|
||||
}
|
||||
|
||||
o.setup.LogSetupComplete(o.setup.NodePeerID)
|
||||
fmt.Printf("✅ Production installation complete! Joined cluster via %s\n\n", o.flags.JoinAddress)
|
||||
return nil
|
||||
}
|
||||
|
||||
// callJoinEndpoint sends the join request to the existing node's HTTPS endpoint
|
||||
func (o *Orchestrator) callJoinEndpoint(wgPubKey string) (*joinhandlers.JoinResponse, error) {
|
||||
reqBody := joinhandlers.JoinRequest{
|
||||
Token: o.flags.Token,
|
||||
WGPublicKey: wgPubKey,
|
||||
PublicIP: o.flags.VpsIP,
|
||||
}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(o.flags.JoinAddress, "/") + "/v1/internal/join"
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, // Self-signed certs during initial setup
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Post(url, "application/json", strings.NewReader(string(bodyBytes)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to contact %s: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("join rejected (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var joinResp joinhandlers.JoinResponse
|
||||
if err := json.Unmarshal(respBody, &joinResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse join response: %w", err)
|
||||
}
|
||||
|
||||
return &joinResp, nil
|
||||
}
|
||||
|
||||
// saveSecretsFromJoinResponse writes cluster secrets received from the join endpoint to disk
|
||||
func (o *Orchestrator) saveSecretsFromJoinResponse(resp *joinhandlers.JoinResponse) error {
|
||||
secretsDir := filepath.Join(o.oramaDir, "secrets")
|
||||
if err := os.MkdirAll(secretsDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create secrets dir: %w", err)
|
||||
}
|
||||
|
||||
// Write cluster secret
|
||||
if resp.ClusterSecret != "" {
|
||||
if err := os.WriteFile(filepath.Join(secretsDir, "cluster-secret"), []byte(resp.ClusterSecret), 0600); err != nil {
|
||||
return fmt.Errorf("failed to write cluster-secret: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Write swarm key
|
||||
if resp.SwarmKey != "" {
|
||||
if err := os.WriteFile(filepath.Join(secretsDir, "swarm.key"), []byte(resp.SwarmKey), 0600); err != nil {
|
||||
return fmt.Errorf("failed to write swarm.key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyWGTunnel pings a WG peer to verify the tunnel is working
|
||||
func (o *Orchestrator) verifyWGTunnel(peers []joinhandlers.WGPeerInfo) error {
|
||||
if len(peers) == 0 {
|
||||
return fmt.Errorf("no WG peers to verify")
|
||||
}
|
||||
|
||||
// Extract the IP from the first peer's AllowedIP (e.g. "10.0.0.1/32" -> "10.0.0.1")
|
||||
targetIP := strings.TrimSuffix(peers[0].AllowedIP, "/32")
|
||||
|
||||
// Retry ping for up to 30 seconds
|
||||
deadline := time.Now().Add(30 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
cmd := exec.Command("ping", "-c", "1", "-W", "2", targetIP)
|
||||
if err := cmd.Run(); err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
return fmt.Errorf("could not reach %s via WireGuard after 30s", targetIP)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildIPFSPeerInfo() *production.IPFSPeerInfo {
|
||||
if o.flags.IPFSPeerID != "" {
|
||||
var addrs []string
|
||||
|
||||
115
pkg/cli/production/invite/command.go
Normal file
115
pkg/cli/production/invite/command.go
Normal file
@ -0,0 +1,115 @@
|
||||
package invite
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Handle processes the invite command
|
||||
func Handle(args []string) {
|
||||
// Must run on a cluster node with RQLite running locally
|
||||
domain, err := readNodeDomain()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: could not read node config: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Make sure you're running this on an installed node.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Generate random token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating token: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
token := hex.EncodeToString(tokenBytes)
|
||||
|
||||
// Determine expiry (default 1 hour, --expiry flag for override)
|
||||
expiry := time.Hour
|
||||
for i, arg := range args {
|
||||
if arg == "--expiry" && i+1 < len(args) {
|
||||
d, err := time.ParseDuration(args[i+1])
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Invalid expiry duration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
expiry = d
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt := time.Now().UTC().Add(expiry).Format("2006-01-02 15:04:05")
|
||||
|
||||
// Get node ID for created_by
|
||||
nodeID := "unknown"
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
nodeID = hostname
|
||||
}
|
||||
|
||||
// Insert token into RQLite via HTTP API
|
||||
if err := insertToken(token, nodeID, expiresAt); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error storing invite token: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Make sure RQLite is running on this node.\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Print the invite command
|
||||
fmt.Printf("\nInvite token created (expires in %s)\n\n", expiry)
|
||||
fmt.Printf("Run this on the new node:\n\n")
|
||||
fmt.Printf(" sudo orama install --join https://%s --token %s --vps-ip <NEW_NODE_IP> --nameserver\n\n", domain, token)
|
||||
fmt.Printf("Replace <NEW_NODE_IP> with the new node's public IP address.\n")
|
||||
}
|
||||
|
||||
// readNodeDomain reads the domain from the node config file
|
||||
func readNodeDomain() (string, error) {
|
||||
configPath := "/home/debros/.orama/configs/node.yaml"
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
|
||||
var config struct {
|
||||
Node struct {
|
||||
Domain string `yaml:"domain"`
|
||||
} `yaml:"node"`
|
||||
}
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return "", fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
if config.Node.Domain == "" {
|
||||
return "", fmt.Errorf("node domain not set in config")
|
||||
}
|
||||
|
||||
return config.Node.Domain, nil
|
||||
}
|
||||
|
||||
// insertToken inserts an invite token into RQLite via HTTP API
|
||||
func insertToken(token, createdBy, expiresAt string) error {
|
||||
body := fmt.Sprintf(`[["INSERT INTO invite_tokens (token, created_by, expires_at) VALUES ('%s', '%s', '%s')"]]`,
|
||||
token, createdBy, expiresAt)
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost:5001/db/execute", strings.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to RQLite: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("RQLite returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -147,6 +147,12 @@ func (o *Orchestrator) Execute() error {
|
||||
fmt.Fprintf(os.Stderr, "⚠️ Service update warning: %v\n", err)
|
||||
}
|
||||
|
||||
// Re-apply UFW firewall rules (idempotent)
|
||||
fmt.Printf("\n🛡️ Re-applying firewall rules...\n")
|
||||
if err := o.setup.Phase6bSetupFirewall(false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, " ⚠️ Warning: Firewall re-apply failed: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n✅ Upgrade complete!\n")
|
||||
|
||||
// Restart services if requested
|
||||
@ -317,12 +323,24 @@ func (o *Orchestrator) extractGatewayConfig() (enableHTTPS bool, domain string,
|
||||
}
|
||||
}
|
||||
|
||||
// Also check node.yaml for base_domain
|
||||
// Also check node.yaml for domain and base_domain
|
||||
nodeConfigPath := filepath.Join(o.oramaDir, "configs", "node.yaml")
|
||||
if data, err := os.ReadFile(nodeConfigPath); err == nil {
|
||||
configStr := string(data)
|
||||
for _, line := range strings.Split(configStr, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
// Extract domain from node.yaml (under node: section) if not already found
|
||||
if domain == "" && strings.HasPrefix(trimmed, "domain:") && !strings.HasPrefix(trimmed, "domain_") {
|
||||
parts := strings.SplitN(trimmed, ":", 2)
|
||||
if len(parts) > 1 {
|
||||
d := strings.TrimSpace(parts[1])
|
||||
d = strings.Trim(d, "\"'")
|
||||
if d != "" && d != "null" {
|
||||
domain = d
|
||||
enableHTTPS = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "base_domain:") {
|
||||
parts := strings.SplitN(trimmed, ":", 2)
|
||||
if len(parts) > 1 {
|
||||
@ -332,7 +350,6 @@ func (o *Orchestrator) extractGatewayConfig() (enableHTTPS bool, domain string,
|
||||
baseDomain = ""
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,7 +259,7 @@ func (rm *ReplicaManager) GetNodeIP(ctx context.Context, nodeID string) (string,
|
||||
}
|
||||
|
||||
var rows []nodeRow
|
||||
query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`
|
||||
query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1`
|
||||
err := rm.db.Query(internalCtx, &rows, query, nodeID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
133
pkg/environments/production/firewall.go
Normal file
133
pkg/environments/production/firewall.go
Normal file
@ -0,0 +1,133 @@
|
||||
package production
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FirewallConfig holds the configuration for UFW firewall rules
|
||||
type FirewallConfig struct {
|
||||
SSHPort int // default 22
|
||||
IsNameserver bool // enables port 53 TCP+UDP
|
||||
AnyoneORPort int // 0 = disabled, typically 9001
|
||||
WireGuardPort int // default 51820
|
||||
}
|
||||
|
||||
// FirewallProvisioner manages UFW firewall setup
|
||||
type FirewallProvisioner struct {
|
||||
config FirewallConfig
|
||||
}
|
||||
|
||||
// NewFirewallProvisioner creates a new firewall provisioner
|
||||
func NewFirewallProvisioner(config FirewallConfig) *FirewallProvisioner {
|
||||
if config.SSHPort == 0 {
|
||||
config.SSHPort = 22
|
||||
}
|
||||
if config.WireGuardPort == 0 {
|
||||
config.WireGuardPort = 51820
|
||||
}
|
||||
return &FirewallProvisioner{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// IsInstalled checks if UFW is available
|
||||
func (fp *FirewallProvisioner) IsInstalled() bool {
|
||||
_, err := exec.LookPath("ufw")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Install installs UFW if not present
|
||||
func (fp *FirewallProvisioner) Install() error {
|
||||
if fp.IsInstalled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("apt-get", "install", "-y", "ufw")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to install ufw: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRules returns the list of UFW commands to apply
|
||||
func (fp *FirewallProvisioner) GenerateRules() []string {
|
||||
rules := []string{
|
||||
// Reset to clean state
|
||||
"ufw --force reset",
|
||||
|
||||
// Default policies
|
||||
"ufw default deny incoming",
|
||||
"ufw default allow outgoing",
|
||||
|
||||
// SSH (always required)
|
||||
fmt.Sprintf("ufw allow %d/tcp", fp.config.SSHPort),
|
||||
|
||||
// WireGuard (always required for mesh)
|
||||
fmt.Sprintf("ufw allow %d/udp", fp.config.WireGuardPort),
|
||||
|
||||
// Public web services
|
||||
"ufw allow 80/tcp", // ACME / HTTP redirect
|
||||
"ufw allow 443/tcp", // HTTPS (Caddy → Gateway)
|
||||
}
|
||||
|
||||
// DNS (only for nameserver nodes)
|
||||
if fp.config.IsNameserver {
|
||||
rules = append(rules, "ufw allow 53/tcp")
|
||||
rules = append(rules, "ufw allow 53/udp")
|
||||
}
|
||||
|
||||
// Anyone relay ORPort
|
||||
if fp.config.AnyoneORPort > 0 {
|
||||
rules = append(rules, fmt.Sprintf("ufw allow %d/tcp", fp.config.AnyoneORPort))
|
||||
}
|
||||
|
||||
// Allow all traffic from WireGuard subnet (inter-node encrypted traffic)
|
||||
rules = append(rules, "ufw allow from 10.0.0.0/8")
|
||||
|
||||
// Enable firewall
|
||||
rules = append(rules, "ufw --force enable")
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// Setup applies all firewall rules. Idempotent — safe to call multiple times.
|
||||
func (fp *FirewallProvisioner) Setup() error {
|
||||
if err := fp.Install(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rules := fp.GenerateRules()
|
||||
|
||||
for _, rule := range rules {
|
||||
parts := strings.Fields(rule)
|
||||
cmd := exec.Command(parts[0], parts[1:]...)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to apply firewall rule '%s': %w\n%s", rule, err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsActive checks if UFW is active
|
||||
func (fp *FirewallProvisioner) IsActive() bool {
|
||||
cmd := exec.Command("ufw", "status")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(string(output), "Status: active")
|
||||
}
|
||||
|
||||
// GetStatus returns the current UFW status
|
||||
func (fp *FirewallProvisioner) GetStatus() (string, error) {
|
||||
cmd := exec.Command("ufw", "status", "verbose")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get ufw status: %w\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
117
pkg/environments/production/firewall_test.go
Normal file
117
pkg/environments/production/firewall_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
package production
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFirewallProvisioner_GenerateRules_StandardNode(t *testing.T) {
|
||||
fp := NewFirewallProvisioner(FirewallConfig{})
|
||||
|
||||
rules := fp.GenerateRules()
|
||||
|
||||
// Should contain defaults
|
||||
assertContainsRule(t, rules, "ufw --force reset")
|
||||
assertContainsRule(t, rules, "ufw default deny incoming")
|
||||
assertContainsRule(t, rules, "ufw default allow outgoing")
|
||||
assertContainsRule(t, rules, "ufw allow 22/tcp")
|
||||
assertContainsRule(t, rules, "ufw allow 51820/udp")
|
||||
assertContainsRule(t, rules, "ufw allow 80/tcp")
|
||||
assertContainsRule(t, rules, "ufw allow 443/tcp")
|
||||
assertContainsRule(t, rules, "ufw allow from 10.0.0.0/8")
|
||||
assertContainsRule(t, rules, "ufw --force enable")
|
||||
|
||||
// Should NOT contain DNS or Anyone relay
|
||||
for _, rule := range rules {
|
||||
if strings.Contains(rule, "53/") {
|
||||
t.Errorf("standard node should not have DNS rule: %s", rule)
|
||||
}
|
||||
if strings.Contains(rule, "9001") {
|
||||
t.Errorf("standard node should not have Anyone relay rule: %s", rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallProvisioner_GenerateRules_Nameserver(t *testing.T) {
|
||||
fp := NewFirewallProvisioner(FirewallConfig{
|
||||
IsNameserver: true,
|
||||
})
|
||||
|
||||
rules := fp.GenerateRules()
|
||||
|
||||
assertContainsRule(t, rules, "ufw allow 53/tcp")
|
||||
assertContainsRule(t, rules, "ufw allow 53/udp")
|
||||
}
|
||||
|
||||
func TestFirewallProvisioner_GenerateRules_WithAnyoneRelay(t *testing.T) {
|
||||
fp := NewFirewallProvisioner(FirewallConfig{
|
||||
AnyoneORPort: 9001,
|
||||
})
|
||||
|
||||
rules := fp.GenerateRules()
|
||||
|
||||
assertContainsRule(t, rules, "ufw allow 9001/tcp")
|
||||
}
|
||||
|
||||
func TestFirewallProvisioner_GenerateRules_CustomSSHPort(t *testing.T) {
|
||||
fp := NewFirewallProvisioner(FirewallConfig{
|
||||
SSHPort: 2222,
|
||||
})
|
||||
|
||||
rules := fp.GenerateRules()
|
||||
|
||||
assertContainsRule(t, rules, "ufw allow 2222/tcp")
|
||||
|
||||
// Should NOT have default port 22
|
||||
for _, rule := range rules {
|
||||
if rule == "ufw allow 22/tcp" {
|
||||
t.Error("should not have default SSH port 22 when custom port is set")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirewallProvisioner_GenerateRules_WireGuardSubnetAllowed(t *testing.T) {
|
||||
fp := NewFirewallProvisioner(FirewallConfig{})
|
||||
|
||||
rules := fp.GenerateRules()
|
||||
|
||||
assertContainsRule(t, rules, "ufw allow from 10.0.0.0/8")
|
||||
}
|
||||
|
||||
func TestFirewallProvisioner_GenerateRules_FullConfig(t *testing.T) {
|
||||
fp := NewFirewallProvisioner(FirewallConfig{
|
||||
SSHPort: 2222,
|
||||
IsNameserver: true,
|
||||
AnyoneORPort: 9001,
|
||||
WireGuardPort: 51821,
|
||||
})
|
||||
|
||||
rules := fp.GenerateRules()
|
||||
|
||||
assertContainsRule(t, rules, "ufw allow 2222/tcp")
|
||||
assertContainsRule(t, rules, "ufw allow 51821/udp")
|
||||
assertContainsRule(t, rules, "ufw allow 53/tcp")
|
||||
assertContainsRule(t, rules, "ufw allow 53/udp")
|
||||
assertContainsRule(t, rules, "ufw allow 9001/tcp")
|
||||
}
|
||||
|
||||
func TestFirewallProvisioner_DefaultPorts(t *testing.T) {
|
||||
fp := NewFirewallProvisioner(FirewallConfig{})
|
||||
|
||||
if fp.config.SSHPort != 22 {
|
||||
t.Errorf("default SSHPort = %d, want 22", fp.config.SSHPort)
|
||||
}
|
||||
if fp.config.WireGuardPort != 51820 {
|
||||
t.Errorf("default WireGuardPort = %d, want 51820", fp.config.WireGuardPort)
|
||||
}
|
||||
}
|
||||
|
||||
func assertContainsRule(t *testing.T, rules []string, expected string) {
|
||||
t.Helper()
|
||||
for _, rule := range rules {
|
||||
if rule == expected {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Errorf("rules should contain '%s', got: %v", expected, rules)
|
||||
}
|
||||
@ -254,6 +254,13 @@ func (ps *ProductionSetup) Phase2ProvisionEnvironment() error {
|
||||
ps.logf(" ✓ Deployment sudoers configured")
|
||||
}
|
||||
|
||||
// Set up WireGuard sudoers (allows debros user to manage WG peers)
|
||||
if err := ps.userProvisioner.SetupWireGuardSudoers(); err != nil {
|
||||
ps.logf(" ⚠️ Failed to setup wireguard sudoers: %v", err)
|
||||
} else {
|
||||
ps.logf(" ✓ WireGuard sudoers configured")
|
||||
}
|
||||
|
||||
// Create directory structure (unified structure)
|
||||
if err := ps.fsProvisioner.EnsureDirectoryStructure(); err != nil {
|
||||
return fmt.Errorf("failed to create directory structure: %w", err)
|
||||
@ -724,6 +731,25 @@ func (ps *ProductionSetup) Phase5CreateSystemdServices(enableHTTPS bool) error {
|
||||
ps.logf(" - debros-node.service started (with embedded gateway)")
|
||||
}
|
||||
|
||||
// Start CoreDNS and Caddy (nameserver nodes only)
|
||||
// Caddy depends on debros-node.service (gateway on :6001), so start after node
|
||||
if ps.isNameserver {
|
||||
if _, err := os.Stat("/usr/local/bin/coredns"); err == nil {
|
||||
if err := ps.serviceController.StartService("coredns.service"); err != nil {
|
||||
ps.logf(" ⚠️ Failed to start coredns.service: %v", err)
|
||||
} else {
|
||||
ps.logf(" - coredns.service started")
|
||||
}
|
||||
}
|
||||
if _, err := os.Stat("/usr/bin/caddy"); err == nil {
|
||||
if err := ps.serviceController.StartService("caddy.service"); err != nil {
|
||||
ps.logf(" ⚠️ Failed to start caddy.service: %v", err)
|
||||
} else {
|
||||
ps.logf(" - caddy.service started")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ps.logf(" ✓ All services started")
|
||||
return nil
|
||||
}
|
||||
@ -775,6 +801,96 @@ func (ps *ProductionSetup) SeedDNSRecords(baseDomain, vpsIP string, peerAddresse
|
||||
return nil
|
||||
}
|
||||
|
||||
// Phase6SetupWireGuard installs WireGuard and generates keys for this node.
|
||||
// For the first node, it self-assigns 10.0.0.1. For joining nodes, the peer
|
||||
// exchange happens via HTTPS in the install CLI orchestrator.
|
||||
func (ps *ProductionSetup) Phase6SetupWireGuard(isFirstNode bool) (privateKey, publicKey string, err error) {
|
||||
ps.logf("Phase 6a: Setting up WireGuard...")
|
||||
|
||||
wp := NewWireGuardProvisioner(WireGuardConfig{})
|
||||
|
||||
// Install WireGuard package
|
||||
if err := wp.Install(); err != nil {
|
||||
return "", "", fmt.Errorf("failed to install wireguard: %w", err)
|
||||
}
|
||||
ps.logf(" ✓ WireGuard installed")
|
||||
|
||||
// Generate keypair
|
||||
privKey, pubKey, err := GenerateKeyPair()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate WG keys: %w", err)
|
||||
}
|
||||
ps.logf(" ✓ WireGuard keypair generated")
|
||||
|
||||
if isFirstNode {
|
||||
// First node: self-assign 10.0.0.1, no peers yet
|
||||
wp.config = WireGuardConfig{
|
||||
PrivateKey: privKey,
|
||||
PrivateIP: "10.0.0.1",
|
||||
ListenPort: 51820,
|
||||
}
|
||||
if err := wp.WriteConfig(); err != nil {
|
||||
return "", "", fmt.Errorf("failed to write WG config: %w", err)
|
||||
}
|
||||
if err := wp.Enable(); err != nil {
|
||||
return "", "", fmt.Errorf("failed to enable WG: %w", err)
|
||||
}
|
||||
ps.logf(" ✓ WireGuard enabled (first node: 10.0.0.1)")
|
||||
}
|
||||
|
||||
return privKey, pubKey, nil
|
||||
}
|
||||
|
||||
// Phase6bSetupFirewall sets up UFW firewall rules
|
||||
func (ps *ProductionSetup) Phase6bSetupFirewall(skipFirewall bool) error {
|
||||
if skipFirewall {
|
||||
ps.logf("Phase 6b: Skipping firewall setup (--skip-firewall)")
|
||||
return nil
|
||||
}
|
||||
|
||||
ps.logf("Phase 6b: Setting up UFW firewall...")
|
||||
|
||||
anyoneORPort := 0
|
||||
if ps.IsAnyoneRelay() && ps.anyoneRelayConfig != nil {
|
||||
anyoneORPort = ps.anyoneRelayConfig.ORPort
|
||||
}
|
||||
|
||||
fp := NewFirewallProvisioner(FirewallConfig{
|
||||
SSHPort: 22,
|
||||
IsNameserver: ps.isNameserver,
|
||||
AnyoneORPort: anyoneORPort,
|
||||
WireGuardPort: 51820,
|
||||
})
|
||||
|
||||
if err := fp.Setup(); err != nil {
|
||||
return fmt.Errorf("firewall setup failed: %w", err)
|
||||
}
|
||||
|
||||
ps.logf(" ✓ UFW firewall configured and enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableWireGuardWithPeers writes WG config with assigned IP and peers, then enables it.
|
||||
// Called by joining nodes after peer exchange.
|
||||
func (ps *ProductionSetup) EnableWireGuardWithPeers(privateKey, assignedIP string, peers []WireGuardPeer) error {
|
||||
wp := NewWireGuardProvisioner(WireGuardConfig{
|
||||
PrivateKey: privateKey,
|
||||
PrivateIP: assignedIP,
|
||||
ListenPort: 51820,
|
||||
Peers: peers,
|
||||
})
|
||||
|
||||
if err := wp.WriteConfig(); err != nil {
|
||||
return fmt.Errorf("failed to write WG config: %w", err)
|
||||
}
|
||||
if err := wp.Enable(); err != nil {
|
||||
return fmt.Errorf("failed to enable WG: %w", err)
|
||||
}
|
||||
|
||||
ps.logf(" ✓ WireGuard enabled (IP: %s, peers: %d)", assignedIP, len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogSetupComplete logs completion information
|
||||
func (ps *ProductionSetup) LogSetupComplete(peerID string) {
|
||||
ps.logf("\n" + strings.Repeat("=", 70))
|
||||
|
||||
@ -224,6 +224,34 @@ debros ALL=(ALL) NOPASSWD: /bin/rm -f /etc/systemd/system/orama-deploy-*.service
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupWireGuardSudoers configures the debros user with permissions to manage WireGuard
|
||||
func (up *UserProvisioner) SetupWireGuardSudoers() error {
|
||||
sudoersFile := "/etc/sudoers.d/debros-wireguard"
|
||||
|
||||
sudoersContent := `# DeBros Network - WireGuard Management Permissions
|
||||
# Allows debros user to manage WireGuard peers
|
||||
|
||||
debros ALL=(ALL) NOPASSWD: /usr/bin/wg set wg0 *
|
||||
debros ALL=(ALL) NOPASSWD: /usr/bin/wg show wg0
|
||||
debros ALL=(ALL) NOPASSWD: /usr/bin/wg showconf wg0
|
||||
debros ALL=(ALL) NOPASSWD: /usr/bin/tee /etc/wireguard/wg0.conf
|
||||
`
|
||||
|
||||
// Write sudoers rule (always overwrite to ensure latest)
|
||||
if err := os.WriteFile(sudoersFile, []byte(sudoersContent), 0440); err != nil {
|
||||
return fmt.Errorf("failed to create wireguard sudoers rule: %w", err)
|
||||
}
|
||||
|
||||
// Validate sudoers file
|
||||
cmd := exec.Command("visudo", "-c", "-f", sudoersFile)
|
||||
if err := cmd.Run(); err != nil {
|
||||
os.Remove(sudoersFile)
|
||||
return fmt.Errorf("wireguard sudoers rule validation failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StateDetector checks for existing production state
|
||||
type StateDetector struct {
|
||||
oramaDir string
|
||||
|
||||
228
pkg/environments/production/wireguard.go
Normal file
228
pkg/environments/production/wireguard.go
Normal file
@ -0,0 +1,228 @@
|
||||
package production
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
// WireGuardPeer represents a WireGuard mesh peer
|
||||
type WireGuardPeer struct {
|
||||
PublicKey string // Base64-encoded public key
|
||||
Endpoint string // e.g., "141.227.165.154:51820"
|
||||
AllowedIP string // e.g., "10.0.0.2/32"
|
||||
}
|
||||
|
||||
// WireGuardConfig holds the configuration for a WireGuard interface
|
||||
type WireGuardConfig struct {
|
||||
PrivateIP string // e.g., "10.0.0.1"
|
||||
ListenPort int // default 51820
|
||||
PrivateKey string // Base64-encoded private key
|
||||
Peers []WireGuardPeer // Known peers
|
||||
}
|
||||
|
||||
// WireGuardProvisioner manages WireGuard VPN setup
|
||||
type WireGuardProvisioner struct {
|
||||
configDir string // /etc/wireguard
|
||||
config WireGuardConfig
|
||||
}
|
||||
|
||||
// NewWireGuardProvisioner creates a new WireGuard provisioner
|
||||
func NewWireGuardProvisioner(config WireGuardConfig) *WireGuardProvisioner {
|
||||
if config.ListenPort == 0 {
|
||||
config.ListenPort = 51820
|
||||
}
|
||||
return &WireGuardProvisioner{
|
||||
configDir: "/etc/wireguard",
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// IsInstalled checks if WireGuard tools are available
|
||||
func (wp *WireGuardProvisioner) IsInstalled() bool {
|
||||
_, err := exec.LookPath("wg")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Install installs the WireGuard package
|
||||
func (wp *WireGuardProvisioner) Install() error {
|
||||
if wp.IsInstalled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := exec.Command("apt-get", "install", "-y", "wireguard", "wireguard-tools")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to install wireguard: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateKeyPair generates a new WireGuard private/public key pair
|
||||
func GenerateKeyPair() (privateKey, publicKey string, err error) {
|
||||
// Generate 32 random bytes for private key
|
||||
var privBytes [32]byte
|
||||
if _, err := rand.Read(privBytes[:]); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// Clamp private key per Curve25519 spec
|
||||
privBytes[0] &= 248
|
||||
privBytes[31] &= 127
|
||||
privBytes[31] |= 64
|
||||
|
||||
// Derive public key
|
||||
var pubBytes [32]byte
|
||||
curve25519.ScalarBaseMult(&pubBytes, &privBytes)
|
||||
|
||||
privateKey = base64.StdEncoding.EncodeToString(privBytes[:])
|
||||
publicKey = base64.StdEncoding.EncodeToString(pubBytes[:])
|
||||
return privateKey, publicKey, nil
|
||||
}
|
||||
|
||||
// PublicKeyFromPrivate derives the public key from a private key
|
||||
func PublicKeyFromPrivate(privateKey string) (string, error) {
|
||||
privBytes, err := base64.StdEncoding.DecodeString(privateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode private key: %w", err)
|
||||
}
|
||||
if len(privBytes) != 32 {
|
||||
return "", fmt.Errorf("invalid private key length: %d", len(privBytes))
|
||||
}
|
||||
|
||||
var priv, pub [32]byte
|
||||
copy(priv[:], privBytes)
|
||||
curve25519.ScalarBaseMult(&pub, &priv)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(pub[:]), nil
|
||||
}
|
||||
|
||||
// GenerateConfig returns the wg0.conf file content
|
||||
func (wp *WireGuardProvisioner) GenerateConfig() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("# WireGuard mesh configuration (managed by Orama Network)\n")
|
||||
sb.WriteString("# Do not edit manually — use orama CLI to manage peers\n\n")
|
||||
sb.WriteString("[Interface]\n")
|
||||
sb.WriteString(fmt.Sprintf("PrivateKey = %s\n", wp.config.PrivateKey))
|
||||
sb.WriteString(fmt.Sprintf("Address = %s/24\n", wp.config.PrivateIP))
|
||||
sb.WriteString(fmt.Sprintf("ListenPort = %d\n", wp.config.ListenPort))
|
||||
|
||||
for _, peer := range wp.config.Peers {
|
||||
sb.WriteString("\n[Peer]\n")
|
||||
sb.WriteString(fmt.Sprintf("PublicKey = %s\n", peer.PublicKey))
|
||||
if peer.Endpoint != "" {
|
||||
sb.WriteString(fmt.Sprintf("Endpoint = %s\n", peer.Endpoint))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("AllowedIPs = %s\n", peer.AllowedIP))
|
||||
sb.WriteString("PersistentKeepalive = 25\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// WriteConfig writes the WireGuard config to /etc/wireguard/wg0.conf
|
||||
func (wp *WireGuardProvisioner) WriteConfig() error {
|
||||
confPath := filepath.Join(wp.configDir, "wg0.conf")
|
||||
content := wp.GenerateConfig()
|
||||
|
||||
// Try direct write first (works when running as root)
|
||||
if err := os.MkdirAll(wp.configDir, 0700); err == nil {
|
||||
if err := os.WriteFile(confPath, []byte(content), 0600); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to sudo tee (for non-root, e.g. debros user)
|
||||
cmd := exec.Command("sudo", "tee", confPath)
|
||||
cmd.Stdin = strings.NewReader(content)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to write wg0.conf via sudo: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enable starts and enables the WireGuard interface
|
||||
func (wp *WireGuardProvisioner) Enable() error {
|
||||
// Enable on boot
|
||||
cmd := exec.Command("systemctl", "enable", "wg-quick@wg0")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to enable wg-quick@wg0: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
// Start now
|
||||
cmd = exec.Command("systemctl", "start", "wg-quick@wg0")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to start wg-quick@wg0: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restart restarts the WireGuard interface
|
||||
func (wp *WireGuardProvisioner) Restart() error {
|
||||
cmd := exec.Command("systemctl", "restart", "wg-quick@wg0")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to restart wg-quick@wg0: %w\n%s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsActive checks if the WireGuard interface is up
|
||||
func (wp *WireGuardProvisioner) IsActive() bool {
|
||||
cmd := exec.Command("systemctl", "is-active", "--quiet", "wg-quick@wg0")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
// AddPeer adds a peer to the running WireGuard interface without restart
|
||||
func (wp *WireGuardProvisioner) AddPeer(peer WireGuardPeer) error {
|
||||
// Add peer to running interface
|
||||
args := []string{"wg", "set", "wg0", "peer", peer.PublicKey, "allowed-ips", peer.AllowedIP, "persistent-keepalive", "25"}
|
||||
if peer.Endpoint != "" {
|
||||
args = append(args, "endpoint", peer.Endpoint)
|
||||
}
|
||||
|
||||
cmd := exec.Command("sudo", args...)
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to add peer %s: %w\n%s", peer.AllowedIP, err, string(output))
|
||||
}
|
||||
|
||||
// Also update config file so it persists across restarts
|
||||
wp.config.Peers = append(wp.config.Peers, peer)
|
||||
return wp.WriteConfig()
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the running WireGuard interface
|
||||
func (wp *WireGuardProvisioner) RemovePeer(publicKey string) error {
|
||||
cmd := exec.Command("sudo", "wg", "set", "wg0", "peer", publicKey, "remove")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
// Remove from config
|
||||
filtered := make([]WireGuardPeer, 0, len(wp.config.Peers))
|
||||
for _, p := range wp.config.Peers {
|
||||
if p.PublicKey != publicKey {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
wp.config.Peers = filtered
|
||||
return wp.WriteConfig()
|
||||
}
|
||||
|
||||
// GetStatus returns the current WireGuard interface status
|
||||
func (wp *WireGuardProvisioner) GetStatus() (string, error) {
|
||||
cmd := exec.Command("wg", "show", "wg0")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get wg status: %w\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
167
pkg/environments/production/wireguard_test.go
Normal file
167
pkg/environments/production/wireguard_test.go
Normal file
@ -0,0 +1,167 @@
|
||||
package production
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateKeyPair(t *testing.T) {
|
||||
priv, pub, err := GenerateKeyPair()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKeyPair failed: %v", err)
|
||||
}
|
||||
|
||||
// Keys should be base64, 44 chars (32 bytes + padding)
|
||||
if len(priv) != 44 {
|
||||
t.Errorf("private key length = %d, want 44", len(priv))
|
||||
}
|
||||
if len(pub) != 44 {
|
||||
t.Errorf("public key length = %d, want 44", len(pub))
|
||||
}
|
||||
|
||||
// Should be valid base64
|
||||
if _, err := base64.StdEncoding.DecodeString(priv); err != nil {
|
||||
t.Errorf("private key is not valid base64: %v", err)
|
||||
}
|
||||
if _, err := base64.StdEncoding.DecodeString(pub); err != nil {
|
||||
t.Errorf("public key is not valid base64: %v", err)
|
||||
}
|
||||
|
||||
// Private and public should differ
|
||||
if priv == pub {
|
||||
t.Error("private and public keys should differ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateKeyPair_Unique(t *testing.T) {
|
||||
priv1, _, _ := GenerateKeyPair()
|
||||
priv2, _, _ := GenerateKeyPair()
|
||||
|
||||
if priv1 == priv2 {
|
||||
t.Error("two generated key pairs should be unique")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeyFromPrivate(t *testing.T) {
|
||||
priv, expectedPub, err := GenerateKeyPair()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKeyPair failed: %v", err)
|
||||
}
|
||||
|
||||
pub, err := PublicKeyFromPrivate(priv)
|
||||
if err != nil {
|
||||
t.Fatalf("PublicKeyFromPrivate failed: %v", err)
|
||||
}
|
||||
|
||||
if pub != expectedPub {
|
||||
t.Errorf("PublicKeyFromPrivate = %s, want %s", pub, expectedPub)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeyFromPrivate_InvalidKey(t *testing.T) {
|
||||
_, err := PublicKeyFromPrivate("not-valid-base64!!!")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid base64")
|
||||
}
|
||||
|
||||
_, err = PublicKeyFromPrivate(base64.StdEncoding.EncodeToString([]byte("short")))
|
||||
if err == nil {
|
||||
t.Error("expected error for short key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardProvisioner_GenerateConfig_NoPeers(t *testing.T) {
|
||||
wp := NewWireGuardProvisioner(WireGuardConfig{
|
||||
PrivateIP: "10.0.0.1",
|
||||
ListenPort: 51820,
|
||||
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
|
||||
})
|
||||
|
||||
config := wp.GenerateConfig()
|
||||
|
||||
if !strings.Contains(config, "[Interface]") {
|
||||
t.Error("config should contain [Interface] section")
|
||||
}
|
||||
if !strings.Contains(config, "Address = 10.0.0.1/24") {
|
||||
t.Error("config should contain correct Address")
|
||||
}
|
||||
if !strings.Contains(config, "ListenPort = 51820") {
|
||||
t.Error("config should contain ListenPort")
|
||||
}
|
||||
if !strings.Contains(config, "PrivateKey = dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=") {
|
||||
t.Error("config should contain PrivateKey")
|
||||
}
|
||||
if strings.Contains(config, "[Peer]") {
|
||||
t.Error("config should NOT contain [Peer] section with no peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardProvisioner_GenerateConfig_WithPeers(t *testing.T) {
|
||||
wp := NewWireGuardProvisioner(WireGuardConfig{
|
||||
PrivateIP: "10.0.0.1",
|
||||
ListenPort: 51820,
|
||||
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
|
||||
Peers: []WireGuardPeer{
|
||||
{
|
||||
PublicKey: "cGVlcjFwdWJsaWNrZXlwZWVyMXB1YmxpY2tleXM=",
|
||||
Endpoint: "1.2.3.4:51820",
|
||||
AllowedIP: "10.0.0.2/32",
|
||||
},
|
||||
{
|
||||
PublicKey: "cGVlcjJwdWJsaWNrZXlwZWVyMnB1YmxpY2tleXM=",
|
||||
Endpoint: "5.6.7.8:51820",
|
||||
AllowedIP: "10.0.0.3/32",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
config := wp.GenerateConfig()
|
||||
|
||||
if strings.Count(config, "[Peer]") != 2 {
|
||||
t.Errorf("expected 2 [Peer] sections, got %d", strings.Count(config, "[Peer]"))
|
||||
}
|
||||
if !strings.Contains(config, "Endpoint = 1.2.3.4:51820") {
|
||||
t.Error("config should contain first peer endpoint")
|
||||
}
|
||||
if !strings.Contains(config, "AllowedIPs = 10.0.0.2/32") {
|
||||
t.Error("config should contain first peer AllowedIPs")
|
||||
}
|
||||
if !strings.Contains(config, "PersistentKeepalive = 25") {
|
||||
t.Error("config should contain PersistentKeepalive")
|
||||
}
|
||||
if !strings.Contains(config, "Endpoint = 5.6.7.8:51820") {
|
||||
t.Error("config should contain second peer endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardProvisioner_GenerateConfig_PeerWithoutEndpoint(t *testing.T) {
|
||||
wp := NewWireGuardProvisioner(WireGuardConfig{
|
||||
PrivateIP: "10.0.0.1",
|
||||
ListenPort: 51820,
|
||||
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
|
||||
Peers: []WireGuardPeer{
|
||||
{
|
||||
PublicKey: "cGVlcjFwdWJsaWNrZXlwZWVyMXB1YmxpY2tleXM=",
|
||||
AllowedIP: "10.0.0.2/32",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
config := wp.GenerateConfig()
|
||||
|
||||
if strings.Contains(config, "Endpoint") {
|
||||
t.Error("config should NOT contain Endpoint when peer has none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireGuardProvisioner_DefaultPort(t *testing.T) {
|
||||
wp := NewWireGuardProvisioner(WireGuardConfig{
|
||||
PrivateIP: "10.0.0.1",
|
||||
PrivateKey: "dGVzdHByaXZhdGVrZXl0ZXN0cHJpdmF0ZWtleXM=",
|
||||
})
|
||||
|
||||
if wp.config.ListenPort != 51820 {
|
||||
t.Errorf("default ListenPort = %d, want 51820", wp.config.ListenPort)
|
||||
}
|
||||
}
|
||||
@ -34,4 +34,7 @@ type Config struct {
|
||||
IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s)
|
||||
IPFSReplicationFactor int // Replication factor for pins (default: 3)
|
||||
IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs)
|
||||
|
||||
// WireGuard mesh configuration
|
||||
ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange
|
||||
}
|
||||
|
||||
@ -26,6 +26,8 @@ import (
|
||||
deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments"
|
||||
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
|
||||
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
||||
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
|
||||
wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard"
|
||||
sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
|
||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||
@ -98,6 +100,15 @@ type Gateway struct {
|
||||
processManager *process.Manager
|
||||
healthChecker *health.HealthChecker
|
||||
|
||||
// Rate limiter
|
||||
rateLimiter *RateLimiter
|
||||
|
||||
// WireGuard peer exchange
|
||||
wireguardHandler *wireguardhandlers.Handler
|
||||
|
||||
// Node join handler
|
||||
joinHandler *joinhandlers.Handler
|
||||
|
||||
// Cluster provisioning for namespace clusters
|
||||
clusterProvisioner authhandlers.ClusterProvisioner
|
||||
}
|
||||
@ -246,6 +257,16 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
||||
)
|
||||
}
|
||||
|
||||
// Initialize rate limiter (300 req/min, burst 50)
|
||||
gw.rateLimiter = NewRateLimiter(300, 50)
|
||||
gw.rateLimiter.StartCleanup(5*time.Minute, 10*time.Minute)
|
||||
|
||||
// Initialize WireGuard peer exchange handler
|
||||
if deps.ORMClient != nil {
|
||||
gw.wireguardHandler = wireguardhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.ClusterSecret)
|
||||
gw.joinHandler = joinhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.DataDir)
|
||||
}
|
||||
|
||||
// Initialize deployment system
|
||||
if deps.ORMClient != nil && deps.IPFSClient != nil {
|
||||
// Convert rqlite.Client to database.Database interface for health checker
|
||||
|
||||
@ -643,8 +643,8 @@ func (s *DeploymentService) getNodeIP(ctx context.Context, nodeID string) (strin
|
||||
|
||||
var rows []nodeRow
|
||||
|
||||
// Try full node ID first
|
||||
query := `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`
|
||||
// Try full node ID first (prefer internal/WG IP for cross-node communication)
|
||||
query := `SELECT COALESCE(internal_ip, ip_address) AS ip_address FROM dns_nodes WHERE id = ? LIMIT 1`
|
||||
err := s.db.Query(ctx, &rows, query, nodeID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
424
pkg/gateway/handlers/join/handler.go
Normal file
424
pkg/gateway/handlers/join/handler.go
Normal file
@ -0,0 +1,424 @@
|
||||
package join
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// JoinRequest is the request body for node join
|
||||
type JoinRequest struct {
|
||||
Token string `json:"token"`
|
||||
WGPublicKey string `json:"wg_public_key"`
|
||||
PublicIP string `json:"public_ip"`
|
||||
}
|
||||
|
||||
// JoinResponse contains everything a joining node needs
|
||||
type JoinResponse struct {
|
||||
// WireGuard
|
||||
WGIP string `json:"wg_ip"`
|
||||
WGPeers []WGPeerInfo `json:"wg_peers"`
|
||||
|
||||
// Secrets
|
||||
ClusterSecret string `json:"cluster_secret"`
|
||||
SwarmKey string `json:"swarm_key"`
|
||||
|
||||
// Cluster join info (all using WG IPs)
|
||||
RQLiteJoinAddress string `json:"rqlite_join_address"`
|
||||
IPFSPeer PeerInfo `json:"ipfs_peer"`
|
||||
IPFSClusterPeer PeerInfo `json:"ipfs_cluster_peer"`
|
||||
BootstrapPeers []string `json:"bootstrap_peers"`
|
||||
|
||||
// Domain
|
||||
BaseDomain string `json:"base_domain"`
|
||||
}
|
||||
|
||||
// WGPeerInfo represents a WireGuard peer
|
||||
type WGPeerInfo struct {
|
||||
PublicKey string `json:"public_key"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
AllowedIP string `json:"allowed_ip"`
|
||||
}
|
||||
|
||||
// PeerInfo represents an IPFS/Cluster peer
|
||||
type PeerInfo struct {
|
||||
ID string `json:"id"`
|
||||
Addrs []string `json:"addrs"`
|
||||
}
|
||||
|
||||
// Handler handles the node join endpoint
|
||||
type Handler struct {
|
||||
logger *zap.Logger
|
||||
rqliteClient rqlite.Client
|
||||
oramaDir string // e.g., /home/debros/.orama
|
||||
}
|
||||
|
||||
// NewHandler creates a new join handler
|
||||
func NewHandler(logger *zap.Logger, rqliteClient rqlite.Client, oramaDir string) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
rqliteClient: rqliteClient,
|
||||
oramaDir: oramaDir,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleJoin handles POST /v1/internal/join
|
||||
func (h *Handler) HandleJoin(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req JoinRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" || req.WGPublicKey == "" || req.PublicIP == "" {
|
||||
http.Error(w, "token, wg_public_key, and public_ip are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// 1. Validate and consume the invite token (atomic single-use)
|
||||
if err := h.consumeToken(ctx, req.Token, req.PublicIP); err != nil {
|
||||
h.logger.Warn("join token validation failed", zap.Error(err))
|
||||
http.Error(w, "unauthorized: invalid or expired token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Assign WG IP with retry on conflict
|
||||
wgIP, err := h.assignWGIP(ctx)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to assign WG IP", zap.Error(err))
|
||||
http.Error(w, "failed to assign WG IP", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Register WG peer in database
|
||||
nodeID := fmt.Sprintf("node-%s", wgIP) // temporary ID based on WG IP
|
||||
_, err = h.rqliteClient.Exec(ctx,
|
||||
"INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)",
|
||||
nodeID, wgIP, req.WGPublicKey, req.PublicIP, 51820)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to register WG peer", zap.Error(err))
|
||||
http.Error(w, "failed to register peer", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 4. Add peer to local WireGuard interface immediately
|
||||
if err := h.addWGPeerLocally(req.WGPublicKey, req.PublicIP, wgIP); err != nil {
|
||||
h.logger.Warn("failed to add WG peer to local interface", zap.Error(err))
|
||||
// Non-fatal: the sync loop will pick it up
|
||||
}
|
||||
|
||||
// 5. Read secrets from disk
|
||||
clusterSecret, err := os.ReadFile(h.oramaDir + "/secrets/cluster-secret")
|
||||
if err != nil {
|
||||
h.logger.Error("failed to read cluster secret", zap.Error(err))
|
||||
http.Error(w, "internal error reading secrets", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
swarmKey, err := os.ReadFile(h.oramaDir + "/secrets/swarm.key")
|
||||
if err != nil {
|
||||
h.logger.Error("failed to read swarm key", zap.Error(err))
|
||||
http.Error(w, "internal error reading secrets", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. Get all WG peers
|
||||
wgPeers, err := h.getWGPeers(ctx, req.WGPublicKey)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list WG peers", zap.Error(err))
|
||||
http.Error(w, "failed to list peers", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 7. Get this node's WG IP
|
||||
myWGIP, err := h.getMyWGIP()
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get local WG IP", zap.Error(err))
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 8. Query IPFS and IPFS Cluster peer info
|
||||
ipfsPeer := h.queryIPFSPeerInfo(myWGIP)
|
||||
ipfsClusterPeer := h.queryIPFSClusterPeerInfo(myWGIP)
|
||||
|
||||
// 9. Get this node's libp2p peer ID for bootstrap peers
|
||||
bootstrapPeers := h.buildBootstrapPeers(myWGIP, ipfsPeer.ID)
|
||||
|
||||
// 10. Read base domain from config
|
||||
baseDomain := h.readBaseDomain()
|
||||
|
||||
resp := JoinResponse{
|
||||
WGIP: wgIP,
|
||||
WGPeers: wgPeers,
|
||||
ClusterSecret: strings.TrimSpace(string(clusterSecret)),
|
||||
SwarmKey: strings.TrimSpace(string(swarmKey)),
|
||||
RQLiteJoinAddress: fmt.Sprintf("%s:7001", myWGIP),
|
||||
IPFSPeer: ipfsPeer,
|
||||
IPFSClusterPeer: ipfsClusterPeer,
|
||||
BootstrapPeers: bootstrapPeers,
|
||||
BaseDomain: baseDomain,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
h.logger.Info("node joined cluster",
|
||||
zap.String("wg_ip", wgIP),
|
||||
zap.String("public_ip", req.PublicIP))
|
||||
}
|
||||
|
||||
// consumeToken validates and marks an invite token as used (atomic single-use)
|
||||
func (h *Handler) consumeToken(ctx context.Context, token, usedByIP string) error {
|
||||
// Atomically mark as used — only succeeds if token exists, is unused, and not expired
|
||||
result, err := h.rqliteClient.Exec(ctx,
|
||||
"UPDATE invite_tokens SET used_at = datetime('now'), used_by_ip = ? WHERE token = ? AND used_at IS NULL AND expires_at > datetime('now')",
|
||||
usedByIP, token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("database error: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check result: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("token invalid, expired, or already used")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// assignWGIP finds the next available 10.0.0.x IP, retrying on UNIQUE constraint violation
|
||||
func (h *Handler) assignWGIP(ctx context.Context) (string, error) {
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
var result []struct {
|
||||
MaxIP string `db:"max_ip"`
|
||||
}
|
||||
|
||||
err := h.rqliteClient.Query(ctx, &result,
|
||||
"SELECT MAX(wg_ip) as max_ip FROM wireguard_peers")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to query max WG IP: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 || result[0].MaxIP == "" {
|
||||
return "10.0.0.2", nil // 10.0.0.1 is genesis
|
||||
}
|
||||
|
||||
maxIP := result[0].MaxIP
|
||||
var a, b, c, d int
|
||||
if _, err := fmt.Sscanf(maxIP, "%d.%d.%d.%d", &a, &b, &c, &d); err != nil {
|
||||
return "", fmt.Errorf("failed to parse max WG IP %s: %w", maxIP, err)
|
||||
}
|
||||
|
||||
d++
|
||||
if d > 254 {
|
||||
c++
|
||||
d = 1
|
||||
if c > 255 {
|
||||
return "", fmt.Errorf("WireGuard IP space exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
nextIP := fmt.Sprintf("%d.%d.%d.%d", a, b, c, d)
|
||||
return nextIP, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to assign WG IP after retries")
|
||||
}
|
||||
|
||||
// addWGPeerLocally adds a peer to the local wg0 interface and persists to config
|
||||
func (h *Handler) addWGPeerLocally(pubKey, publicIP, wgIP string) error {
|
||||
// Add to running interface with persistent-keepalive
|
||||
cmd := exec.Command("sudo", "wg", "set", "wg0",
|
||||
"peer", pubKey,
|
||||
"endpoint", fmt.Sprintf("%s:51820", publicIP),
|
||||
"allowed-ips", fmt.Sprintf("%s/32", wgIP),
|
||||
"persistent-keepalive", "25")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("wg set failed: %w\n%s", err, string(output))
|
||||
}
|
||||
|
||||
// Persist to wg0.conf so peer survives wg-quick restart.
|
||||
// Read current config, append peer section, write back.
|
||||
confPath := "/etc/wireguard/wg0.conf"
|
||||
data, err := os.ReadFile(confPath)
|
||||
if err != nil {
|
||||
h.logger.Warn("could not read wg0.conf for persistence", zap.Error(err))
|
||||
return nil // non-fatal: runtime peer is added
|
||||
}
|
||||
|
||||
// Check if peer already in config
|
||||
if strings.Contains(string(data), pubKey) {
|
||||
return nil // already persisted
|
||||
}
|
||||
|
||||
peerSection := fmt.Sprintf("\n[Peer]\nPublicKey = %s\nEndpoint = %s:51820\nAllowedIPs = %s/32\nPersistentKeepalive = 25\n",
|
||||
pubKey, publicIP, wgIP)
|
||||
|
||||
newConf := string(data) + peerSection
|
||||
writeCmd := exec.Command("sudo", "tee", confPath)
|
||||
writeCmd.Stdin = strings.NewReader(newConf)
|
||||
if output, err := writeCmd.CombinedOutput(); err != nil {
|
||||
h.logger.Warn("could not persist peer to wg0.conf", zap.Error(err), zap.String("output", string(output)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getWGPeers returns all WG peers except the requesting node
|
||||
func (h *Handler) getWGPeers(ctx context.Context, excludePubKey string) ([]WGPeerInfo, error) {
|
||||
type peerRow struct {
|
||||
WGIP string `db:"wg_ip"`
|
||||
PublicKey string `db:"public_key"`
|
||||
PublicIP string `db:"public_ip"`
|
||||
WGPort int `db:"wg_port"`
|
||||
}
|
||||
|
||||
var rows []peerRow
|
||||
err := h.rqliteClient.Query(ctx, &rows,
|
||||
"SELECT wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var peers []WGPeerInfo
|
||||
for _, row := range rows {
|
||||
if row.PublicKey == excludePubKey {
|
||||
continue // don't include the requesting node itself
|
||||
}
|
||||
port := row.WGPort
|
||||
if port == 0 {
|
||||
port = 51820
|
||||
}
|
||||
peers = append(peers, WGPeerInfo{
|
||||
PublicKey: row.PublicKey,
|
||||
Endpoint: fmt.Sprintf("%s:%d", row.PublicIP, port),
|
||||
AllowedIP: fmt.Sprintf("%s/32", row.WGIP),
|
||||
})
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// getMyWGIP gets this node's WireGuard IP from the wg0 interface
|
||||
func (h *Handler) getMyWGIP() (string, error) {
|
||||
out, err := exec.Command("ip", "-4", "addr", "show", "wg0").CombinedOutput()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get wg0 info: %w", err)
|
||||
}
|
||||
|
||||
// Parse "inet 10.0.0.1/32" from output
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "inet ") {
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) >= 2 {
|
||||
ip := strings.Split(parts[1], "/")[0]
|
||||
return ip, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("could not find wg0 IP address")
|
||||
}
|
||||
|
||||
// queryIPFSPeerInfo gets the local IPFS node's peer ID and builds addrs with WG IP
|
||||
func (h *Handler) queryIPFSPeerInfo(myWGIP string) PeerInfo {
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Post("http://localhost:4501/api/v0/id", "", nil)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to query IPFS peer info", zap.Error(err))
|
||||
return PeerInfo{}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
ID string `json:"ID"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
h.logger.Warn("failed to decode IPFS peer info", zap.Error(err))
|
||||
return PeerInfo{}
|
||||
}
|
||||
|
||||
return PeerInfo{
|
||||
ID: result.ID,
|
||||
Addrs: []string{
|
||||
fmt.Sprintf("/ip4/%s/tcp/4101", myWGIP),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// queryIPFSClusterPeerInfo gets the local IPFS Cluster peer ID and builds addrs with WG IP
|
||||
func (h *Handler) queryIPFSClusterPeerInfo(myWGIP string) PeerInfo {
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Get("http://localhost:9094/id")
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to query IPFS Cluster peer info", zap.Error(err))
|
||||
return PeerInfo{}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
h.logger.Warn("failed to decode IPFS Cluster peer info", zap.Error(err))
|
||||
return PeerInfo{}
|
||||
}
|
||||
|
||||
return PeerInfo{
|
||||
ID: result.ID,
|
||||
Addrs: []string{
|
||||
fmt.Sprintf("/ip4/%s/tcp/9100/p2p/%s", myWGIP, result.ID),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// buildBootstrapPeers constructs bootstrap peer multiaddrs using WG IPs
|
||||
func (h *Handler) buildBootstrapPeers(myWGIP, ipfsPeerID string) []string {
|
||||
if ipfsPeerID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{
|
||||
fmt.Sprintf("/ip4/%s/tcp/4101/p2p/%s", myWGIP, ipfsPeerID),
|
||||
}
|
||||
}
|
||||
|
||||
// readBaseDomain reads the base domain from node config
|
||||
func (h *Handler) readBaseDomain() string {
|
||||
data, err := os.ReadFile(h.oramaDir + "/configs/node.yaml")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Simple parse — look for base_domain field
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "base_domain:") {
|
||||
val := strings.TrimPrefix(line, "base_domain:")
|
||||
val = strings.TrimSpace(val)
|
||||
val = strings.Trim(val, `"'`)
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
211
pkg/gateway/handlers/wireguard/handler.go
Normal file
211
pkg/gateway/handlers/wireguard/handler.go
Normal file
@ -0,0 +1,211 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PeerRecord represents a WireGuard peer stored in RQLite
|
||||
type PeerRecord struct {
|
||||
NodeID string `json:"node_id" db:"node_id"`
|
||||
WGIP string `json:"wg_ip" db:"wg_ip"`
|
||||
PublicKey string `json:"public_key" db:"public_key"`
|
||||
PublicIP string `json:"public_ip" db:"public_ip"`
|
||||
WGPort int `json:"wg_port" db:"wg_port"`
|
||||
}
|
||||
|
||||
// RegisterPeerRequest is the request body for peer registration
|
||||
type RegisterPeerRequest struct {
|
||||
NodeID string `json:"node_id"`
|
||||
PublicKey string `json:"public_key"`
|
||||
PublicIP string `json:"public_ip"`
|
||||
WGPort int `json:"wg_port,omitempty"`
|
||||
ClusterSecret string `json:"cluster_secret"`
|
||||
}
|
||||
|
||||
// RegisterPeerResponse is the response for peer registration
|
||||
type RegisterPeerResponse struct {
|
||||
AssignedWGIP string `json:"assigned_wg_ip"`
|
||||
Peers []PeerRecord `json:"peers"`
|
||||
}
|
||||
|
||||
// Handler handles WireGuard peer exchange endpoints
|
||||
type Handler struct {
|
||||
logger *zap.Logger
|
||||
rqliteClient rqlite.Client
|
||||
clusterSecret string // expected cluster secret for auth
|
||||
}
|
||||
|
||||
// NewHandler creates a new WireGuard handler
|
||||
func NewHandler(logger *zap.Logger, rqliteClient rqlite.Client, clusterSecret string) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
rqliteClient: rqliteClient,
|
||||
clusterSecret: clusterSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleRegisterPeer handles POST /v1/internal/wg/peer
|
||||
// A new node calls this to register itself and get all existing peers.
|
||||
func (h *Handler) HandleRegisterPeer(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req RegisterPeerRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate cluster secret
|
||||
if h.clusterSecret != "" && req.ClusterSecret != h.clusterSecret {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if req.NodeID == "" || req.PublicKey == "" || req.PublicIP == "" {
|
||||
http.Error(w, "node_id, public_key, and public_ip are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.WGPort == 0 {
|
||||
req.WGPort = 51820
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Assign next available WG IP
|
||||
wgIP, err := h.assignNextWGIP(ctx)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to assign WG IP", zap.Error(err))
|
||||
http.Error(w, "failed to assign WG IP", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Insert peer record
|
||||
_, err = h.rqliteClient.Exec(ctx,
|
||||
"INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)",
|
||||
req.NodeID, wgIP, req.PublicKey, req.PublicIP, req.WGPort)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to insert WG peer", zap.Error(err))
|
||||
http.Error(w, "failed to register peer", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get all peers (including the one just added)
|
||||
peers, err := h.ListPeers(ctx)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list WG peers", zap.Error(err))
|
||||
http.Error(w, "failed to list peers", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := RegisterPeerResponse{
|
||||
AssignedWGIP: wgIP,
|
||||
Peers: peers,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
h.logger.Info("registered WireGuard peer",
|
||||
zap.String("node_id", req.NodeID),
|
||||
zap.String("wg_ip", wgIP),
|
||||
zap.String("public_ip", req.PublicIP))
|
||||
}
|
||||
|
||||
// HandleListPeers handles GET /v1/internal/wg/peers
|
||||
func (h *Handler) HandleListPeers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := h.ListPeers(r.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list WG peers", zap.Error(err))
|
||||
http.Error(w, "failed to list peers", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(peers)
|
||||
}
|
||||
|
||||
// HandleRemovePeer handles DELETE /v1/internal/wg/peer?node_id=xxx
|
||||
func (h *Handler) HandleRemovePeer(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
nodeID := r.URL.Query().Get("node_id")
|
||||
if nodeID == "" {
|
||||
http.Error(w, "node_id parameter required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.rqliteClient.Exec(r.Context(),
|
||||
"DELETE FROM wireguard_peers WHERE node_id = ?", nodeID)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to remove WG peer", zap.Error(err))
|
||||
http.Error(w, "failed to remove peer", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
h.logger.Info("removed WireGuard peer", zap.String("node_id", nodeID))
|
||||
}
|
||||
|
||||
// ListPeers returns all registered WireGuard peers
|
||||
func (h *Handler) ListPeers(ctx context.Context) ([]PeerRecord, error) {
|
||||
var peers []PeerRecord
|
||||
err := h.rqliteClient.Query(ctx, &peers,
|
||||
"SELECT node_id, wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query wireguard_peers: %w", err)
|
||||
}
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// assignNextWGIP finds the next available 10.0.0.x IP
|
||||
func (h *Handler) assignNextWGIP(ctx context.Context) (string, error) {
|
||||
var result []struct {
|
||||
MaxIP string `db:"max_ip"`
|
||||
}
|
||||
|
||||
err := h.rqliteClient.Query(ctx, &result,
|
||||
"SELECT MAX(wg_ip) as max_ip FROM wireguard_peers")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to query max WG IP: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 || result[0].MaxIP == "" {
|
||||
return "10.0.0.1", nil
|
||||
}
|
||||
|
||||
// Parse last octet and increment
|
||||
maxIP := result[0].MaxIP
|
||||
var a, b, c, d int
|
||||
if _, err := fmt.Sscanf(maxIP, "%d.%d.%d.%d", &a, &b, &c, &d); err != nil {
|
||||
return "", fmt.Errorf("failed to parse max WG IP %s: %w", maxIP, err)
|
||||
}
|
||||
|
||||
d++
|
||||
if d > 254 {
|
||||
c++
|
||||
d = 1
|
||||
if c > 255 {
|
||||
return "", fmt.Errorf("WireGuard IP space exhausted")
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d.%d.%d.%d", a, b, c, d), nil
|
||||
}
|
||||
@ -19,15 +19,32 @@ import (
|
||||
|
||||
// Note: context keys (ctxKeyAPIKey, ctxKeyJWT, CtxKeyNamespaceOverride) are now defined in context.go
|
||||
|
||||
// withMiddleware adds CORS and logging middleware
|
||||
// withMiddleware adds CORS, security headers, rate limiting, and logging middleware
|
||||
func (g *Gateway) withMiddleware(next http.Handler) http.Handler {
|
||||
// Order: logging (outermost) -> CORS -> domain routing -> auth -> handler
|
||||
// Domain routing must come BEFORE auth to handle deployment domains without auth
|
||||
// Order: logging -> security headers -> rate limit -> CORS -> domain routing -> auth -> handler
|
||||
return g.loggingMiddleware(
|
||||
g.corsMiddleware(
|
||||
g.domainRoutingMiddleware(
|
||||
g.authMiddleware(
|
||||
g.authorizationMiddleware(next)))))
|
||||
g.securityHeadersMiddleware(
|
||||
g.rateLimitMiddleware(
|
||||
g.corsMiddleware(
|
||||
g.domainRoutingMiddleware(
|
||||
g.authMiddleware(
|
||||
g.authorizationMiddleware(next)))))))
|
||||
}
|
||||
|
||||
// securityHeadersMiddleware adds standard security headers to all responses
|
||||
func (g *Gateway) securityHeadersMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("X-XSS-Protection", "0")
|
||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
||||
// HSTS only when behind TLS (Caddy)
|
||||
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// loggingMiddleware logs basic request info and duration
|
||||
@ -202,6 +219,16 @@ func isPublicPath(p string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// WireGuard peer exchange (auth handled by cluster secret in handler)
|
||||
if strings.HasPrefix(p, "/v1/internal/wg/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Node join endpoint (auth handled by invite token in handler)
|
||||
if p == "/v1/internal/join" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch p {
|
||||
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers", "/v1/internal/tls/check", "/v1/internal/acme/present", "/v1/internal/acme/cleanup":
|
||||
return true
|
||||
@ -912,7 +939,7 @@ func (g *Gateway) proxyCrossNode(w http.ResponseWriter, r *http.Request, deploym
|
||||
db := g.client.Database()
|
||||
internalCtx := client.WithInternalAuth(r.Context())
|
||||
|
||||
query := "SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1"
|
||||
query := "SELECT COALESCE(internal_ip, ip_address) FROM dns_nodes WHERE id = ? LIMIT 1"
|
||||
result, err := db.Query(internalCtx, query, deployment.HomeNodeID)
|
||||
if err != nil || result == nil || len(result.Rows) == 0 {
|
||||
g.logger.Warn("Failed to get home node IP",
|
||||
|
||||
129
pkg/gateway/rate_limiter.go
Normal file
129
pkg/gateway/rate_limiter.go
Normal file
@ -0,0 +1,129 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token-bucket rate limiter per client IP.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
clients map[string]*bucket
|
||||
rate float64 // tokens per second
|
||||
burst int // max tokens (burst capacity)
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
lastCheck time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a rate limiter. ratePerMinute is the sustained rate;
|
||||
// burst is the maximum number of requests that can be made in a short window.
|
||||
func NewRateLimiter(ratePerMinute, burst int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
clients: make(map[string]*bucket),
|
||||
rate: float64(ratePerMinute) / 60.0,
|
||||
burst: burst,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request from the given IP should be allowed.
|
||||
func (rl *RateLimiter) Allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
b, exists := rl.clients[ip]
|
||||
if !exists {
|
||||
rl.clients[ip] = &bucket{tokens: float64(rl.burst) - 1, lastCheck: now}
|
||||
return true
|
||||
}
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
elapsed := now.Sub(b.lastCheck).Seconds()
|
||||
b.tokens += elapsed * rl.rate
|
||||
if b.tokens > float64(rl.burst) {
|
||||
b.tokens = float64(rl.burst)
|
||||
}
|
||||
b.lastCheck = now
|
||||
|
||||
if b.tokens >= 1 {
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Cleanup removes stale entries older than the given duration.
|
||||
func (rl *RateLimiter) Cleanup(maxAge time.Duration) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
for ip, b := range rl.clients {
|
||||
if b.lastCheck.Before(cutoff) {
|
||||
delete(rl.clients, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StartCleanup runs periodic cleanup in a goroutine.
|
||||
func (rl *RateLimiter) StartCleanup(interval, maxAge time.Duration) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
rl.Cleanup(maxAge)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// rateLimitMiddleware returns 429 when a client exceeds the rate limit.
|
||||
// Internal traffic from the WireGuard subnet (10.0.0.0/8) is exempt.
|
||||
func (g *Gateway) rateLimitMiddleware(next http.Handler) http.Handler {
|
||||
if g.rateLimiter == nil {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := getClientIP(r)
|
||||
|
||||
// Exempt internal cluster traffic (WireGuard subnet)
|
||||
if isInternalIP(ip) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if !g.rateLimiter.Allow(ip) {
|
||||
w.Header().Set("Retry-After", "5")
|
||||
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// isInternalIP returns true if the IP is in the WireGuard 10.0.0.0/8 subnet
|
||||
// or is a loopback address.
|
||||
func isInternalIP(ipStr string) bool {
|
||||
// Strip port if present
|
||||
if strings.Contains(ipStr, ":") {
|
||||
host, _, err := net.SplitHostPort(ipStr)
|
||||
if err == nil {
|
||||
ipStr = host
|
||||
}
|
||||
}
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
// 10.0.0.0/8 — WireGuard mesh
|
||||
_, wgNet, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
return wgNet.Contains(ip)
|
||||
}
|
||||
197
pkg/gateway/rate_limiter_test.go
Normal file
197
pkg/gateway/rate_limiter_test.go
Normal file
@ -0,0 +1,197 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter_AllowsUnderLimit(t *testing.T) {
|
||||
rl := NewRateLimiter(60, 10) // 1/sec, burst 10
|
||||
for i := 0; i < 10; i++ {
|
||||
if !rl.Allow("1.2.3.4") {
|
||||
t.Fatalf("request %d should be allowed (within burst)", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_BlocksOverLimit(t *testing.T) {
|
||||
rl := NewRateLimiter(60, 5) // 1/sec, burst 5
|
||||
// Exhaust burst
|
||||
for i := 0; i < 5; i++ {
|
||||
rl.Allow("1.2.3.4")
|
||||
}
|
||||
if rl.Allow("1.2.3.4") {
|
||||
t.Fatal("request after burst should be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_RefillsOverTime(t *testing.T) {
|
||||
rl := NewRateLimiter(6000, 5) // 100/sec, burst 5
|
||||
// Exhaust burst
|
||||
for i := 0; i < 5; i++ {
|
||||
rl.Allow("1.2.3.4")
|
||||
}
|
||||
if rl.Allow("1.2.3.4") {
|
||||
t.Fatal("should be blocked after burst")
|
||||
}
|
||||
// Wait for refill
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if !rl.Allow("1.2.3.4") {
|
||||
t.Fatal("should be allowed after refill")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_PerIPIsolation(t *testing.T) {
|
||||
rl := NewRateLimiter(60, 2)
|
||||
// Exhaust IP A
|
||||
rl.Allow("1.1.1.1")
|
||||
rl.Allow("1.1.1.1")
|
||||
if rl.Allow("1.1.1.1") {
|
||||
t.Fatal("IP A should be blocked")
|
||||
}
|
||||
// IP B should still be allowed
|
||||
if !rl.Allow("2.2.2.2") {
|
||||
t.Fatal("IP B should be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_Cleanup(t *testing.T) {
|
||||
rl := NewRateLimiter(60, 10)
|
||||
rl.Allow("old-ip")
|
||||
// Force the entry to be old
|
||||
rl.mu.Lock()
|
||||
rl.clients["old-ip"].lastCheck = time.Now().Add(-20 * time.Minute)
|
||||
rl.mu.Unlock()
|
||||
|
||||
rl.Cleanup(10 * time.Minute)
|
||||
|
||||
rl.mu.Lock()
|
||||
_, exists := rl.clients["old-ip"]
|
||||
rl.mu.Unlock()
|
||||
if exists {
|
||||
t.Fatal("stale entry should have been cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
|
||||
rl := NewRateLimiter(60000, 100) // high limit to avoid false failures
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
rl.Allow("concurrent-ip")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIsInternalIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
internal bool
|
||||
}{
|
||||
{"10.0.0.1", true},
|
||||
{"10.0.0.254", true},
|
||||
{"10.255.255.255", true},
|
||||
{"127.0.0.1", true},
|
||||
{"192.168.1.1", false},
|
||||
{"8.8.8.8", false},
|
||||
{"141.227.165.168", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isInternalIP(tt.ip); got != tt.internal {
|
||||
t.Errorf("isInternalIP(%q) = %v, want %v", tt.ip, got, tt.internal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
gw := &Gateway{}
|
||||
handler := gw.securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
expected := map[string]string{
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"X-Frame-Options": "DENY",
|
||||
"X-XSS-Protection": "0",
|
||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||||
}
|
||||
|
||||
for header, want := range expected {
|
||||
if got := w.Header().Get(header); got != want {
|
||||
t.Errorf("header %s = %q, want %q", header, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_NoHSTS_WithoutTLS(t *testing.T) {
|
||||
gw := &Gateway{}
|
||||
handler := gw.securityHeadersMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Strict-Transport-Security"); got != "" {
|
||||
t.Errorf("HSTS should not be set without TLS, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_Returns429(t *testing.T) {
|
||||
gw := &Gateway{rateLimiter: NewRateLimiter(60, 1)}
|
||||
handler := gw.rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request should pass
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "8.8.8.8:1234"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("first request should be 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Second request should be rate limited
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("second request should be 429, got %d", w.Code)
|
||||
}
|
||||
if w.Header().Get("Retry-After") == "" {
|
||||
t.Fatal("should have Retry-After header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_ExemptsInternalTraffic(t *testing.T) {
|
||||
gw := &Gateway{rateLimiter: NewRateLimiter(60, 1)}
|
||||
handler := gw.rateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Internal IP should never be rate limited
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("internal request %d should be 200, got %d", i, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -24,6 +24,18 @@ func (g *Gateway) Routes() http.Handler {
|
||||
mux.HandleFunc("/v1/internal/acme/present", g.acmePresentHandler)
|
||||
mux.HandleFunc("/v1/internal/acme/cleanup", g.acmeCleanupHandler)
|
||||
|
||||
// WireGuard peer exchange (internal, cluster-secret auth)
|
||||
if g.wireguardHandler != nil {
|
||||
mux.HandleFunc("/v1/internal/wg/peer", g.wireguardHandler.HandleRegisterPeer)
|
||||
mux.HandleFunc("/v1/internal/wg/peers", g.wireguardHandler.HandleListPeers)
|
||||
mux.HandleFunc("/v1/internal/wg/peer/remove", g.wireguardHandler.HandleRemovePeer)
|
||||
}
|
||||
|
||||
// Node join endpoint (token-authenticated, no middleware auth needed)
|
||||
if g.joinHandler != nil {
|
||||
mux.HandleFunc("/v1/internal/join", g.joinHandler.HandleJoin)
|
||||
}
|
||||
|
||||
// auth endpoints
|
||||
mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler)
|
||||
mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler)
|
||||
|
||||
@ -226,8 +226,8 @@ func (cm *ClusterManager) startRQLiteCluster(ctx context.Context, cluster *Names
|
||||
NodeID: nodes[0].NodeID,
|
||||
HTTPPort: portBlocks[0].RQLiteHTTPPort,
|
||||
RaftPort: portBlocks[0].RQLiteRaftPort,
|
||||
HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[0].IPAddress, portBlocks[0].RQLiteHTTPPort),
|
||||
RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[0].IPAddress, portBlocks[0].RQLiteRaftPort),
|
||||
HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[0].InternalIP, portBlocks[0].RQLiteHTTPPort),
|
||||
RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[0].InternalIP, portBlocks[0].RQLiteRaftPort),
|
||||
IsLeader: true,
|
||||
}
|
||||
|
||||
@ -254,8 +254,8 @@ func (cm *ClusterManager) startRQLiteCluster(ctx context.Context, cluster *Names
|
||||
NodeID: nodes[i].NodeID,
|
||||
HTTPPort: portBlocks[i].RQLiteHTTPPort,
|
||||
RaftPort: portBlocks[i].RQLiteRaftPort,
|
||||
HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[i].IPAddress, portBlocks[i].RQLiteHTTPPort),
|
||||
RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[i].IPAddress, portBlocks[i].RQLiteRaftPort),
|
||||
HTTPAdvAddress: fmt.Sprintf("%s:%d", nodes[i].InternalIP, portBlocks[i].RQLiteHTTPPort),
|
||||
RaftAdvAddress: fmt.Sprintf("%s:%d", nodes[i].InternalIP, portBlocks[i].RQLiteRaftPort),
|
||||
JoinAddresses: []string{leaderRaftAddr},
|
||||
IsLeader: false,
|
||||
}
|
||||
@ -288,7 +288,7 @@ func (cm *ClusterManager) startOlricCluster(ctx context.Context, cluster *Namesp
|
||||
// Build peer addresses (all nodes)
|
||||
peerAddresses := make([]string, len(nodes))
|
||||
for i, node := range nodes {
|
||||
peerAddresses[i] = fmt.Sprintf("%s:%d", node.IPAddress, portBlocks[i].OlricMemberlistPort)
|
||||
peerAddresses[i] = fmt.Sprintf("%s:%d", node.InternalIP, portBlocks[i].OlricMemberlistPort)
|
||||
}
|
||||
|
||||
// Start all Olric instances
|
||||
@ -299,7 +299,7 @@ func (cm *ClusterManager) startOlricCluster(ctx context.Context, cluster *Namesp
|
||||
HTTPPort: portBlocks[i].OlricHTTPPort,
|
||||
MemberlistPort: portBlocks[i].OlricMemberlistPort,
|
||||
BindAddr: "0.0.0.0",
|
||||
AdvertiseAddr: node.IPAddress,
|
||||
AdvertiseAddr: node.InternalIP,
|
||||
PeerAddresses: peerAddresses,
|
||||
}
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ type ClusterNodeSelector struct {
|
||||
type NodeCapacity struct {
|
||||
NodeID string `json:"node_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
InternalIP string `json:"internal_ip"` // WireGuard IP for inter-node communication
|
||||
DeploymentCount int `json:"deployment_count"`
|
||||
AllocatedPorts int `json:"allocated_ports"`
|
||||
AvailablePorts int `json:"available_ports"`
|
||||
@ -59,7 +60,7 @@ func (cns *ClusterNodeSelector) SelectNodesForCluster(ctx context.Context, nodeC
|
||||
// Filter nodes that have capacity for namespace instances
|
||||
eligibleNodes := make([]NodeCapacity, 0)
|
||||
for _, node := range activeNodes {
|
||||
capacity, err := cns.getNodeCapacity(internalCtx, node.NodeID, node.IPAddress)
|
||||
capacity, err := cns.getNodeCapacity(internalCtx, node.NodeID, node.IPAddress, node.InternalIP)
|
||||
if err != nil {
|
||||
cns.logger.Warn("Failed to get node capacity, skipping",
|
||||
zap.String("node_id", node.NodeID),
|
||||
@ -117,8 +118,9 @@ func (cns *ClusterNodeSelector) SelectNodesForCluster(ctx context.Context, nodeC
|
||||
|
||||
// nodeInfo is used for querying active nodes
|
||||
type nodeInfo struct {
|
||||
NodeID string `db:"id"`
|
||||
IPAddress string `db:"ip_address"`
|
||||
NodeID string `db:"id"`
|
||||
IPAddress string `db:"ip_address"`
|
||||
InternalIP string `db:"internal_ip"`
|
||||
}
|
||||
|
||||
// getActiveNodes retrieves all active nodes from dns_nodes table
|
||||
@ -128,7 +130,7 @@ func (cns *ClusterNodeSelector) getActiveNodes(ctx context.Context) ([]nodeInfo,
|
||||
|
||||
var results []nodeInfo
|
||||
query := `
|
||||
SELECT id, ip_address FROM dns_nodes
|
||||
SELECT id, ip_address, COALESCE(internal_ip, ip_address) as internal_ip FROM dns_nodes
|
||||
WHERE status = 'active' AND last_seen > ?
|
||||
ORDER BY id
|
||||
`
|
||||
@ -148,7 +150,7 @@ func (cns *ClusterNodeSelector) getActiveNodes(ctx context.Context) ([]nodeInfo,
|
||||
}
|
||||
|
||||
// getNodeCapacity calculates capacity metrics for a single node
|
||||
func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipAddress string) (*NodeCapacity, error) {
|
||||
func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipAddress, internalIP string) (*NodeCapacity, error) {
|
||||
// Get deployment count
|
||||
deploymentCount, err := cns.getDeploymentCount(ctx, nodeID)
|
||||
if err != nil {
|
||||
@ -209,6 +211,7 @@ func (cns *ClusterNodeSelector) getNodeCapacity(ctx context.Context, nodeID, ipA
|
||||
capacity := &NodeCapacity{
|
||||
NodeID: nodeID,
|
||||
IPAddress: ipAddress,
|
||||
InternalIP: internalIP,
|
||||
DeploymentCount: deploymentCount,
|
||||
AllocatedPorts: allocatedPorts,
|
||||
AvailablePorts: availablePorts,
|
||||
@ -365,7 +368,7 @@ func (cns *ClusterNodeSelector) GetNodeByID(ctx context.Context, nodeID string)
|
||||
internalCtx := client.WithInternalAuth(ctx)
|
||||
|
||||
var results []nodeInfo
|
||||
query := `SELECT id, ip_address FROM dns_nodes WHERE id = ? LIMIT 1`
|
||||
query := `SELECT id, ip_address, COALESCE(internal_ip, ip_address) as internal_ip FROM dns_nodes WHERE id = ? LIMIT 1`
|
||||
err := cns.db.Query(internalCtx, &results, query, nodeID)
|
||||
if err != nil {
|
||||
return nil, &ClusterError{
|
||||
|
||||
@ -29,8 +29,11 @@ func (n *Node) registerDNSNode(ctx context.Context) error {
|
||||
ipAddress = "127.0.0.1"
|
||||
}
|
||||
|
||||
// Get internal IP (same as external for now, or could use private network IP)
|
||||
// Get internal IP from WireGuard interface (for cross-node communication over VPN)
|
||||
internalIP := ipAddress
|
||||
if wgIP, err := n.getWireGuardIP(); err == nil && wgIP != "" {
|
||||
internalIP = wgIP
|
||||
}
|
||||
|
||||
// Determine region (defaulting to "local" for now, could be from cloud metadata in future)
|
||||
region := "local"
|
||||
@ -297,6 +300,24 @@ func (n *Node) cleanupStaleNodeRecords(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// getWireGuardIP returns the IPv4 address assigned to the wg0 interface, if any
|
||||
func (n *Node) getWireGuardIP() (string, error) {
|
||||
iface, err := net.InterfaceByName("wg0")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil {
|
||||
return ipnet.IP.String(), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no IPv4 address on wg0")
|
||||
}
|
||||
|
||||
// getNodeIPAddress attempts to determine the node's external IP address
|
||||
func (n *Node) getNodeIPAddress() (string, error) {
|
||||
// Try to detect external IP by connecting to a public server
|
||||
|
||||
@ -33,6 +33,15 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// DataDir in node config is ~/.orama/data; the orama dir is the parent
|
||||
oramaDir := filepath.Join(os.ExpandEnv(n.config.Node.DataDir), "..")
|
||||
|
||||
// Read cluster secret for WireGuard peer exchange auth
|
||||
clusterSecret := ""
|
||||
if secretBytes, err := os.ReadFile(filepath.Join(oramaDir, "secrets", "cluster-secret")); err == nil {
|
||||
clusterSecret = string(secretBytes)
|
||||
}
|
||||
|
||||
gwCfg := &gateway.Config{
|
||||
ListenAddr: n.config.HTTPGateway.ListenAddr,
|
||||
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
|
||||
@ -45,6 +54,8 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
|
||||
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
|
||||
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
||||
BaseDomain: n.config.HTTPGateway.BaseDomain,
|
||||
DataDir: oramaDir,
|
||||
ClusterSecret: clusterSecret,
|
||||
}
|
||||
|
||||
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
|
||||
|
||||
@ -103,6 +103,9 @@ func (n *Node) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to start RQLite: %w", err)
|
||||
}
|
||||
|
||||
// Sync WireGuard peers from RQLite (if WG is active on this node)
|
||||
n.startWireGuardSyncLoop(ctx)
|
||||
|
||||
// Register this node in dns_nodes table for deployment routing
|
||||
if err := n.registerDNSNode(ctx); err != nil {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to register DNS node", zap.Error(err))
|
||||
|
||||
223
pkg/node/wireguard_sync.go
Normal file
223
pkg/node/wireguard_sync.go
Normal file
@ -0,0 +1,223 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/environments/production"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// syncWireGuardPeers reads all peers from RQLite and reconciles the local
|
||||
// WireGuard interface so it matches the cluster state. This is called on
|
||||
// startup after RQLite is ready and periodically thereafter.
|
||||
func (n *Node) syncWireGuardPeers(ctx context.Context) error {
|
||||
if n.rqliteAdapter == nil {
|
||||
return fmt.Errorf("rqlite adapter not initialized")
|
||||
}
|
||||
|
||||
// Check if WireGuard is installed and active
|
||||
if _, err := exec.LookPath("wg"); err != nil {
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard not installed, skipping peer sync")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if wg0 interface exists
|
||||
out, err := exec.CommandContext(ctx, "sudo", "wg", "show", "wg0").CombinedOutput()
|
||||
if err != nil {
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard interface wg0 not active, skipping peer sync")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse current peers from wg show output
|
||||
currentPeers := parseWGShowPeers(string(out))
|
||||
localPubKey := parseWGShowLocalKey(string(out))
|
||||
|
||||
// Query all peers from RQLite
|
||||
db := n.rqliteAdapter.GetSQLDB()
|
||||
rows, err := db.QueryContext(ctx,
|
||||
"SELECT node_id, wg_ip, public_key, public_ip, wg_port FROM wireguard_peers ORDER BY wg_ip")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query wireguard_peers: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Build desired peer set (excluding self)
|
||||
desiredPeers := make(map[string]production.WireGuardPeer)
|
||||
for rows.Next() {
|
||||
var nodeID, wgIP, pubKey, pubIP string
|
||||
var wgPort int
|
||||
if err := rows.Scan(&nodeID, &wgIP, &pubKey, &pubIP, &wgPort); err != nil {
|
||||
continue
|
||||
}
|
||||
if pubKey == localPubKey {
|
||||
continue // skip self
|
||||
}
|
||||
if wgPort == 0 {
|
||||
wgPort = 51820
|
||||
}
|
||||
desiredPeers[pubKey] = production.WireGuardPeer{
|
||||
PublicKey: pubKey,
|
||||
Endpoint: fmt.Sprintf("%s:%d", pubIP, wgPort),
|
||||
AllowedIP: wgIP + "/32",
|
||||
}
|
||||
}
|
||||
|
||||
wp := &production.WireGuardProvisioner{}
|
||||
|
||||
// Add missing peers
|
||||
for pubKey, peer := range desiredPeers {
|
||||
if _, exists := currentPeers[pubKey]; !exists {
|
||||
if err := wp.AddPeer(peer); err != nil {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "failed to add WG peer",
|
||||
zap.String("public_key", pubKey[:8]+"..."),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "added WG peer",
|
||||
zap.String("allowed_ip", peer.AllowedIP))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove peers not in the desired set
|
||||
for pubKey := range currentPeers {
|
||||
if _, exists := desiredPeers[pubKey]; !exists {
|
||||
if err := wp.RemovePeer(pubKey); err != nil {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "failed to remove stale WG peer",
|
||||
zap.String("public_key", pubKey[:8]+"..."),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "removed stale WG peer",
|
||||
zap.String("public_key", pubKey[:8]+"..."))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard peer sync completed",
|
||||
zap.Int("desired_peers", len(desiredPeers)),
|
||||
zap.Int("current_peers", len(currentPeers)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureWireGuardSelfRegistered ensures this node's WireGuard info is in the
|
||||
// wireguard_peers table. Without this, joining nodes get an empty peer list
|
||||
// from the /v1/internal/join endpoint and can't establish WG tunnels.
|
||||
func (n *Node) ensureWireGuardSelfRegistered(ctx context.Context) {
|
||||
if n.rqliteAdapter == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if wg0 is active
|
||||
out, err := exec.CommandContext(ctx, "sudo", "wg", "show", "wg0").CombinedOutput()
|
||||
if err != nil {
|
||||
return // WG not active, nothing to register
|
||||
}
|
||||
|
||||
// Get local public key
|
||||
localPubKey := parseWGShowLocalKey(string(out))
|
||||
if localPubKey == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Get WG IP from interface
|
||||
wgIP := ""
|
||||
iface, err := net.InterfaceByName("wg0")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if ipnet, ok := addr.(*net.IPNet); ok && ipnet.IP.To4() != nil {
|
||||
wgIP = ipnet.IP.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
if wgIP == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Get public IP
|
||||
publicIP, err := n.getNodeIPAddress()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
nodeID := n.GetPeerID()
|
||||
if nodeID == "" {
|
||||
nodeID = fmt.Sprintf("node-%s", wgIP)
|
||||
}
|
||||
|
||||
db := n.rqliteAdapter.GetSQLDB()
|
||||
_, err = db.ExecContext(ctx,
|
||||
"INSERT OR REPLACE INTO wireguard_peers (node_id, wg_ip, public_key, public_ip, wg_port) VALUES (?, ?, ?, ?, ?)",
|
||||
nodeID, wgIP, localPubKey, publicIP, 51820)
|
||||
if err != nil {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to self-register WG peer", zap.Error(err))
|
||||
} else {
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "WireGuard self-registered",
|
||||
zap.String("wg_ip", wgIP),
|
||||
zap.String("public_key", localPubKey[:8]+"..."))
|
||||
}
|
||||
}
|
||||
|
||||
// startWireGuardSyncLoop runs syncWireGuardPeers periodically
|
||||
func (n *Node) startWireGuardSyncLoop(ctx context.Context) {
|
||||
// Ensure this node is registered in wireguard_peers (critical for join flow)
|
||||
n.ensureWireGuardSelfRegistered(ctx)
|
||||
|
||||
// Run initial sync
|
||||
if err := n.syncWireGuardPeers(ctx); err != nil {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "initial WireGuard peer sync failed", zap.Error(err))
|
||||
}
|
||||
|
||||
// Periodic sync every 60 seconds
|
||||
go func() {
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := n.syncWireGuardPeers(ctx); err != nil {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "WireGuard peer sync failed", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// parseWGShowPeers extracts public keys of current peers from `wg show wg0` output
|
||||
func parseWGShowPeers(output string) map[string]struct{} {
|
||||
peers := make(map[string]struct{})
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "peer:") {
|
||||
key := strings.TrimSpace(strings.TrimPrefix(line, "peer:"))
|
||||
if key != "" {
|
||||
peers[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// parseWGShowLocalKey extracts the local public key from `wg show wg0` output
|
||||
func parseWGShowLocalKey(output string) string {
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "public key:") {
|
||||
return strings.TrimSpace(strings.TrimPrefix(line, "public key:"))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@ -38,6 +38,13 @@ func (r *RQLiteManager) waitForMinClusterSizeBeforeStart(ctx context.Context, rq
|
||||
}
|
||||
|
||||
requiredRemotePeers := r.config.MinClusterSize - 1
|
||||
|
||||
// Genesis node (single-node cluster) doesn't need to wait for peers
|
||||
if requiredRemotePeers <= 0 {
|
||||
r.logger.Info("Genesis node, skipping peer discovery wait")
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = r.discoveryService.TriggerPeerExchange(ctx)
|
||||
|
||||
checkInterval := 2 * time.Second
|
||||
|
||||
@ -17,8 +17,55 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// killOrphanedRQLite kills any orphaned rqlited process still holding the port.
|
||||
// This can happen when the parent node process crashes and rqlited keeps running.
|
||||
func (r *RQLiteManager) killOrphanedRQLite() {
|
||||
// Check if port is already in use by querying the status endpoint
|
||||
url := fmt.Sprintf("http://localhost:%d/status", r.config.RQLitePort)
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return // Port not in use, nothing to clean up
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Port is in use — find and kill the orphaned process
|
||||
r.logger.Warn("Found orphaned rqlited process on port, killing it",
|
||||
zap.Int("port", r.config.RQLitePort))
|
||||
|
||||
// Use fuser to find and kill the process holding the port
|
||||
cmd := exec.Command("fuser", "-k", fmt.Sprintf("%d/tcp", r.config.RQLitePort))
|
||||
if err := cmd.Run(); err != nil {
|
||||
r.logger.Warn("fuser failed, trying lsof", zap.Error(err))
|
||||
// Fallback: use lsof
|
||||
out, err := exec.Command("lsof", "-ti", fmt.Sprintf(":%d", r.config.RQLitePort)).Output()
|
||||
if err == nil {
|
||||
for _, pidStr := range strings.Split(strings.TrimSpace(string(out)), "\n") {
|
||||
if pidStr != "" {
|
||||
killCmd := exec.Command("kill", "-9", pidStr)
|
||||
killCmd.Run()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for port to be released
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return // Port released
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
r.logger.Warn("Could not release port from orphaned process")
|
||||
}
|
||||
|
||||
// launchProcess starts the RQLite process with appropriate arguments
|
||||
func (r *RQLiteManager) launchProcess(ctx context.Context, rqliteDataDir string) error {
|
||||
// Kill any orphaned rqlited from a previous crash
|
||||
r.killOrphanedRQLite()
|
||||
|
||||
// Build RQLite command
|
||||
args := []string{
|
||||
"-http-addr", fmt.Sprintf("0.0.0.0:%d", r.config.RQLitePort),
|
||||
@ -180,21 +227,30 @@ func (r *RQLiteManager) waitForReady(ctx context.Context) error {
|
||||
|
||||
// waitForSQLAvailable waits until a simple query succeeds
|
||||
func (r *RQLiteManager) waitForSQLAvailable(ctx context.Context) error {
|
||||
r.logger.Info("Waiting for SQL to become available...")
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.logger.Error("waitForSQLAvailable timed out", zap.Int("attempts", attempts))
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
attempts++
|
||||
if r.connection == nil {
|
||||
r.logger.Warn("connection is nil in waitForSQLAvailable")
|
||||
continue
|
||||
}
|
||||
_, err := r.connection.QueryOne("SELECT 1")
|
||||
if err == nil {
|
||||
r.logger.Info("SQL is available", zap.Int("attempts", attempts))
|
||||
return nil
|
||||
}
|
||||
if attempts <= 5 || attempts%10 == 0 {
|
||||
r.logger.Debug("SQL not yet available", zap.Int("attempt", attempts), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,6 +74,9 @@ on:github.com/coredns/caddy/onevent
|
||||
sign:sign
|
||||
view:view
|
||||
|
||||
# Response Rate Limiting (DNS amplification protection)
|
||||
rrl:rrl
|
||||
|
||||
# Custom RQLite plugin
|
||||
rqlite:github.com/DeBrosOfficial/network/pkg/coredns/rqlite
|
||||
EOF
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user