From f26676db2c7a0a771e77781fe6f8c199b99993da Mon Sep 17 00:00:00 2001 From: anonpenguin23 Date: Fri, 27 Feb 2026 15:22:51 +0200 Subject: [PATCH] feat: add sandbox command and vault guardian build - integrate Zig-built vault-guardian into cross-compile process - add `orama sandbox` for ephemeral Hetzner Cloud clusters - update docs for `orama node` subcommands and new guides --- README.md | 14 +- cmd/cli/root.go | 4 + docs/COMMON_PROBLEMS.md | 16 +- pkg/cli/build/builder.go | 108 +++- pkg/environments/production/config.go | 24 + pkg/environments/production/orchestrator.go | 28 +- pkg/environments/production/prebuilt.go | 2 + pkg/environments/production/provisioner.go | 2 + pkg/environments/production/services.go | 37 ++ pkg/gateway/gateway.go | 5 + pkg/gateway/handlers/vault/handlers.go | 132 +++++ pkg/gateway/handlers/vault/health_handler.go | 116 +++++ pkg/gateway/handlers/vault/pull_handler.go | 183 +++++++ pkg/gateway/handlers/vault/push_handler.go | 168 +++++++ pkg/gateway/handlers/vault/rate_limiter.go | 120 +++++ pkg/gateway/middleware.go | 5 + pkg/gateway/routes.go | 8 + pkg/shamir/field.go | 82 +++ pkg/shamir/shamir.go | 150 ++++++ pkg/shamir/shamir_test.go | 501 +++++++++++++++++++ 20 files changed, 1669 insertions(+), 36 deletions(-) create mode 100644 pkg/gateway/handlers/vault/handlers.go create mode 100644 pkg/gateway/handlers/vault/health_handler.go create mode 100644 pkg/gateway/handlers/vault/pull_handler.go create mode 100644 pkg/gateway/handlers/vault/push_handler.go create mode 100644 pkg/gateway/handlers/vault/rate_limiter.go create mode 100644 pkg/shamir/field.go create mode 100644 pkg/shamir/shamir.go create mode 100644 pkg/shamir/shamir_test.go diff --git a/README.md b/README.md index 2b416c2..d8119d6 100644 --- a/README.md +++ b/README.md @@ -349,13 +349,13 @@ All configuration lives in `~/.orama/`: ```bash # Check status -systemctl status orama-node +sudo orama node status # View logs -journalctl -u orama-node -f +orama node logs node --follow # Check log files -tail -f /opt/orama/.orama/logs/node.log +sudo orama node doctor ``` ### Port Conflicts @@ -417,9 +417,11 @@ See `openapi/gateway.yaml` for complete API specification. - **[Deployment Guide](docs/DEPLOYMENT_GUIDE.md)** - Deploy React, Next.js, Go apps and manage databases - **[Architecture Guide](docs/ARCHITECTURE.md)** - System architecture and design patterns - **[Client SDK](docs/CLIENT_SDK.md)** - Go SDK documentation and examples -- **[Gateway API](docs/GATEWAY_API.md)** - Complete HTTP API reference -- **[Security Deployment](docs/SECURITY_DEPLOYMENT_GUIDE.md)** - Production security hardening -- **[Testing Plan](docs/TESTING_PLAN.md)** - Comprehensive testing strategy and implementation +- **[Monitoring](docs/MONITORING.md)** - Cluster monitoring and health checks +- **[Inspector](docs/INSPECTOR.md)** - Deep subsystem health inspection +- **[Serverless Functions](docs/SERVERLESS.md)** - WASM serverless with host functions +- **[WebRTC](docs/WEBRTC.md)** - Real-time communication setup +- **[Common Problems](docs/COMMON_PROBLEMS.md)** - Troubleshooting known issues ## Resources diff --git a/cmd/cli/root.go b/cmd/cli/root.go index 266fc9b..0f27fdd 100644 --- a/cmd/cli/root.go +++ b/cmd/cli/root.go @@ -18,6 +18,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/cli/cmd/monitorcmd" "github.com/DeBrosOfficial/network/pkg/cli/cmd/namespacecmd" "github.com/DeBrosOfficial/network/pkg/cli/cmd/node" + "github.com/DeBrosOfficial/network/pkg/cli/cmd/sandboxcmd" ) // version metadata populated via -ldflags at build time @@ -87,6 +88,9 @@ and interacting with the Orama distributed network.`, // Build command (cross-compile binary archive) rootCmd.AddCommand(buildcmd.Cmd) + // Sandbox command (ephemeral Hetzner Cloud clusters) + rootCmd.AddCommand(sandboxcmd.Cmd) + return rootCmd } diff --git a/docs/COMMON_PROBLEMS.md b/docs/COMMON_PROBLEMS.md index f54c938..ae6d9ff 100644 --- a/docs/COMMON_PROBLEMS.md +++ b/docs/COMMON_PROBLEMS.md @@ -32,7 +32,7 @@ wg set wg0 peer remove wg set wg0 peer endpoint :51820 allowed-ips /32 persistent-keepalive 25 ``` -Then restart services: `sudo orama prod restart` +Then restart services: `sudo orama node restart` You can find peer public keys with `wg show wg0`. @@ -46,7 +46,7 @@ cat /opt/orama/.orama/data/namespaces//configs/olric-*.yaml If `bindAddr` is `0.0.0.0`, the node will try to bind to IPv6 on dual-stack hosts, breaking memberlist gossip. -**Fix:** Edit the YAML to use the node's WireGuard IP (run `ip addr show wg0` to find it), then restart: `sudo orama prod restart` +**Fix:** Edit the YAML to use the node's WireGuard IP (run `ip addr show wg0` to find it), then restart: `sudo orama node restart` This was fixed in code (BindAddr validation in `SpawnOlric`), so new namespaces won't have this issue. @@ -82,7 +82,7 @@ olric_servers: - "10.0.0.Z:10002" ``` -Then: `sudo orama prod restart` +Then: `sudo orama node restart` This was fixed in code, so new namespaces get the correct config. @@ -90,7 +90,7 @@ This was fixed in code, so new namespaces get the correct config. ## 3. Namespace not restoring after restart (missing cluster-state.json) -**Symptom:** After `orama prod restart`, the namespace services don't come back because `RestoreLocalClustersFromDisk` has no state file. +**Symptom:** After `orama node restart`, the namespace services don't come back because `RestoreLocalClustersFromDisk` has no state file. **Check:** @@ -117,9 +117,9 @@ This was fixed in code — `ProvisionCluster` now saves state to all nodes (incl ## 4. Namespace gateway processes not restarting after upgrade -**Symptom:** After `orama upgrade --restart` or `orama prod restart`, namespace gateway/olric/rqlite services don't start. +**Symptom:** After `orama upgrade --restart` or `orama node restart`, namespace gateway/olric/rqlite services don't start. -**Cause:** `orama prod stop` disables systemd template services (`orama-namespace-gateway@.service`). They have `PartOf=orama-node.service`, but that only propagates restart to **enabled** services. +**Cause:** `orama node stop` disables systemd template services (`orama-namespace-gateway@.service`). They have `PartOf=orama-node.service`, but that only propagates restart to **enabled** services. **Fix:** Re-enable the services before restarting: @@ -127,7 +127,7 @@ This was fixed in code — `ProvisionCluster` now saves state to all nodes (incl systemctl enable orama-namespace-rqlite@.service systemctl enable orama-namespace-olric@.service systemctl enable orama-namespace-gateway@.service -sudo orama prod restart +sudo orama node restart ``` This was fixed in code — the upgrade orchestrator now re-enables `@` services before restarting. @@ -152,7 +152,7 @@ ssh -n user@host 'command' ## General Debugging Tips -- **Always use `sudo orama prod restart`** instead of raw `systemctl` commands +- **Always use `sudo orama node restart`** instead of raw `systemctl` commands - **Namespace data lives at:** `/opt/orama/.orama/data/namespaces//` - **Check service logs:** `journalctl -u orama-namespace-olric@.service --no-pager -n 50` - **Check WireGuard:** `wg show wg0` — look for recent handshakes and transfer bytes diff --git a/pkg/cli/build/builder.go b/pkg/cli/build/builder.go index de82016..4514f6b 100644 --- a/pkg/cli/build/builder.go +++ b/pkg/cli/build/builder.go @@ -71,48 +71,53 @@ func (b *Builder) Build() error { return fmt.Errorf("failed to build orama binaries: %w", err) } - // Step 2: Cross-compile Olric + // Step 2: Cross-compile Vault Guardian (Zig) + if err := b.buildVaultGuardian(); err != nil { + return fmt.Errorf("failed to build vault-guardian: %w", err) + } + + // Step 3: Cross-compile Olric if err := b.buildOlric(); err != nil { return fmt.Errorf("failed to build olric: %w", err) } - // Step 3: Cross-compile IPFS Cluster + // Step 4: Cross-compile IPFS Cluster if err := b.buildIPFSCluster(); err != nil { return fmt.Errorf("failed to build ipfs-cluster: %w", err) } - // Step 4: Build CoreDNS with RQLite plugin + // Step 5: Build CoreDNS with RQLite plugin if err := b.buildCoreDNS(); err != nil { return fmt.Errorf("failed to build coredns: %w", err) } - // Step 5: Build Caddy with Orama DNS module + // Step 6: Build Caddy with Orama DNS module if err := b.buildCaddy(); err != nil { return fmt.Errorf("failed to build caddy: %w", err) } - // Step 6: Download pre-built IPFS Kubo + // Step 7: Download pre-built IPFS Kubo if err := b.downloadIPFS(); err != nil { return fmt.Errorf("failed to download ipfs: %w", err) } - // Step 7: Download pre-built RQLite + // Step 8: Download pre-built RQLite if err := b.downloadRQLite(); err != nil { return fmt.Errorf("failed to download rqlite: %w", err) } - // Step 8: Copy systemd templates + // Step 9: Copy systemd templates if err := b.copySystemdTemplates(); err != nil { return fmt.Errorf("failed to copy systemd templates: %w", err) } - // Step 9: Generate manifest + // Step 10: Generate manifest manifest, err := b.generateManifest() if err != nil { return fmt.Errorf("failed to generate manifest: %w", err) } - // Step 10: Create archive + // Step 11: Create archive outputPath := b.flags.Output if outputPath == "" { outputPath = fmt.Sprintf("/tmp/orama-%s-linux-%s.tar.gz", b.version, b.flags.Arch) @@ -130,7 +135,7 @@ func (b *Builder) Build() error { } func (b *Builder) buildOramaBinaries() error { - fmt.Println("[1/7] Cross-compiling Orama binaries...") + fmt.Println("[1/8] Cross-compiling Orama binaries...") ldflags := fmt.Sprintf("-s -w -X 'main.version=%s' -X 'main.commit=%s' -X 'main.date=%s'", b.version, b.commit, b.date) @@ -177,8 +182,79 @@ func (b *Builder) buildOramaBinaries() error { return nil } +func (b *Builder) buildVaultGuardian() error { + fmt.Println("[2/8] Cross-compiling Vault Guardian (Zig)...") + + // Ensure zig is available + if _, err := exec.LookPath("zig"); err != nil { + return fmt.Errorf("zig not found in PATH — install from https://ziglang.org/download/") + } + + // Vault source is sibling to orama project + vaultDir := filepath.Join(b.projectDir, "..", "orama-vault") + if _, err := os.Stat(filepath.Join(vaultDir, "build.zig")); err != nil { + return fmt.Errorf("vault source not found at %s — expected orama-vault as sibling directory: %w", vaultDir, err) + } + + // Map Go arch to Zig target triple + var zigTarget string + switch b.flags.Arch { + case "amd64": + zigTarget = "x86_64-linux-musl" + case "arm64": + zigTarget = "aarch64-linux-musl" + default: + return fmt.Errorf("unsupported architecture for vault: %s", b.flags.Arch) + } + + if b.flags.Verbose { + fmt.Printf(" zig build -Dtarget=%s -Doptimize=ReleaseSafe\n", zigTarget) + } + + cmd := exec.Command("zig", "build", + fmt.Sprintf("-Dtarget=%s", zigTarget), + "-Doptimize=ReleaseSafe") + cmd.Dir = vaultDir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("zig build failed: %w", err) + } + + // Copy output binary to build bin dir + src := filepath.Join(vaultDir, "zig-out", "bin", "vault-guardian") + dst := filepath.Join(b.binDir, "vault-guardian") + if err := copyFile(src, dst); err != nil { + return fmt.Errorf("failed to copy vault-guardian binary: %w", err) + } + + fmt.Println(" ✓ vault-guardian") + return nil +} + +// copyFile copies a file from src to dst, preserving executable permissions. +func copyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + dstFile, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0755) + if err != nil { + return err + } + defer dstFile.Close() + + if _, err := srcFile.WriteTo(dstFile); err != nil { + return err + } + return nil +} + func (b *Builder) buildOlric() error { - fmt.Printf("[2/7] Cross-compiling Olric %s...\n", constants.OlricVersion) + fmt.Printf("[3/8] Cross-compiling Olric %s...\n", constants.OlricVersion) cmd := exec.Command("go", "install", fmt.Sprintf("github.com/olric-data/olric/cmd/olric-server@%s", constants.OlricVersion)) @@ -197,7 +273,7 @@ func (b *Builder) buildOlric() error { } func (b *Builder) buildIPFSCluster() error { - fmt.Printf("[3/7] Cross-compiling IPFS Cluster %s...\n", constants.IPFSClusterVersion) + fmt.Printf("[4/8] Cross-compiling IPFS Cluster %s...\n", constants.IPFSClusterVersion) cmd := exec.Command("go", "install", fmt.Sprintf("github.com/ipfs-cluster/ipfs-cluster/cmd/ipfs-cluster-service@%s", constants.IPFSClusterVersion)) @@ -216,7 +292,7 @@ func (b *Builder) buildIPFSCluster() error { } func (b *Builder) buildCoreDNS() error { - fmt.Printf("[4/7] Building CoreDNS %s with RQLite plugin...\n", constants.CoreDNSVersion) + fmt.Printf("[5/8] Building CoreDNS %s with RQLite plugin...\n", constants.CoreDNSVersion) buildDir := filepath.Join(b.tmpDir, "coredns-build") @@ -363,7 +439,7 @@ rqlite:rqlite } func (b *Builder) buildCaddy() error { - fmt.Printf("[5/7] Building Caddy %s with Orama DNS module...\n", constants.CaddyVersion) + fmt.Printf("[6/8] Building Caddy %s with Orama DNS module...\n", constants.CaddyVersion) // Ensure xcaddy is available if _, err := exec.LookPath("xcaddy"); err != nil { @@ -429,7 +505,7 @@ require ( } func (b *Builder) downloadIPFS() error { - fmt.Printf("[6/7] Downloading IPFS Kubo %s...\n", constants.IPFSKuboVersion) + fmt.Printf("[7/8] Downloading IPFS Kubo %s...\n", constants.IPFSKuboVersion) arch := b.flags.Arch tarball := fmt.Sprintf("kubo_%s_linux-%s.tar.gz", constants.IPFSKuboVersion, arch) @@ -450,7 +526,7 @@ func (b *Builder) downloadIPFS() error { } func (b *Builder) downloadRQLite() error { - fmt.Printf("[7/7] Downloading RQLite %s...\n", constants.RQLiteVersion) + fmt.Printf("[8/8] Downloading RQLite %s...\n", constants.RQLiteVersion) arch := b.flags.Arch tarball := fmt.Sprintf("rqlite-v%s-linux-%s.tar.gz", constants.RQLiteVersion, arch) diff --git a/pkg/environments/production/config.go b/pkg/environments/production/config.go index 3bf6cd1..8de319a 100644 --- a/pkg/environments/production/config.go +++ b/pkg/environments/production/config.go @@ -194,6 +194,30 @@ func (cg *ConfigGenerator) GenerateNodeConfig(peerAddresses []string, vpsIP stri return templates.RenderNodeConfig(data) } +// GenerateVaultConfig generates vault.yaml configuration for the Vault Guardian. +// The vault config uses key=value format (not YAML, despite the file extension). +// Peer discovery is dynamic via RQLite — no static peer list needed. +func (cg *ConfigGenerator) GenerateVaultConfig(vpsIP string) string { + dataDir := filepath.Join(cg.oramaDir, "data", "vault") + + // Bind to WireGuard IP so vault is only accessible over the overlay network. + // If no WG IP is provided, bind to localhost as a safe default. + bindAddr := "127.0.0.1" + if vpsIP != "" { + bindAddr = vpsIP + } + + return fmt.Sprintf(`# Vault Guardian Configuration +# Generated by orama node install + +listen_address = %s +client_port = 7500 +peer_port = 7501 +data_dir = %s +rqlite_url = http://127.0.0.1:5001 +`, bindAddr, dataDir) +} + // GenerateGatewayConfig generates gateway.yaml configuration func (cg *ConfigGenerator) GenerateGatewayConfig(peerAddresses []string, enableHTTPS bool, domain string, olricServers []string) (string, error) { tlsCacheDir := "" diff --git a/pkg/environments/production/orchestrator.go b/pkg/environments/production/orchestrator.go index 7e5d371..65dd7c8 100644 --- a/pkg/environments/production/orchestrator.go +++ b/pkg/environments/production/orchestrator.go @@ -573,6 +573,14 @@ func (ps *ProductionSetup) Phase4GenerateConfigs(peerAddresses []string, vpsIP s } ps.logf(" ✓ Olric config generated") + // Vault Guardian config + vaultConfig := ps.configGenerator.GenerateVaultConfig(vpsIP) + vaultConfigPath := filepath.Join(ps.oramaDir, "data", "vault", "vault.yaml") + if err := os.WriteFile(vaultConfigPath, []byte(vaultConfig), 0644); err != nil { + return fmt.Errorf("failed to save vault config: %w", err) + } + ps.logf(" ✓ Vault config generated") + // Configure CoreDNS (if baseDomain is provided - this is the zone name) // CoreDNS uses baseDomain (e.g., "dbrs.space") as the authoritative zone dnsZone := baseDomain @@ -667,6 +675,13 @@ func (ps *ProductionSetup) Phase5CreateSystemdServices(enableHTTPS bool) error { } ps.logf(" ✓ Node service created: orama-node.service (with embedded gateway)") + // Vault Guardian service + vaultUnit := ps.serviceGenerator.GenerateVaultService() + if err := ps.serviceController.WriteServiceUnit("orama-vault.service", vaultUnit); err != nil { + return fmt.Errorf("failed to write Vault service: %w", err) + } + ps.logf(" ✓ Vault service created: orama-vault.service") + // Anyone Relay service (only created when --anyone-relay flag is used) // A node must run EITHER relay OR client, never both. When writing one // mode's service, we remove the other to prevent conflicts (they share @@ -725,7 +740,7 @@ func (ps *ProductionSetup) Phase5CreateSystemdServices(enableHTTPS bool) error { // Enable services (unified names - no bootstrap/node distinction) // Note: orama-gateway.service is no longer needed - each node has an embedded gateway // Note: orama-rqlite.service is NOT created - RQLite is managed by each node internally - services := []string{"orama-ipfs.service", "orama-ipfs-cluster.service", "orama-olric.service", "orama-node.service"} + services := []string{"orama-ipfs.service", "orama-ipfs-cluster.service", "orama-olric.service", "orama-vault.service", "orama-node.service"} // Add Anyone service if configured (relay or client) if ps.IsAnyoneRelay() { @@ -756,8 +771,8 @@ func (ps *ProductionSetup) Phase5CreateSystemdServices(enableHTTPS bool) error { // services pick up new configs even if already running from a previous install) ps.logf(" Starting services...") - // Start infrastructure first (IPFS, Olric, Anyone) - RQLite is managed internally by each node - infraServices := []string{"orama-ipfs.service", "orama-olric.service"} + // Start infrastructure first (IPFS, Olric, Vault, Anyone) - RQLite is managed internally by each node + infraServices := []string{"orama-ipfs.service", "orama-olric.service", "orama-vault.service"} // Add Anyone service if configured (relay or client) if ps.IsAnyoneRelay() { @@ -977,12 +992,13 @@ func (ps *ProductionSetup) LogSetupComplete(peerID string) { ps.logf(" %s/logs/olric.log", ps.oramaDir) ps.logf(" %s/logs/node.log", ps.oramaDir) ps.logf(" %s/logs/gateway.log", ps.oramaDir) + ps.logf(" %s/logs/vault.log", ps.oramaDir) // Anyone mode-specific logs and commands if ps.IsAnyoneRelay() { ps.logf(" /var/log/anon/notices.log (Anyone Relay)") ps.logf("\nStart All Services:") - ps.logf(" systemctl start orama-ipfs orama-ipfs-cluster orama-olric orama-anyone-relay orama-node") + ps.logf(" systemctl start orama-ipfs orama-ipfs-cluster orama-olric orama-vault orama-anyone-relay orama-node") ps.logf("\nAnyone Relay Operator:") ps.logf(" ORPort: %d", ps.anyoneRelayConfig.ORPort) ps.logf(" Wallet: %s", ps.anyoneRelayConfig.Wallet) @@ -991,10 +1007,10 @@ func (ps *ProductionSetup) LogSetupComplete(peerID string) { ps.logf(" IMPORTANT: You need 100 $ANYONE tokens in your wallet to receive rewards") } else if ps.IsAnyoneClient() { ps.logf("\nStart All Services:") - ps.logf(" systemctl start orama-ipfs orama-ipfs-cluster orama-olric orama-anyone-client orama-node") + ps.logf(" systemctl start orama-ipfs orama-ipfs-cluster orama-olric orama-vault orama-anyone-client orama-node") } else { ps.logf("\nStart All Services:") - ps.logf(" systemctl start orama-ipfs orama-ipfs-cluster orama-olric orama-node") + ps.logf(" systemctl start orama-ipfs orama-ipfs-cluster orama-olric orama-vault orama-node") } ps.logf("\nVerify Installation:") diff --git a/pkg/environments/production/prebuilt.go b/pkg/environments/production/prebuilt.go index 689b8ba..04d4233 100644 --- a/pkg/environments/production/prebuilt.go +++ b/pkg/environments/production/prebuilt.go @@ -127,6 +127,8 @@ func (ps *ProductionSetup) deployPreBuiltBinaries(manifest *PreBuiltManifest) er {name: "coredns", dest: "/usr/local/bin/coredns"}, {name: "caddy", dest: "/usr/bin/caddy"}, } + // Note: vault-guardian stays at /opt/orama/bin/ (from archive extraction) + // and is referenced by absolute path in the systemd service — no copy needed. for _, bin := range binaries { srcPath := filepath.Join(OramaArchiveBin, bin.name) diff --git a/pkg/environments/production/provisioner.go b/pkg/environments/production/provisioner.go index 259e213..97e3089 100644 --- a/pkg/environments/production/provisioner.go +++ b/pkg/environments/production/provisioner.go @@ -34,6 +34,7 @@ func (fp *FilesystemProvisioner) EnsureDirectoryStructure() error { filepath.Join(fp.oramaDir, "data", "ipfs", "repo"), filepath.Join(fp.oramaDir, "data", "ipfs-cluster"), filepath.Join(fp.oramaDir, "data", "rqlite"), + filepath.Join(fp.oramaDir, "data", "vault"), filepath.Join(fp.oramaDir, "logs"), filepath.Join(fp.oramaDir, "tls-cache"), filepath.Join(fp.oramaDir, "backups"), @@ -65,6 +66,7 @@ func (fp *FilesystemProvisioner) EnsureDirectoryStructure() error { "ipfs.log", "ipfs-cluster.log", "node.log", + "vault.log", "anyone-client.log", } diff --git a/pkg/environments/production/services.go b/pkg/environments/production/services.go index 0070cb9..2e47f66 100644 --- a/pkg/environments/production/services.go +++ b/pkg/environments/production/services.go @@ -214,6 +214,43 @@ WantedBy=multi-user.target `, ssg.oramaHome, ssg.oramaDir, configFile, logFile) } +// GenerateVaultService generates the Orama Vault Guardian systemd unit. +// The vault guardian runs on every node, storing Shamir secret shares. +// It binds to the WireGuard overlay only (no public exposure). +func (ssg *SystemdServiceGenerator) GenerateVaultService() string { + logFile := filepath.Join(ssg.oramaDir, "logs", "vault.log") + dataDir := filepath.Join(ssg.oramaDir, "data", "vault") + + return fmt.Sprintf(`[Unit] +Description=Orama Vault Guardian +After=network-online.target wg-quick@wg0.service +Wants=network-online.target +Requires=wg-quick@wg0.service +PartOf=orama-node.service + +[Service] +Type=simple +ExecStart=%[1]s/bin/vault-guardian --config %[2]s/vault.yaml +Restart=on-failure +RestartSec=5 +StandardOutput=append:%[3]s +StandardError=append:%[3]s +SyslogIdentifier=orama-vault + +PrivateTmp=yes +ProtectSystem=strict +ReadWritePaths=%[2]s +NoNewPrivileges=yes +LimitMEMLOCK=67108864 +MemoryMax=512M +TimeoutStopSec=30 +KillMode=mixed + +[Install] +WantedBy=multi-user.target +`, ssg.oramaHome, dataDir, logFile) +} + // GenerateGatewayService generates the Orama Gateway systemd unit func (ssg *SystemdServiceGenerator) GenerateGatewayService() string { logFile := filepath.Join(ssg.oramaDir, "logs", "gateway.log") diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index c597343..d7088c3 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -31,6 +31,7 @@ import ( serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join" webrtchandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/webrtc" + vaulthandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/vault" wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard" sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite" "github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage" @@ -162,6 +163,9 @@ type Gateway struct { // Shared HTTP transport for proxy connections (connection pooling) proxyTransport *http.Transport + // Vault proxy handlers + vaultHandlers *vaulthandlers.Handlers + // Namespace health state (local service probes + hourly reconciliation) nsHealth *namespaceHealthState } @@ -395,6 +399,7 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { if deps.ORMClient != nil { gw.wireguardHandler = wireguardhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.ClusterSecret) gw.joinHandler = joinhandlers.NewHandler(logger.Logger, deps.ORMClient, cfg.DataDir) + gw.vaultHandlers = vaulthandlers.NewHandlers(logger, deps.Client) } // Initialize deployment system diff --git a/pkg/gateway/handlers/vault/handlers.go b/pkg/gateway/handlers/vault/handlers.go new file mode 100644 index 0000000..ec80dcb --- /dev/null +++ b/pkg/gateway/handlers/vault/handlers.go @@ -0,0 +1,132 @@ +// Package vault provides HTTP handlers for vault proxy operations. +// +// The gateway acts as a smart proxy between RootWallet clients and +// vault guardian nodes on the WireGuard overlay network. It handles +// Shamir split/combine so clients make a single HTTPS call. +package vault + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/DeBrosOfficial/network/pkg/client" + "github.com/DeBrosOfficial/network/pkg/logging" +) + +const ( + // VaultGuardianPort is the port vault guardians listen on (client API). + VaultGuardianPort = 7500 + + // guardianTimeout is the per-guardian HTTP request timeout. + guardianTimeout = 5 * time.Second + + // overallTimeout is the maximum time for the full fan-out operation. + overallTimeout = 15 * time.Second + + // maxPushBodySize limits push request bodies (1 MiB). + maxPushBodySize = 1 << 20 + + // maxPullBodySize limits pull request bodies (4 KiB). + maxPullBodySize = 4 << 10 +) + +// Handlers provides HTTP handlers for vault proxy operations. +type Handlers struct { + logger *logging.ColoredLogger + dbClient client.NetworkClient + rateLimiter *IdentityRateLimiter + httpClient *http.Client +} + +// NewHandlers creates vault proxy handlers. +func NewHandlers(logger *logging.ColoredLogger, dbClient client.NetworkClient) *Handlers { + h := &Handlers{ + logger: logger, + dbClient: dbClient, + rateLimiter: NewIdentityRateLimiter( + 30, // 30 pushes per hour per identity + 120, // 120 pulls per hour per identity + ), + httpClient: &http.Client{ + Timeout: guardianTimeout, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + }, + }, + } + h.rateLimiter.StartCleanup(10*time.Minute, 1*time.Hour) + return h +} + +// guardian represents a reachable vault guardian node. +type guardian struct { + IP string + Port int +} + +// discoverGuardians queries dns_nodes for all active nodes. +// Every Orama node runs a vault guardian, so every active node is a guardian. +func (h *Handlers) discoverGuardians(ctx context.Context) ([]guardian, error) { + db := h.dbClient.Database() + internalCtx := client.WithInternalAuth(ctx) + + query := "SELECT COALESCE(internal_ip, ip_address) FROM dns_nodes WHERE status = 'active'" + result, err := db.Query(internalCtx, query) + if err != nil { + return nil, fmt.Errorf("vault: failed to query guardian nodes: %w", err) + } + if result == nil || len(result.Rows) == 0 { + return nil, fmt.Errorf("vault: no active guardian nodes found") + } + + guardians := make([]guardian, 0, len(result.Rows)) + for _, row := range result.Rows { + if len(row) == 0 { + continue + } + ip := getString(row[0]) + if ip == "" { + continue + } + guardians = append(guardians, guardian{IP: ip, Port: VaultGuardianPort}) + } + if len(guardians) == 0 { + return nil, fmt.Errorf("vault: no guardian nodes with valid IPs found") + } + return guardians, nil +} + +func writeJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, status int, msg string) { + writeJSON(w, status, map[string]string{"error": msg}) +} + +func getString(v interface{}) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +// isValidIdentity checks that identity is exactly 64 hex characters. +func isValidIdentity(identity string) bool { + if len(identity) != 64 { + return false + } + for _, c := range identity { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) { + return false + } + } + return true +} diff --git a/pkg/gateway/handlers/vault/health_handler.go b/pkg/gateway/handlers/vault/health_handler.go new file mode 100644 index 0000000..e5dd702 --- /dev/null +++ b/pkg/gateway/handlers/vault/health_handler.go @@ -0,0 +1,116 @@ +package vault + +import ( + "context" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + + "github.com/DeBrosOfficial/network/pkg/shamir" +) + +// HealthResponse is returned for GET /v1/vault/health. +type HealthResponse struct { + Status string `json:"status"` // "healthy", "degraded", "unavailable" +} + +// StatusResponse is returned for GET /v1/vault/status. +type StatusResponse struct { + Guardians int `json:"guardians"` // Total guardian nodes + Healthy int `json:"healthy"` // Reachable guardians + Threshold int `json:"threshold"` // Read quorum (K) + WriteQuorum int `json:"write_quorum"` // Write quorum (W) +} + +// HandleHealth processes GET /v1/vault/health. +func (h *Handlers) HandleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + guardians, err := h.discoverGuardians(r.Context()) + if err != nil { + writeJSON(w, http.StatusOK, HealthResponse{Status: "unavailable"}) + return + } + + n := len(guardians) + healthy := h.probeGuardians(r.Context(), guardians) + + k := shamir.AdaptiveThreshold(n) + wq := shamir.WriteQuorum(n) + + status := "healthy" + if healthy < wq { + if healthy >= k { + status = "degraded" + } else { + status = "unavailable" + } + } + + writeJSON(w, http.StatusOK, HealthResponse{Status: status}) +} + +// HandleStatus processes GET /v1/vault/status. +func (h *Handlers) HandleStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + guardians, err := h.discoverGuardians(r.Context()) + if err != nil { + writeJSON(w, http.StatusOK, StatusResponse{}) + return + } + + n := len(guardians) + healthy := h.probeGuardians(r.Context(), guardians) + + writeJSON(w, http.StatusOK, StatusResponse{ + Guardians: n, + Healthy: healthy, + Threshold: shamir.AdaptiveThreshold(n), + WriteQuorum: shamir.WriteQuorum(n), + }) +} + +// probeGuardians checks health of all guardians in parallel and returns the healthy count. +func (h *Handlers) probeGuardians(ctx context.Context, guardians []guardian) int { + ctx, cancel := context.WithTimeout(ctx, guardianTimeout) + defer cancel() + + var healthyCount atomic.Int32 + var wg sync.WaitGroup + wg.Add(len(guardians)) + + for _, g := range guardians { + go func(gd guardian) { + defer wg.Done() + + url := fmt.Sprintf("http://%s:%d/v1/vault/health", gd.IP, gd.Port) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return + } + + resp, err := h.httpClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + io.Copy(io.Discard, resp.Body) + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + healthyCount.Add(1) + } + }(g) + } + + wg.Wait() + return int(healthyCount.Load()) +} diff --git a/pkg/gateway/handlers/vault/pull_handler.go b/pkg/gateway/handlers/vault/pull_handler.go new file mode 100644 index 0000000..2164487 --- /dev/null +++ b/pkg/gateway/handlers/vault/pull_handler.go @@ -0,0 +1,183 @@ +package vault + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/shamir" + "go.uber.org/zap" +) + +// PullRequest is the client-facing request body. +type PullRequest struct { + Identity string `json:"identity"` // 64 hex chars +} + +// PullResponse is returned to the client. +type PullResponse struct { + Envelope string `json:"envelope"` // base64-encoded reconstructed envelope + Collected int `json:"collected"` // Number of shares collected + Threshold int `json:"threshold"` // K threshold used +} + +// guardianPullRequest is sent to each vault guardian. +type guardianPullRequest struct { + Identity string `json:"identity"` +} + +// guardianPullResponse is the response from a guardian. +type guardianPullResponse struct { + Share string `json:"share"` // base64([x:1byte][y:rest]) +} + +// HandlePull processes POST /v1/vault/pull. +func (h *Handlers) HandlePull(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + body, err := io.ReadAll(io.LimitReader(r.Body, maxPullBodySize)) + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return + } + + var req PullRequest + if err := json.Unmarshal(body, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + if !isValidIdentity(req.Identity) { + writeError(w, http.StatusBadRequest, "identity must be 64 hex characters") + return + } + + if !h.rateLimiter.AllowPull(req.Identity) { + w.Header().Set("Retry-After", "30") + writeError(w, http.StatusTooManyRequests, "pull rate limit exceeded for this identity") + return + } + + guardians, err := h.discoverGuardians(r.Context()) + if err != nil { + h.logger.ComponentError(logging.ComponentGeneral, "Vault pull: guardian discovery failed", zap.Error(err)) + writeError(w, http.StatusServiceUnavailable, "no guardian nodes available") + return + } + + n := len(guardians) + k := shamir.AdaptiveThreshold(n) + + // Fan out pull requests to all guardians. + ctx, cancel := context.WithTimeout(r.Context(), overallTimeout) + defer cancel() + + type shareResult struct { + share shamir.Share + ok bool + } + + results := make([]shareResult, n) + var wg sync.WaitGroup + wg.Add(n) + + for i, g := range guardians { + go func(idx int, gd guardian) { + defer wg.Done() + + guardianReq := guardianPullRequest{Identity: req.Identity} + reqBody, _ := json.Marshal(guardianReq) + + url := fmt.Sprintf("http://%s:%d/v1/vault/pull", gd.IP, gd.Port) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(reqBody)) + if err != nil { + return + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := h.httpClient.Do(httpReq) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + io.Copy(io.Discard, resp.Body) + return + } + + var pullResp guardianPullResponse + if err := json.NewDecoder(resp.Body).Decode(&pullResp); err != nil { + return + } + + shareBytes, err := base64.StdEncoding.DecodeString(pullResp.Share) + if err != nil || len(shareBytes) < 2 { + return + } + + results[idx] = shareResult{ + share: shamir.Share{ + X: shareBytes[0], + Y: shareBytes[1:], + }, + ok: true, + } + }(i, g) + } + + wg.Wait() + + // Collect successful shares. + shares := make([]shamir.Share, 0, n) + for _, r := range results { + if r.ok { + shares = append(shares, r.share) + } + } + + if len(shares) < k { + h.logger.ComponentError(logging.ComponentGeneral, "Vault pull: not enough shares", + zap.Int("collected", len(shares)), zap.Int("total", n), zap.Int("threshold", k)) + writeError(w, http.StatusServiceUnavailable, + fmt.Sprintf("not enough shares: collected %d of %d required (contacted %d guardians)", len(shares), k, n)) + return + } + + // Shamir combine to reconstruct envelope. + envelope, err := shamir.Combine(shares[:k]) + if err != nil { + h.logger.ComponentError(logging.ComponentGeneral, "Vault pull: Shamir combine failed", zap.Error(err)) + writeError(w, http.StatusInternalServerError, "failed to reconstruct envelope") + return + } + + // Wipe collected shares. + for i := range shares { + for j := range shares[i].Y { + shares[i].Y[j] = 0 + } + } + + envelopeB64 := base64.StdEncoding.EncodeToString(envelope) + + // Wipe envelope. + for i := range envelope { + envelope[i] = 0 + } + + writeJSON(w, http.StatusOK, PullResponse{ + Envelope: envelopeB64, + Collected: len(shares), + Threshold: k, + }) +} diff --git a/pkg/gateway/handlers/vault/push_handler.go b/pkg/gateway/handlers/vault/push_handler.go new file mode 100644 index 0000000..b3e729d --- /dev/null +++ b/pkg/gateway/handlers/vault/push_handler.go @@ -0,0 +1,168 @@ +package vault + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/shamir" + "go.uber.org/zap" +) + +// PushRequest is the client-facing request body. +type PushRequest struct { + Identity string `json:"identity"` // 64 hex chars (SHA-256) + Envelope string `json:"envelope"` // base64-encoded encrypted envelope + Version uint64 `json:"version"` // Anti-rollback version counter +} + +// PushResponse is returned to the client. +type PushResponse struct { + Status string `json:"status"` // "ok" or "partial" + AckCount int `json:"ack_count"` + Total int `json:"total"` + Quorum int `json:"quorum"` + Threshold int `json:"threshold"` +} + +// guardianPushRequest is sent to each vault guardian. +type guardianPushRequest struct { + Identity string `json:"identity"` + Share string `json:"share"` // base64([x:1byte][y:rest]) + Version uint64 `json:"version"` +} + +// HandlePush processes POST /v1/vault/push. +func (h *Handlers) HandlePush(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + body, err := io.ReadAll(io.LimitReader(r.Body, maxPushBodySize)) + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return + } + + var req PushRequest + if err := json.Unmarshal(body, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + if !isValidIdentity(req.Identity) { + writeError(w, http.StatusBadRequest, "identity must be 64 hex characters") + return + } + + envelopeBytes, err := base64.StdEncoding.DecodeString(req.Envelope) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid base64 envelope") + return + } + if len(envelopeBytes) == 0 { + writeError(w, http.StatusBadRequest, "envelope must not be empty") + return + } + + if !h.rateLimiter.AllowPush(req.Identity) { + w.Header().Set("Retry-After", "120") + writeError(w, http.StatusTooManyRequests, "push rate limit exceeded for this identity") + return + } + + guardians, err := h.discoverGuardians(r.Context()) + if err != nil { + h.logger.ComponentError(logging.ComponentGeneral, "Vault push: guardian discovery failed", zap.Error(err)) + writeError(w, http.StatusServiceUnavailable, "no guardian nodes available") + return + } + + n := len(guardians) + k := shamir.AdaptiveThreshold(n) + quorum := shamir.WriteQuorum(n) + + shares, err := shamir.Split(envelopeBytes, n, k) + if err != nil { + h.logger.ComponentError(logging.ComponentGeneral, "Vault push: Shamir split failed", zap.Error(err)) + writeError(w, http.StatusInternalServerError, "failed to split envelope") + return + } + + // Fan out to guardians in parallel. + ctx, cancel := context.WithTimeout(r.Context(), overallTimeout) + defer cancel() + + var ackCount atomic.Int32 + var wg sync.WaitGroup + wg.Add(n) + + for i, g := range guardians { + go func(idx int, gd guardian) { + defer wg.Done() + + share := shares[idx] + // Serialize: [x:1byte][y:rest] + shareBytes := make([]byte, 1+len(share.Y)) + shareBytes[0] = share.X + copy(shareBytes[1:], share.Y) + shareB64 := base64.StdEncoding.EncodeToString(shareBytes) + + guardianReq := guardianPushRequest{ + Identity: req.Identity, + Share: shareB64, + Version: req.Version, + } + reqBody, _ := json.Marshal(guardianReq) + + url := fmt.Sprintf("http://%s:%d/v1/vault/push", gd.IP, gd.Port) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(reqBody)) + if err != nil { + return + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := h.httpClient.Do(httpReq) + if err != nil { + return + } + defer resp.Body.Close() + io.Copy(io.Discard, resp.Body) + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + ackCount.Add(1) + } + }(i, g) + } + + wg.Wait() + + // Wipe share data. + for i := range shares { + for j := range shares[i].Y { + shares[i].Y[j] = 0 + } + } + + ack := int(ackCount.Load()) + status := "ok" + if ack < quorum { + status = "partial" + } + + writeJSON(w, http.StatusOK, PushResponse{ + Status: status, + AckCount: ack, + Total: n, + Quorum: quorum, + Threshold: k, + }) +} diff --git a/pkg/gateway/handlers/vault/rate_limiter.go b/pkg/gateway/handlers/vault/rate_limiter.go new file mode 100644 index 0000000..9a69821 --- /dev/null +++ b/pkg/gateway/handlers/vault/rate_limiter.go @@ -0,0 +1,120 @@ +package vault + +import ( + "sync" + "time" +) + +// IdentityRateLimiter provides per-identity-hash rate limiting for vault operations. +// Push and pull have separate rate limits since push is more expensive. +type IdentityRateLimiter struct { + pushBuckets sync.Map // identity -> *tokenBucket + pullBuckets sync.Map // identity -> *tokenBucket + pushRate float64 // tokens per second + pushBurst int + pullRate float64 // tokens per second + pullBurst int + stopCh chan struct{} +} + +type tokenBucket struct { + mu sync.Mutex + tokens float64 + lastCheck time.Time +} + +// NewIdentityRateLimiter creates a per-identity rate limiter. +// pushPerHour and pullPerHour are sustained rates; burst is 1/6th of the hourly rate. +func NewIdentityRateLimiter(pushPerHour, pullPerHour int) *IdentityRateLimiter { + pushBurst := pushPerHour / 6 + if pushBurst < 1 { + pushBurst = 1 + } + pullBurst := pullPerHour / 6 + if pullBurst < 1 { + pullBurst = 1 + } + return &IdentityRateLimiter{ + pushRate: float64(pushPerHour) / 3600.0, + pushBurst: pushBurst, + pullRate: float64(pullPerHour) / 3600.0, + pullBurst: pullBurst, + } +} + +// AllowPush checks if a push for this identity is allowed. +func (rl *IdentityRateLimiter) AllowPush(identity string) bool { + return rl.allow(&rl.pushBuckets, identity, rl.pushRate, rl.pushBurst) +} + +// AllowPull checks if a pull for this identity is allowed. +func (rl *IdentityRateLimiter) AllowPull(identity string) bool { + return rl.allow(&rl.pullBuckets, identity, rl.pullRate, rl.pullBurst) +} + +func (rl *IdentityRateLimiter) allow(buckets *sync.Map, identity string, rate float64, burst int) bool { + val, _ := buckets.LoadOrStore(identity, &tokenBucket{ + tokens: float64(burst), + lastCheck: time.Now(), + }) + b := val.(*tokenBucket) + + b.mu.Lock() + defer b.mu.Unlock() + + now := time.Now() + elapsed := now.Sub(b.lastCheck).Seconds() + b.tokens += elapsed * rate + if b.tokens > float64(burst) { + b.tokens = float64(burst) + } + b.lastCheck = now + + if b.tokens >= 1 { + b.tokens-- + return true + } + return false +} + +// StartCleanup runs periodic cleanup of stale identity entries. +func (rl *IdentityRateLimiter) StartCleanup(interval, maxAge time.Duration) { + rl.stopCh = make(chan struct{}) + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + rl.cleanup(maxAge) + case <-rl.stopCh: + return + } + } + }() +} + +// Stop terminates the background cleanup goroutine. +func (rl *IdentityRateLimiter) Stop() { + if rl.stopCh != nil { + close(rl.stopCh) + } +} + +func (rl *IdentityRateLimiter) cleanup(maxAge time.Duration) { + cutoff := time.Now().Add(-maxAge) + cleanMap := func(m *sync.Map) { + m.Range(func(key, value interface{}) bool { + b := value.(*tokenBucket) + b.mu.Lock() + stale := b.lastCheck.Before(cutoff) + b.mu.Unlock() + if stale { + m.Delete(key) + } + return true + }) + } + cleanMap(&rl.pushBuckets) + cleanMap(&rl.pullBuckets) +} diff --git a/pkg/gateway/middleware.go b/pkg/gateway/middleware.go index 8e25840..65567ae 100644 --- a/pkg/gateway/middleware.go +++ b/pkg/gateway/middleware.go @@ -417,6 +417,11 @@ func isPublicPath(p string) bool { return true } + // Vault proxy endpoints (no auth — rate-limited per identity hash within handler) + if strings.HasPrefix(p, "/v1/vault/") { + return true + } + // Phantom auth endpoints are public (session creation, status polling, completion) if strings.HasPrefix(p, "/v1/auth/phantom/") { return true diff --git a/pkg/gateway/routes.go b/pkg/gateway/routes.go index 4e49910..a791eda 100644 --- a/pkg/gateway/routes.go +++ b/pkg/gateway/routes.go @@ -114,6 +114,14 @@ func (g *Gateway) Routes() http.Handler { mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler) } + // vault proxy (public, rate-limited per identity within handler) + if g.vaultHandlers != nil { + mux.HandleFunc("/v1/vault/push", g.vaultHandlers.HandlePush) + mux.HandleFunc("/v1/vault/pull", g.vaultHandlers.HandlePull) + mux.HandleFunc("/v1/vault/health", g.vaultHandlers.HandleHealth) + mux.HandleFunc("/v1/vault/status", g.vaultHandlers.HandleStatus) + } + // webrtc if g.webrtcHandlers != nil { mux.HandleFunc("/v1/webrtc/turn/credentials", g.webrtcHandlers.CredentialsHandler) diff --git a/pkg/shamir/field.go b/pkg/shamir/field.go new file mode 100644 index 0000000..2dd4d97 --- /dev/null +++ b/pkg/shamir/field.go @@ -0,0 +1,82 @@ +// Package shamir implements Shamir's Secret Sharing over GF(2^8). +// +// Uses the AES irreducible polynomial x^8 + x^4 + x^3 + x + 1 (0x11B) +// with generator 3. Precomputed log/exp tables for O(1) field arithmetic. +// +// Cross-platform compatible with the Zig (orama-vault) and TypeScript +// (network-ts-sdk) implementations using identical field parameters. +package shamir + +import "errors" + +// ErrDivisionByZero is returned when dividing by zero in GF(2^8). +var ErrDivisionByZero = errors.New("shamir: division by zero in GF(2^8)") + +// Irreducible polynomial: x^8 + x^4 + x^3 + x + 1. +const irreducible = 0x11B + +// expTable[i] = generator^i mod polynomial, for i in 0..511. +// Extended to 512 entries so Mul can use (logA + logB) without modular reduction. +var expTable [512]byte + +// logTable[a] = i where generator^i = a, for a in 1..255. +// logTable[0] is unused (log of zero is undefined). +var logTable [256]byte + +func init() { + x := uint16(1) + for i := 0; i < 512; i++ { + if i < 256 { + expTable[i] = byte(x) + logTable[byte(x)] = byte(i) + } else { + expTable[i] = expTable[i-255] + } + + if i < 255 { + // Multiply by generator (3): x*3 = x*2 XOR x + x2 := x << 1 + x3 := x2 ^ x + if x3&0x100 != 0 { + x3 ^= irreducible + } + x = x3 + } + } +} + +// Add returns a XOR b (addition in GF(2^8)). +func Add(a, b byte) byte { + return a ^ b +} + +// Mul returns a * b in GF(2^8) via log/exp tables. +func Mul(a, b byte) byte { + if a == 0 || b == 0 { + return 0 + } + logSum := uint16(logTable[a]) + uint16(logTable[b]) + return expTable[logSum] +} + +// Inv returns the multiplicative inverse of a in GF(2^8). +// Returns ErrDivisionByZero if a == 0. +func Inv(a byte) (byte, error) { + if a == 0 { + return 0, ErrDivisionByZero + } + return expTable[255-uint16(logTable[a])], nil +} + +// Div returns a / b in GF(2^8). +// Returns ErrDivisionByZero if b == 0. +func Div(a, b byte) (byte, error) { + if b == 0 { + return 0, ErrDivisionByZero + } + if a == 0 { + return 0, nil + } + logDiff := uint16(logTable[a]) + 255 - uint16(logTable[b]) + return expTable[logDiff], nil +} diff --git a/pkg/shamir/shamir.go b/pkg/shamir/shamir.go new file mode 100644 index 0000000..0ba260a --- /dev/null +++ b/pkg/shamir/shamir.go @@ -0,0 +1,150 @@ +package shamir + +import ( + "crypto/rand" + "errors" + "fmt" +) + +var ( + ErrThresholdTooSmall = errors.New("shamir: threshold K must be at least 2") + ErrShareCountTooSmall = errors.New("shamir: share count N must be >= threshold K") + ErrTooManyShares = errors.New("shamir: maximum 255 shares (GF(2^8) limit)") + ErrEmptySecret = errors.New("shamir: secret must not be empty") + ErrNotEnoughShares = errors.New("shamir: need at least 2 shares to reconstruct") + ErrMismatchedShareLen = errors.New("shamir: all shares must have the same data length") + ErrZeroShareIndex = errors.New("shamir: share index must not be 0") + ErrDuplicateShareIndex = errors.New("shamir: duplicate share indices") +) + +// Share represents a single Shamir share. +type Share struct { + X byte // Evaluation point (1..255, never 0) + Y []byte // Share data (same length as original secret) +} + +// Split divides secret into n shares with threshold k. +// Any k shares can reconstruct the secret; k-1 reveal nothing. +func Split(secret []byte, n, k int) ([]Share, error) { + if k < 2 { + return nil, ErrThresholdTooSmall + } + if n < k { + return nil, ErrShareCountTooSmall + } + if n > 255 { + return nil, ErrTooManyShares + } + if len(secret) == 0 { + return nil, ErrEmptySecret + } + + shares := make([]Share, n) + for i := range shares { + shares[i] = Share{ + X: byte(i + 1), + Y: make([]byte, len(secret)), + } + } + + // Temporary buffer for polynomial coefficients. + coeffs := make([]byte, k) + defer func() { + for i := range coeffs { + coeffs[i] = 0 + } + }() + + for byteIdx := 0; byteIdx < len(secret); byteIdx++ { + coeffs[0] = secret[byteIdx] + // Fill degrees 1..k-1 with random bytes. + if _, err := rand.Read(coeffs[1:]); err != nil { + return nil, fmt.Errorf("shamir: random generation failed: %w", err) + } + for i := range shares { + shares[i].Y[byteIdx] = evaluatePolynomial(coeffs, shares[i].X) + } + } + + return shares, nil +} + +// Combine reconstructs the secret from k or more shares via Lagrange interpolation. +func Combine(shares []Share) ([]byte, error) { + if len(shares) < 2 { + return nil, ErrNotEnoughShares + } + + secretLen := len(shares[0].Y) + seen := make(map[byte]bool, len(shares)) + for _, s := range shares { + if s.X == 0 { + return nil, ErrZeroShareIndex + } + if len(s.Y) != secretLen { + return nil, ErrMismatchedShareLen + } + if seen[s.X] { + return nil, ErrDuplicateShareIndex + } + seen[s.X] = true + } + + result := make([]byte, secretLen) + for byteIdx := 0; byteIdx < secretLen; byteIdx++ { + var value byte + for i, si := range shares { + // Lagrange basis polynomial L_i evaluated at 0: + // L_i(0) = product over j!=i of (0 - x_j)/(x_i - x_j) + // = product over j!=i of x_j / (x_i XOR x_j) + var basis byte = 1 + for j, sj := range shares { + if i == j { + continue + } + num := sj.X + den := Add(si.X, sj.X) // x_i - x_j = x_i XOR x_j in GF(2^8) + d, err := Div(num, den) + if err != nil { + return nil, err + } + basis = Mul(basis, d) + } + value = Add(value, Mul(si.Y[byteIdx], basis)) + } + result[byteIdx] = value + } + + return result, nil +} + +// AdaptiveThreshold returns max(3, floor(n/3)). +// This is the read quorum: minimum shares needed to reconstruct. +func AdaptiveThreshold(n int) int { + t := n / 3 + if t < 3 { + return 3 + } + return t +} + +// WriteQuorum returns ceil(2n/3). +// This is the write quorum: minimum ACKs needed for a successful push. +func WriteQuorum(n int) int { + if n == 0 { + return 0 + } + if n <= 2 { + return n + } + return (2*n + 2) / 3 +} + +// evaluatePolynomial evaluates p(x) = coeffs[0] + coeffs[1]*x + ... using Horner's method. +func evaluatePolynomial(coeffs []byte, x byte) byte { + var result byte + for i := len(coeffs) - 1; i >= 0; i-- { + result = Add(Mul(result, x), coeffs[i]) + } + return result +} diff --git a/pkg/shamir/shamir_test.go b/pkg/shamir/shamir_test.go new file mode 100644 index 0000000..2e57cc9 --- /dev/null +++ b/pkg/shamir/shamir_test.go @@ -0,0 +1,501 @@ +package shamir + +import ( + "testing" +) + +// ── GF(2^8) Field Tests ──────────────────────────────────────────────────── + +func TestExpTable_Cycle(t *testing.T) { + // g^0 = 1, g^255 = 1 (cyclic group of order 255) + if expTable[0] != 1 { + t.Errorf("exp[0] = %d, want 1", expTable[0]) + } + if expTable[255] != 1 { + t.Errorf("exp[255] = %d, want 1", expTable[255]) + } +} + +func TestExpTable_AllNonzeroAppear(t *testing.T) { + var seen [256]bool + for i := 0; i < 255; i++ { + v := expTable[i] + if seen[v] { + t.Fatalf("duplicate value %d at index %d", v, i) + } + seen[v] = true + } + for v := 1; v < 256; v++ { + if !seen[v] { + t.Errorf("value %d not seen in exp[0..255]", v) + } + } + if seen[0] { + t.Error("zero should not appear in exp[0..254]") + } +} + +// Cross-platform test vectors from orama-vault/src/sss/test_cross_platform.zig +func TestExpTable_CrossPlatform(t *testing.T) { + vectors := [][2]int{ + {0, 1}, {10, 114}, {20, 216}, {30, 102}, + {40, 106}, {50, 4}, {60, 211}, {70, 77}, + {80, 131}, {90, 179}, {100, 16}, {110, 97}, + {120, 47}, {130, 58}, {140, 250}, {150, 64}, + {160, 159}, {170, 188}, {180, 232}, {190, 197}, + {200, 27}, {210, 74}, {220, 198}, {230, 141}, + {240, 57}, {250, 108}, {254, 246}, {255, 1}, + } + for _, v := range vectors { + if got := expTable[v[0]]; got != byte(v[1]) { + t.Errorf("exp[%d] = %d, want %d", v[0], got, v[1]) + } + } +} + +func TestMul_CrossPlatform(t *testing.T) { + vectors := [][3]byte{ + {1, 1, 1}, {1, 2, 2}, {1, 3, 3}, + {1, 42, 42}, {1, 127, 127}, {1, 170, 170}, {1, 255, 255}, + {2, 1, 2}, {2, 2, 4}, {2, 3, 6}, + {2, 42, 84}, {2, 127, 254}, {2, 170, 79}, {2, 255, 229}, + {3, 1, 3}, {3, 2, 6}, {3, 3, 5}, + {3, 42, 126}, {3, 127, 129}, {3, 170, 229}, {3, 255, 26}, + {42, 1, 42}, {42, 2, 84}, {42, 3, 126}, + {42, 42, 40}, {42, 127, 82}, {42, 170, 244}, {42, 255, 142}, + {127, 1, 127}, {127, 2, 254}, {127, 3, 129}, + {127, 42, 82}, {127, 127, 137}, {127, 170, 173}, {127, 255, 118}, + {170, 1, 170}, {170, 2, 79}, {170, 3, 229}, + {170, 42, 244}, {170, 127, 173}, {170, 170, 178}, {170, 255, 235}, + {255, 1, 255}, {255, 2, 229}, {255, 3, 26}, + {255, 42, 142}, {255, 127, 118}, {255, 170, 235}, {255, 255, 19}, + } + for _, v := range vectors { + if got := Mul(v[0], v[1]); got != v[2] { + t.Errorf("Mul(%d, %d) = %d, want %d", v[0], v[1], got, v[2]) + } + } +} + +func TestMul_Zero(t *testing.T) { + for a := 0; a < 256; a++ { + if Mul(byte(a), 0) != 0 { + t.Errorf("Mul(%d, 0) != 0", a) + } + if Mul(0, byte(a)) != 0 { + t.Errorf("Mul(0, %d) != 0", a) + } + } +} + +func TestMul_Identity(t *testing.T) { + for a := 0; a < 256; a++ { + if Mul(byte(a), 1) != byte(a) { + t.Errorf("Mul(%d, 1) = %d", a, Mul(byte(a), 1)) + } + } +} + +func TestMul_Commutative(t *testing.T) { + for a := 1; a < 256; a += 7 { + for b := 1; b < 256; b += 11 { + ab := Mul(byte(a), byte(b)) + ba := Mul(byte(b), byte(a)) + if ab != ba { + t.Errorf("Mul(%d,%d)=%d != Mul(%d,%d)=%d", a, b, ab, b, a, ba) + } + } + } +} + +func TestInv_CrossPlatform(t *testing.T) { + vectors := [][2]byte{ + {1, 1}, {2, 141}, {3, 246}, {5, 82}, + {7, 209}, {16, 116}, {42, 152}, {127, 130}, + {128, 131}, {170, 18}, {200, 169}, {255, 28}, + } + for _, v := range vectors { + got, err := Inv(v[0]) + if err != nil { + t.Errorf("Inv(%d) returned error: %v", v[0], err) + continue + } + if got != v[1] { + t.Errorf("Inv(%d) = %d, want %d", v[0], got, v[1]) + } + } +} + +func TestInv_SelfInverse(t *testing.T) { + for a := 1; a < 256; a++ { + inv1, _ := Inv(byte(a)) + inv2, _ := Inv(inv1) + if inv2 != byte(a) { + t.Errorf("Inv(Inv(%d)) = %d, want %d", a, inv2, a) + } + } +} + +func TestInv_Product(t *testing.T) { + for a := 1; a < 256; a++ { + inv1, _ := Inv(byte(a)) + if Mul(byte(a), inv1) != 1 { + t.Errorf("Mul(%d, Inv(%d)) != 1", a, a) + } + } +} + +func TestInv_Zero(t *testing.T) { + _, err := Inv(0) + if err != ErrDivisionByZero { + t.Errorf("Inv(0) should return ErrDivisionByZero, got %v", err) + } +} + +func TestDiv_CrossPlatform(t *testing.T) { + vectors := [][3]byte{ + {1, 1, 1}, {1, 2, 141}, {1, 3, 246}, + {1, 42, 152}, {1, 127, 130}, {1, 170, 18}, {1, 255, 28}, + {2, 1, 2}, {2, 2, 1}, {2, 3, 247}, + {3, 1, 3}, {3, 2, 140}, {3, 3, 1}, + {42, 1, 42}, {42, 2, 21}, {42, 42, 1}, + {127, 1, 127}, {127, 127, 1}, + {170, 1, 170}, {170, 170, 1}, + {255, 1, 255}, {255, 255, 1}, + } + for _, v := range vectors { + got, err := Div(v[0], v[1]) + if err != nil { + t.Errorf("Div(%d, %d) returned error: %v", v[0], v[1], err) + continue + } + if got != v[2] { + t.Errorf("Div(%d, %d) = %d, want %d", v[0], v[1], got, v[2]) + } + } +} + +func TestDiv_ByZero(t *testing.T) { + _, err := Div(42, 0) + if err != ErrDivisionByZero { + t.Errorf("Div(42, 0) should return ErrDivisionByZero, got %v", err) + } +} + +// ── Polynomial evaluation ────────────────────────────────────────────────── + +func TestEvaluatePolynomial_CrossPlatform(t *testing.T) { + // p(x) = 42 + 5x + 7x^2 + coeffs0 := []byte{42, 5, 7} + vectors0 := [][2]byte{ + {1, 40}, {2, 60}, {3, 62}, {4, 78}, + {5, 76}, {10, 207}, {100, 214}, {255, 125}, + } + for _, v := range vectors0 { + if got := evaluatePolynomial(coeffs0, v[0]); got != v[1] { + t.Errorf("p(%d) = %d, want %d [coeffs: 42,5,7]", v[0], got, v[1]) + } + } + + // p(x) = 0 + 0xAB*x + 0xCD*x^2 + coeffs1 := []byte{0, 0xAB, 0xCD} + vectors1 := [][2]byte{ + {1, 102}, {3, 50}, {5, 152}, {7, 204}, {200, 96}, + } + for _, v := range vectors1 { + if got := evaluatePolynomial(coeffs1, v[0]); got != v[1] { + t.Errorf("p(%d) = %d, want %d [coeffs: 0,AB,CD]", v[0], got, v[1]) + } + } + + // p(x) = 0xFF (constant) + coeffs2 := []byte{0xFF} + for _, x := range []byte{1, 2, 255} { + if got := evaluatePolynomial(coeffs2, x); got != 0xFF { + t.Errorf("constant p(%d) = %d, want 255", x, got) + } + } + + // p(x) = 128 + 64x + 32x^2 + 16x^3 + coeffs3 := []byte{128, 64, 32, 16} + vectors3 := [][2]byte{ + {1, 240}, {2, 0}, {3, 16}, {4, 193}, {5, 234}, + } + for _, v := range vectors3 { + if got := evaluatePolynomial(coeffs3, v[0]); got != v[1] { + t.Errorf("p(%d) = %d, want %d [coeffs: 128,64,32,16]", v[0], got, v[1]) + } + } +} + +// ── Lagrange combine (cross-platform) ───────────────────────────────────── + +func TestCombine_CrossPlatform_SingleByte(t *testing.T) { + // p(x) = 42 + 5x + 7x^2, secret = 42 + // Shares: (1,40) (2,60) (3,62) (4,78) (5,76) + allShares := []Share{ + {X: 1, Y: []byte{40}}, + {X: 2, Y: []byte{60}}, + {X: 3, Y: []byte{62}}, + {X: 4, Y: []byte{78}}, + {X: 5, Y: []byte{76}}, + } + + subsets := [][]int{ + {0, 1, 2}, // {1,2,3} + {0, 2, 4}, // {1,3,5} + {1, 3, 4}, // {2,4,5} + {2, 3, 4}, // {3,4,5} + } + + for _, subset := range subsets { + shares := make([]Share, len(subset)) + for i, idx := range subset { + shares[i] = allShares[idx] + } + result, err := Combine(shares) + if err != nil { + t.Fatalf("Combine failed for subset %v: %v", subset, err) + } + if result[0] != 42 { + t.Errorf("Combine(subset %v) = %d, want 42", subset, result[0]) + } + } +} + +func TestCombine_CrossPlatform_MultiByte(t *testing.T) { + // 2-byte secret [42, 0] + // byte0: 42 + 5x + 7x^2 → shares at x=1,3,5: 40, 62, 76 + // byte1: 0 + 0xAB*x + 0xCD*x^2 → shares at x=1,3,5: 102, 50, 152 + shares := []Share{ + {X: 1, Y: []byte{40, 102}}, + {X: 3, Y: []byte{62, 50}}, + {X: 5, Y: []byte{76, 152}}, + } + result, err := Combine(shares) + if err != nil { + t.Fatalf("Combine failed: %v", err) + } + if result[0] != 42 || result[1] != 0 { + t.Errorf("Combine = %v, want [42, 0]", result) + } +} + +// ── Split/Combine round-trip ────────────────────────────────────────────── + +func TestSplitCombine_RoundTrip_2of3(t *testing.T) { + secret := []byte("hello world") + shares, err := Split(secret, 3, 2) + if err != nil { + t.Fatalf("Split: %v", err) + } + if len(shares) != 3 { + t.Fatalf("got %d shares, want 3", len(shares)) + } + + // Any 2 shares should reconstruct + for i := 0; i < 3; i++ { + for j := i + 1; j < 3; j++ { + result, err := Combine([]Share{shares[i], shares[j]}) + if err != nil { + t.Fatalf("Combine(%d,%d): %v", i, j, err) + } + if string(result) != string(secret) { + t.Errorf("Combine(%d,%d) = %q, want %q", i, j, result, secret) + } + } + } +} + +func TestSplitCombine_RoundTrip_3of5(t *testing.T) { + secret := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + shares, err := Split(secret, 5, 3) + if err != nil { + t.Fatalf("Split: %v", err) + } + + // All C(5,3)=10 subsets should reconstruct + count := 0 + for i := 0; i < 5; i++ { + for j := i + 1; j < 5; j++ { + for k := j + 1; k < 5; k++ { + result, err := Combine([]Share{shares[i], shares[j], shares[k]}) + if err != nil { + t.Fatalf("Combine(%d,%d,%d): %v", i, j, k, err) + } + for idx := range secret { + if result[idx] != secret[idx] { + t.Errorf("Combine(%d,%d,%d)[%d] = %d, want %d", i, j, k, idx, result[idx], secret[idx]) + } + } + count++ + } + } + } + if count != 10 { + t.Errorf("tested %d subsets, want 10", count) + } +} + +func TestSplitCombine_RoundTrip_LargeSecret(t *testing.T) { + secret := make([]byte, 256) + for i := range secret { + secret[i] = byte(i) + } + shares, err := Split(secret, 10, 5) + if err != nil { + t.Fatalf("Split: %v", err) + } + + // Use first 5 shares + result, err := Combine(shares[:5]) + if err != nil { + t.Fatalf("Combine: %v", err) + } + for i := range secret { + if result[i] != secret[i] { + t.Errorf("result[%d] = %d, want %d", i, result[i], secret[i]) + break + } + } +} + +func TestSplitCombine_AllZeros(t *testing.T) { + secret := make([]byte, 10) + shares, err := Split(secret, 5, 3) + if err != nil { + t.Fatalf("Split: %v", err) + } + result, err := Combine(shares[:3]) + if err != nil { + t.Fatalf("Combine: %v", err) + } + for i, b := range result { + if b != 0 { + t.Errorf("result[%d] = %d, want 0", i, b) + } + } +} + +func TestSplitCombine_AllOnes(t *testing.T) { + secret := make([]byte, 10) + for i := range secret { + secret[i] = 0xFF + } + shares, err := Split(secret, 5, 3) + if err != nil { + t.Fatalf("Split: %v", err) + } + result, err := Combine(shares[:3]) + if err != nil { + t.Fatalf("Combine: %v", err) + } + for i, b := range result { + if b != 0xFF { + t.Errorf("result[%d] = %d, want 255", i, b) + } + } +} + +// ── Share indices ───────────────────────────────────────────────────────── + +func TestSplit_ShareIndices(t *testing.T) { + shares, err := Split([]byte{42}, 5, 3) + if err != nil { + t.Fatalf("Split: %v", err) + } + for i, s := range shares { + if s.X != byte(i+1) { + t.Errorf("shares[%d].X = %d, want %d", i, s.X, i+1) + } + } +} + +// ── Error cases ─────────────────────────────────────────────────────────── + +func TestSplit_Errors(t *testing.T) { + tests := []struct { + name string + secret []byte + n, k int + want error + }{ + {"k < 2", []byte{1}, 3, 1, ErrThresholdTooSmall}, + {"n < k", []byte{1}, 2, 3, ErrShareCountTooSmall}, + {"n > 255", []byte{1}, 256, 3, ErrTooManyShares}, + {"empty secret", []byte{}, 3, 2, ErrEmptySecret}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Split(tt.secret, tt.n, tt.k) + if err != tt.want { + t.Errorf("Split() error = %v, want %v", err, tt.want) + } + }) + } +} + +func TestCombine_Errors(t *testing.T) { + t.Run("not enough shares", func(t *testing.T) { + _, err := Combine([]Share{{X: 1, Y: []byte{1}}}) + if err != ErrNotEnoughShares { + t.Errorf("got %v, want ErrNotEnoughShares", err) + } + }) + + t.Run("zero index", func(t *testing.T) { + _, err := Combine([]Share{ + {X: 0, Y: []byte{1}}, + {X: 1, Y: []byte{2}}, + }) + if err != ErrZeroShareIndex { + t.Errorf("got %v, want ErrZeroShareIndex", err) + } + }) + + t.Run("mismatched lengths", func(t *testing.T) { + _, err := Combine([]Share{ + {X: 1, Y: []byte{1, 2}}, + {X: 2, Y: []byte{3}}, + }) + if err != ErrMismatchedShareLen { + t.Errorf("got %v, want ErrMismatchedShareLen", err) + } + }) + + t.Run("duplicate indices", func(t *testing.T) { + _, err := Combine([]Share{ + {X: 1, Y: []byte{1}}, + {X: 1, Y: []byte{2}}, + }) + if err != ErrDuplicateShareIndex { + t.Errorf("got %v, want ErrDuplicateShareIndex", err) + } + }) +} + +// ── Threshold / Quorum ──────────────────────────────────────────────────── + +func TestAdaptiveThreshold(t *testing.T) { + tests := [][2]int{ + {1, 3}, {2, 3}, {3, 3}, {5, 3}, {8, 3}, {9, 3}, + {10, 3}, {12, 4}, {15, 5}, {30, 10}, {100, 33}, + } + for _, tt := range tests { + if got := AdaptiveThreshold(tt[0]); got != tt[1] { + t.Errorf("AdaptiveThreshold(%d) = %d, want %d", tt[0], got, tt[1]) + } + } +} + +func TestWriteQuorum(t *testing.T) { + tests := [][2]int{ + {0, 0}, {1, 1}, {2, 2}, {3, 2}, {4, 3}, {5, 4}, + {6, 4}, {10, 7}, {14, 10}, {100, 67}, + } + for _, tt := range tests { + if got := WriteQuorum(tt[0]); got != tt[1] { + t.Errorf("WriteQuorum(%d) = %d, want %d", tt[0], got, tt[1]) + } + } +}