mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-06-16 23:54:13 +00:00
feat(gateway): implement ntfy cluster fan-out and improve secrets encryption
- Add `ntfyFanoutResolver` to distribute push notifications across all active cluster nodes, ensuring delivery when nodes lack shared state. - Refactor secrets encryption key derivation to use cluster-wide secrets via HKDF, replacing ephemeral per-node keys to fix cross-node decryption issues. - Add unit tests for fan-out resolution logic and caching behavior.
This commit is contained in:
parent
4ae8fa941d
commit
34f9da6f8d
@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -478,15 +479,21 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
||||
|
||||
// Create secrets manager for serverless functions (AES-256-GCM encrypted).
|
||||
//
|
||||
// The encryption key comes from the gateway Config (loaded from
|
||||
// ~/.orama/secrets/secrets-encryption-key), NOT from engineCfg — engineCfg
|
||||
// never has the key set, so passing it always produced a per-process
|
||||
// ephemeral key and made get_secret return undecryptable values
|
||||
// (bugboard #837). allowEphemeral=false: a missing/invalid key fails
|
||||
// The encryption key is DERIVED from the cluster secret via HKDF
|
||||
// (resolveSecretsEncryptionKeyHex), so every gateway in the cluster computes
|
||||
// the identical key and a secret written on one node decrypts on every other
|
||||
// node and survives rolling upgrades. This replaces the old per-node
|
||||
// crypto/rand key file, whose divergence across an upgraded cluster kept
|
||||
// get_secret broken (bugboard #837). The file key (cfg.SecretsEncryptionKey)
|
||||
// remains only as a fallback when no cluster secret is available (legacy /
|
||||
// single-node test rigs). allowEphemeral=false: a missing/invalid key fails
|
||||
// loudly here and disables get_secret rather than silently corrupting
|
||||
// secrets.
|
||||
var secretsMgr serverless.SecretsManager
|
||||
if smImpl, secretsErr := hostfunctions.NewDBSecretsManager(deps.ORMClient, cfg.SecretsEncryptionKey, false, logger.Logger); secretsErr != nil {
|
||||
if secretsKeyHex, keyErr := resolveSecretsEncryptionKeyHex(cfg.ClusterSecret, cfg.SecretsEncryptionKey); keyErr != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "Failed to derive secrets encryption key; get_secret will be unavailable",
|
||||
zap.Error(keyErr))
|
||||
} else if smImpl, secretsErr := hostfunctions.NewDBSecretsManager(deps.ORMClient, secretsKeyHex, false, logger.Logger); secretsErr != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "Failed to initialize secrets manager; get_secret will be unavailable",
|
||||
zap.Error(secretsErr))
|
||||
} else {
|
||||
@ -504,7 +511,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
|
||||
//
|
||||
// PushDispatcher (legacy) is set only when YAML defaults exist —
|
||||
// kept for back-compat with code that hasn't migrated to Manager.
|
||||
pushDispatcher, pushStore, pushManager, pushCfgStore, pushCredManager, err := buildPushDispatcher(cfg, deps.ORMClient, logger)
|
||||
pushDispatcher, pushStore, pushManager, pushCfgStore, pushCredManager, err := buildPushDispatcher(cfg, deps.ORMClient, deps.Client, logger)
|
||||
if err != nil {
|
||||
// Non-fatal: log and continue. Functions calling push_send will get nil
|
||||
// (silent no-op) and HTTP /v1/push/* endpoints return 503.
|
||||
@ -921,6 +928,7 @@ func appendRQLiteQueryParams(dsn string) string {
|
||||
func buildPushDispatcher(
|
||||
cfg *Config,
|
||||
db rqlite.Client,
|
||||
globalDB client.NetworkClient,
|
||||
logger *logging.ColoredLogger,
|
||||
) (*push.PushDispatcher, push.PushDeviceStore, *push.Manager, push.ConfigStore, *pushcreds.Manager, error) {
|
||||
if cfg.ClusterSecret == "" {
|
||||
@ -957,6 +965,25 @@ func buildPushDispatcher(
|
||||
pushcreds.Register(pushapns.NewValidator())
|
||||
pushcreds.Register(pushntfy.NewValidator())
|
||||
|
||||
// ntfy cluster fan-out (bugboard #858): the default push infra runs an
|
||||
// independent ntfy per node with no shared store, so a publish must reach
|
||||
// EVERY active node for the subscriber's instance (picked by round-robin
|
||||
// DNS) to receive it. Build a resolver over the global dns_nodes table; the
|
||||
// factory attaches it only to providers using the shared default base URL
|
||||
// (a namespace pointing ntfy at its own server is never fanned across our
|
||||
// cluster). nil globalDB or an unparseable base URL → no fan-out (provider
|
||||
// falls back to the single base URL).
|
||||
var ntfyFanout *ntfyFanoutResolver
|
||||
var ntfyFanoutHost string
|
||||
if globalDB != nil {
|
||||
if base := strings.TrimSpace(cfg.NtfyBaseURL); base != "" {
|
||||
if u, perr := url.Parse(base); perr == nil && u.Hostname() != "" {
|
||||
ntfyFanoutHost = u.Hostname()
|
||||
ntfyFanout = newNtfyFanoutResolver(globalDB, u.Scheme, u.Port(), defaultNtfyFanoutTTL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderFactory turns a resolved Config into the right set of
|
||||
// provider instances. Lives here in dependencies.go because this is
|
||||
// the only place that imports both the manager package and the
|
||||
@ -997,6 +1024,13 @@ func buildPushDispatcher(
|
||||
}
|
||||
}
|
||||
if ntfyCfg.BaseURL != "" {
|
||||
// Fan out across all push nodes ONLY for the shared default infra.
|
||||
// A namespace that overrode BaseURL with its own ntfy server keeps
|
||||
// single-host delivery (its server, not our cluster).
|
||||
if ntfyFanout != nil && ntfyCfg.BaseURL == cfg.NtfyBaseURL {
|
||||
ntfyCfg.FanoutResolver = ntfyFanout.Hosts
|
||||
ntfyCfg.FanoutHostHeader = ntfyFanoutHost
|
||||
}
|
||||
ps = append(ps, pushntfy.New(ntfyCfg, logger.Logger))
|
||||
}
|
||||
if c.ExpoAccessToken != "" {
|
||||
|
||||
95
core/pkg/gateway/push_fanout.go
Normal file
95
core/pkg/gateway/push_fanout.go
Normal file
@ -0,0 +1,95 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/client"
|
||||
)
|
||||
|
||||
// defaultNtfyFanoutTTL bounds how long the active-push-node list is cached
|
||||
// before re-querying dns_nodes. Matches the DNS heartbeat cadence, so a node
|
||||
// added/removed is picked up within a heartbeat without hammering rqlite on
|
||||
// every push.
|
||||
const defaultNtfyFanoutTTL = 30 * time.Second
|
||||
|
||||
// ntfyFanoutResolver resolves the set of ntfy publish base URLs (one per active
|
||||
// push node) for fan-out delivery, caching the result for a short TTL. Each
|
||||
// node runs an independent ntfy with no shared store, so a publish must reach
|
||||
// every node for the subscriber's instance to receive it (bugboard #858).
|
||||
type ntfyFanoutResolver struct {
|
||||
// query returns the public IPs of the currently-active push nodes. Injected
|
||||
// so the cache/transform logic is unit-testable without a live cluster.
|
||||
query func(ctx context.Context) ([]string, error)
|
||||
scheme string // "https" (prod) / "http" (dev), from the configured base URL
|
||||
port string // explicit port from the base URL, or "" for the scheme default
|
||||
|
||||
ttl time.Duration
|
||||
mu sync.Mutex
|
||||
cached []string
|
||||
cachedAt time.Time
|
||||
}
|
||||
|
||||
// newNtfyFanoutResolver builds a resolver backed by the global dns_nodes table.
|
||||
func newNtfyFanoutResolver(globalDB client.NetworkClient, scheme, port string, ttl time.Duration) *ntfyFanoutResolver {
|
||||
return &ntfyFanoutResolver{
|
||||
scheme: scheme,
|
||||
port: port,
|
||||
ttl: ttl,
|
||||
query: func(ctx context.Context) ([]string, error) {
|
||||
db := globalDB.Database()
|
||||
res, err := db.Query(client.WithInternalAuth(ctx), "SELECT ip_address FROM dns_nodes WHERE status = 'active'")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query active push nodes: %w", err)
|
||||
}
|
||||
if res == nil {
|
||||
return nil, nil
|
||||
}
|
||||
ips := make([]string, 0, len(res.Rows))
|
||||
for _, row := range res.Rows {
|
||||
if len(row) == 0 {
|
||||
continue
|
||||
}
|
||||
if ip, ok := row[0].(string); ok && ip != "" {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
return ips, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Hosts returns the cached fan-out base URLs, refreshing from the query when the
|
||||
// cache is stale. On a query error it returns the last-known list (possibly nil)
|
||||
// alongside the error, so the caller can decide to fall back to its base URL
|
||||
// rather than dropping a push.
|
||||
func (r *ntfyFanoutResolver) Hosts(ctx context.Context) ([]string, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.cached != nil && time.Since(r.cachedAt) < r.ttl {
|
||||
return r.cached, nil
|
||||
}
|
||||
|
||||
ips, err := r.query(ctx)
|
||||
if err != nil {
|
||||
return r.cached, err
|
||||
}
|
||||
|
||||
hosts := make([]string, 0, len(ips))
|
||||
suffix := ""
|
||||
if r.port != "" {
|
||||
suffix = ":" + r.port
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
hosts = append(hosts, r.scheme+"://"+ip+suffix)
|
||||
}
|
||||
r.cached = hosts
|
||||
r.cachedAt = time.Now()
|
||||
return hosts, nil
|
||||
}
|
||||
125
core/pkg/gateway/push_fanout_test.go
Normal file
125
core/pkg/gateway/push_fanout_test.go
Normal file
@ -0,0 +1,125 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Bugboard #858 — the fan-out resolver turns active dns_nodes into ntfy publish
|
||||
// base URLs and caches them for a short TTL. These pin the transform + caching.
|
||||
|
||||
func TestNtfyFanoutResolver_buildsSchemeHostPort(t *testing.T) {
|
||||
r := &ntfyFanoutResolver{
|
||||
scheme: "https",
|
||||
port: "",
|
||||
ttl: time.Minute,
|
||||
query: func(context.Context) ([]string, error) { return []string{"1.2.3.4", "5.6.7.8"}, nil },
|
||||
}
|
||||
hosts, err := r.Hosts(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Hosts: %v", err)
|
||||
}
|
||||
want := []string{"https://1.2.3.4", "https://5.6.7.8"}
|
||||
if len(hosts) != len(want) {
|
||||
t.Fatalf("got %v; want %v", hosts, want)
|
||||
}
|
||||
for i := range want {
|
||||
if hosts[i] != want[i] {
|
||||
t.Errorf("host[%d] = %q; want %q", i, hosts[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtfyFanoutResolver_includesExplicitPort(t *testing.T) {
|
||||
r := &ntfyFanoutResolver{
|
||||
scheme: "http",
|
||||
port: "8090",
|
||||
ttl: time.Minute,
|
||||
query: func(context.Context) ([]string, error) { return []string{"10.0.0.6"}, nil },
|
||||
}
|
||||
hosts, _ := r.Hosts(context.Background())
|
||||
if len(hosts) != 1 || hosts[0] != "http://10.0.0.6:8090" {
|
||||
t.Errorf("got %v; want [http://10.0.0.6:8090]", hosts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtfyFanoutResolver_skipsEmptyIPs(t *testing.T) {
|
||||
r := &ntfyFanoutResolver{
|
||||
scheme: "https",
|
||||
ttl: time.Minute,
|
||||
query: func(context.Context) ([]string, error) { return []string{"", "1.2.3.4", ""}, nil },
|
||||
}
|
||||
hosts, _ := r.Hosts(context.Background())
|
||||
if len(hosts) != 1 || hosts[0] != "https://1.2.3.4" {
|
||||
t.Errorf("got %v; want only the non-empty IP", hosts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtfyFanoutResolver_cachesWithinTTL(t *testing.T) {
|
||||
calls := 0
|
||||
r := &ntfyFanoutResolver{
|
||||
scheme: "https",
|
||||
ttl: time.Minute,
|
||||
query: func(context.Context) ([]string, error) {
|
||||
calls++
|
||||
return []string{"1.2.3.4"}, nil
|
||||
},
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
if _, err := r.Hosts(context.Background()); err != nil {
|
||||
t.Fatalf("Hosts: %v", err)
|
||||
}
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("query called %d times; want 1 (cached within TTL)", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtfyFanoutResolver_requeriesAfterTTL(t *testing.T) {
|
||||
calls := 0
|
||||
r := &ntfyFanoutResolver{
|
||||
scheme: "https",
|
||||
ttl: time.Nanosecond, // expire immediately
|
||||
query: func(context.Context) ([]string, error) {
|
||||
calls++
|
||||
return []string{"1.2.3.4"}, nil
|
||||
},
|
||||
}
|
||||
_, _ = r.Hosts(context.Background())
|
||||
time.Sleep(time.Millisecond)
|
||||
_, _ = r.Hosts(context.Background())
|
||||
if calls != 2 {
|
||||
t.Errorf("query called %d times; want 2 (TTL expired between calls)", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNtfyFanoutResolver_queryError_returnsStaleCache(t *testing.T) {
|
||||
fail := false
|
||||
r := &ntfyFanoutResolver{
|
||||
scheme: "https",
|
||||
ttl: time.Nanosecond,
|
||||
query: func(context.Context) ([]string, error) {
|
||||
if fail {
|
||||
return nil, errors.New("rqlite unreachable")
|
||||
}
|
||||
return []string{"1.2.3.4"}, nil
|
||||
},
|
||||
}
|
||||
// Prime the cache.
|
||||
if _, err := r.Hosts(context.Background()); err != nil {
|
||||
t.Fatalf("prime: %v", err)
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
// Now the query fails — Hosts must return the stale cache alongside the error
|
||||
// so the caller can fall back rather than drop the push.
|
||||
fail = true
|
||||
hosts, err := r.Hosts(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("want the query error surfaced")
|
||||
}
|
||||
if len(hosts) != 1 || hosts[0] != "https://1.2.3.4" {
|
||||
t.Errorf("want the stale cache returned on error; got %v", hosts)
|
||||
}
|
||||
}
|
||||
49
core/pkg/gateway/secrets_key.go
Normal file
49
core/pkg/gateway/secrets_key.go
Normal file
@ -0,0 +1,49 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/secrets"
|
||||
)
|
||||
|
||||
// secretsEncryptionDerivePurpose is the HKDF info label used to derive the
|
||||
// function-secrets AES-256 key from the cluster secret. Deriving it (instead of
|
||||
// generating a per-node crypto/rand key file) guarantees every gateway in the
|
||||
// cluster computes the IDENTICAL key, so a secret written on one node decrypts
|
||||
// on every other node and survives rolling upgrades — eliminating the
|
||||
// key-divergence / convergence-window class that kept get_secret broken for
|
||||
// days (bugboard #837). Same pattern as the cluster-wide JWT signing key
|
||||
// (jwtEdDSADerivePurpose) and the TURN encryption key ("turn-encryption").
|
||||
//
|
||||
// Bumping the version label (e.g. "...-v2") is a DELIBERATE rotation that
|
||||
// invalidates every stored function secret (they must be re-`set`). It must
|
||||
// never be changed casually.
|
||||
const secretsEncryptionDerivePurpose = "orama-secrets-encryption-v1"
|
||||
|
||||
// resolveSecretsEncryptionKeyHex returns the hex-encoded AES-256 key the
|
||||
// serverless secrets manager should use to encrypt/decrypt function secrets.
|
||||
//
|
||||
// Primary: derive deterministically from the cluster secret via HKDF, so the
|
||||
// key is identical on every gateway in the cluster and stable across restarts
|
||||
// and rolling upgrades. The cluster secret is TrimSpace'd first so a stray
|
||||
// trailing newline on one node's secret file can't silently diverge its derived
|
||||
// key from the rest of the cluster (the host gateway reads the file untrimmed
|
||||
// while the namespace gateway trims it — without this they could derive
|
||||
// different keys and reintroduce #837).
|
||||
//
|
||||
// Fallback: when no cluster secret is available (single-node test rigs / legacy
|
||||
// deployments without a shared secret), fall back to an explicitly-configured
|
||||
// key file. An empty result then makes the production secrets manager fail loud
|
||||
// (NewDBSecretsManager with allowEphemeral=false), rather than silently using a
|
||||
// per-process ephemeral key.
|
||||
func resolveSecretsEncryptionKeyHex(clusterSecret, fileKeyHex string) (string, error) {
|
||||
if cs := strings.TrimSpace(clusterSecret); cs != "" {
|
||||
key, err := secrets.DeriveKey(cs, secretsEncryptionDerivePurpose)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(key), nil
|
||||
}
|
||||
return strings.TrimSpace(fileKeyHex), nil
|
||||
}
|
||||
95
core/pkg/gateway/secrets_key_test.go
Normal file
95
core/pkg/gateway/secrets_key_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/secrets"
|
||||
)
|
||||
|
||||
// Bugboard #837 — the function-secrets AES key must be DERIVED from the cluster
|
||||
// secret (not a per-node random file), so every gateway computes the identical
|
||||
// key and stored secrets survive rolling upgrades. These pin the derivation.
|
||||
|
||||
func TestResolveSecretsEncryptionKeyHex_deterministic(t *testing.T) {
|
||||
// Same cluster secret → byte-identical key, every time. This is the whole
|
||||
// point: any gateway in the cluster derives the same key, so a secret set on
|
||||
// one node decrypts on all others.
|
||||
const cs = "cluster-secret-abc123"
|
||||
a, err := resolveSecretsEncryptionKeyHex(cs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("resolve: %v", err)
|
||||
}
|
||||
b, err := resolveSecretsEncryptionKeyHex(cs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("resolve: %v", err)
|
||||
}
|
||||
if a == "" || a != b {
|
||||
t.Fatalf("derivation not deterministic: %q vs %q", a, b)
|
||||
}
|
||||
// Valid AES-256 key: 32 bytes = 64 hex chars.
|
||||
raw, err := hex.DecodeString(a)
|
||||
if err != nil || len(raw) != 32 {
|
||||
t.Errorf("derived key is not 32-byte hex: len(raw)=%d err=%v", len(raw), err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSecretsEncryptionKeyHex_trimInvariant(t *testing.T) {
|
||||
// A trailing newline on one node's cluster-secret file must NOT change the
|
||||
// derived key — otherwise the host gateway (reads untrimmed) and a namespace
|
||||
// gateway (reads trimmed) would diverge and reintroduce #837.
|
||||
trimmed, _ := resolveSecretsEncryptionKeyHex("cluster-secret-abc123", "")
|
||||
withNL, _ := resolveSecretsEncryptionKeyHex("cluster-secret-abc123\n", "")
|
||||
withSpaces, _ := resolveSecretsEncryptionKeyHex(" cluster-secret-abc123\t\n", "")
|
||||
if trimmed != withNL || trimmed != withSpaces {
|
||||
t.Errorf("derived key is not whitespace-invariant: %q / %q / %q", trimmed, withNL, withSpaces)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSecretsEncryptionKeyHex_distinctSecretsDistinctKeys(t *testing.T) {
|
||||
a, _ := resolveSecretsEncryptionKeyHex("cluster-secret-A", "")
|
||||
b, _ := resolveSecretsEncryptionKeyHex("cluster-secret-B", "")
|
||||
if a == b {
|
||||
t.Errorf("distinct cluster secrets must derive distinct keys; both = %q", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSecretsEncryptionKeyHex_purposeSeparatedFromTURN(t *testing.T) {
|
||||
// The secrets key must NOT equal the TURN key derived from the same cluster
|
||||
// secret — domain separation via the HKDF info label.
|
||||
const cs = "cluster-secret-abc123"
|
||||
secretsHex, _ := resolveSecretsEncryptionKeyHex(cs, "")
|
||||
turnKey, err := secrets.DeriveKey(cs, "turn-encryption")
|
||||
if err != nil {
|
||||
t.Fatalf("derive turn key: %v", err)
|
||||
}
|
||||
if secretsHex == hex.EncodeToString(turnKey) {
|
||||
t.Error("secrets key collides with the TURN key — HKDF purpose label not providing domain separation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSecretsEncryptionKeyHex_emptyClusterSecretUsesFileKey(t *testing.T) {
|
||||
// Legacy/test rigs with no cluster secret fall back to the explicitly
|
||||
// configured file key (trimmed).
|
||||
const fileKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
got, err := resolveSecretsEncryptionKeyHex("", fileKey+"\n")
|
||||
if err != nil {
|
||||
t.Fatalf("resolve: %v", err)
|
||||
}
|
||||
if got != fileKey {
|
||||
t.Errorf("empty cluster secret should return the trimmed file key; got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSecretsEncryptionKeyHex_emptyBothReturnsEmpty(t *testing.T) {
|
||||
// No cluster secret AND no file key → empty result, which makes the
|
||||
// production secrets manager fail loud (allowEphemeral=false) instead of
|
||||
// silently using an ephemeral key.
|
||||
got, err := resolveSecretsEncryptionKeyHex("", "")
|
||||
if err != nil {
|
||||
t.Fatalf("resolve: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Errorf("want empty result when neither source has a key; got %q", got)
|
||||
}
|
||||
}
|
||||
@ -1785,15 +1785,24 @@ func (cm *ClusterManager) saveLocalState(state *ClusterLocalState) error {
|
||||
return fmt.Errorf("failed to marshal state: %w", err)
|
||||
}
|
||||
path := filepath.Join(dir, "cluster-state.json")
|
||||
// 0600: this file now carries the namespace TURN shared secret for
|
||||
// cold-start resilience (bugboard #130), so it must not be world/group
|
||||
// readable. WriteFile's mode only applies on create — chmod explicitly so a
|
||||
// file written 0644 by an older release is tightened on the next rewrite.
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write state file: %w", err)
|
||||
// Atomic write: this file now carries the namespace TURN shared secret
|
||||
// (bugboard #130) and is rewritten from multiple converge paths. Write a
|
||||
// temp file then rename over the target so a reader (or a concurrent
|
||||
// writer) never observes a half-written secret — rename is atomic on the
|
||||
// same filesystem. 0600 + chmod on the temp file keeps the secret out of
|
||||
// world/group read; the rename then makes the live file 0600 too, which
|
||||
// also tightens a file an older release left at 0644.
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write temp state file: %w", err)
|
||||
}
|
||||
if err := os.Chmod(path, 0600); err != nil {
|
||||
return fmt.Errorf("failed to set state file permissions: %w", err)
|
||||
if err := os.Chmod(tmp, 0600); err != nil {
|
||||
os.Remove(tmp)
|
||||
return fmt.Errorf("failed to set temp state file permissions: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, path); err != nil {
|
||||
os.Remove(tmp)
|
||||
return fmt.Errorf("failed to rename state file into place: %w", err)
|
||||
}
|
||||
cm.logger.Info("Saved cluster local state", zap.String("namespace", state.NamespaceName), zap.String("path", path))
|
||||
return nil
|
||||
@ -1855,6 +1864,29 @@ const (
|
||||
webrtcResolveRetryDelay = 2 * time.Second
|
||||
)
|
||||
|
||||
// resolveWebRTCConfigWithRetry calls fetch up to `retries` times, sleeping
|
||||
// `delay` between attempts, and returns the first result whose error is nil. A
|
||||
// distant/just-restarted node's namespace rqlite can take a few seconds to
|
||||
// become readable; without the retry the read fails once and the gateway comes
|
||||
// up with TURN disabled (bugboard #130). A genuine decrypt failure (stale
|
||||
// cluster-secret) also errors and exhausts the retries, returning the final
|
||||
// error so the caller can mark the result unresolved. `sleep` is injected so
|
||||
// unit tests exercise the loop without real delay.
|
||||
func resolveWebRTCConfigWithRetry(retries int, delay time.Duration, sleep func(time.Duration), fetch func() (*WebRTCConfig, error)) (*WebRTCConfig, error) {
|
||||
var cfg *WebRTCConfig
|
||||
var err error
|
||||
for attempt := 0; attempt < retries; attempt++ {
|
||||
cfg, err = fetch()
|
||||
if err == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
if attempt < retries-1 {
|
||||
sleep(delay)
|
||||
}
|
||||
}
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
// applyResolvedWebRTCToState copies a freshly-resolved WebRTC config into the
|
||||
// local cluster state so a future cold start can read the TURN secret from disk
|
||||
// instead of the (possibly-slow) namespace rqlite (bugboard #130). Returns true
|
||||
@ -1886,12 +1918,13 @@ type restoreWebRTC struct {
|
||||
turnDomain string
|
||||
turnSecret string
|
||||
stealthDomain string // feat-124: empty when webrtc stealth is disabled
|
||||
// unresolved is true when the state file had no TURN secret AND the DB
|
||||
// fallback ERRORED (vs. resolved-but-not-enabled). The caller must NOT
|
||||
// write a WebRTC-disabled gateway config off an unresolved lookup — that
|
||||
// silently kills turn.credentials on a node that should serve TURN
|
||||
// (bugboard #130: a decrypt failure after cluster-secret rotation was
|
||||
// swallowed into "disabled"). enabled is always false when unresolved.
|
||||
// unresolved is true when the DB lookup ERRORED (vs. resolved-but-not-
|
||||
// enabled) AND the local cache had no secret to fall back to. The caller
|
||||
// must NOT write a WebRTC-disabled gateway config off an unresolved
|
||||
// lookup — that silently kills turn.credentials on a node that should
|
||||
// serve TURN (bugboard #130: a decrypt failure after cluster-secret
|
||||
// rotation was swallowed into "disabled"). enabled is always false when
|
||||
// unresolved.
|
||||
unresolved bool
|
||||
}
|
||||
|
||||
@ -1905,9 +1938,18 @@ type restoreWebRTC struct {
|
||||
// - SFU (sfuPort) is PER-NODE — non-zero only when this node runs a
|
||||
// local SFU (for /v1/webrtc/signal + /rooms proxying).
|
||||
//
|
||||
// Precedence: prefer the local state file; fall back to the DB (source of
|
||||
// truth) when the state file lacks the TURN secret (the namespace-wide
|
||||
// "webrtc is enabled" marker). dbFetch is lazy — only hit when needed.
|
||||
// Precedence: DB-FIRST. The namespace_webrtc_config row is the source of
|
||||
// truth for the CURRENT TURN secret, so we always consult it. The local
|
||||
// cluster-state.json cache (dbFetch's counterpart) is a FALLBACK ONLY —
|
||||
// used when the DB read fails (a slow/just-restarted node whose namespace
|
||||
// rqlite has not synced yet). This is the bugboard #130 FOLLOW-UP fix: the
|
||||
// earlier state-FIRST read short-circuited the DB whenever the cache held a
|
||||
// secret and so NEVER re-validated a present-but-stale cached secret. If a
|
||||
// secret was rotated (disable→enable) while a node was offline, that node
|
||||
// kept serving the OLD secret indefinitely. DB-first means a stale cache
|
||||
// can survive at most until the DB becomes readable on the next converge —
|
||||
// never indefinitely — while still letting a genuinely DB-down node come up
|
||||
// on TURN via the cache (the #130 resilience the cache was added for).
|
||||
//
|
||||
// `enabled` is true when EITHER a TURN secret OR an SFU port is present,
|
||||
// so the caller knows to write a webrtc block. A non-SFU gateway gets
|
||||
@ -1920,52 +1962,45 @@ func chooseRestoreWebRTC(
|
||||
stateHasSFU bool, stateSFUPort int, stateTURNDomain, stateTURNSecret, stateStealthDomain string,
|
||||
dbFetch func() (turnSecret, turnDomain, stealthDomain string, sfuPort int, resolved bool),
|
||||
) restoreWebRTC {
|
||||
turnSecret := stateTURNSecret
|
||||
turnDomain := stateTURNDomain
|
||||
stealthDomain := stateStealthDomain
|
||||
// DB-first: consult the source of truth before trusting the local cache.
|
||||
dbSecret, dbDomain, dbStealth, dbSFU, resolved := dbFetch()
|
||||
if resolved {
|
||||
// The DB read landed and is authoritative. dbSecret == "" means the
|
||||
// namespace genuinely has no WebRTC enabled — honor that (disable),
|
||||
// do NOT fall back to a possibly-stale cached secret. A present
|
||||
// secret is the CURRENT one and wins over any cached value.
|
||||
if dbSecret == "" {
|
||||
return restoreWebRTC{}
|
||||
}
|
||||
return restoreWebRTC{
|
||||
enabled: true,
|
||||
sfuPort: dbSFU,
|
||||
turnDomain: dbDomain,
|
||||
turnSecret: dbSecret,
|
||||
stealthDomain: dbStealth,
|
||||
}
|
||||
}
|
||||
|
||||
// The DB/decrypt lookup ERRORED (slow node whose namespace rqlite is not
|
||||
// readable yet, or a decrypt failure after a cluster-secret rotation).
|
||||
// Fall back to the locally-cached secret so TURN still comes up — possibly
|
||||
// stale, but functional, and self-correcting on the next converge once the
|
||||
// DB is readable (NOT indefinite). If the cache is empty too, signal
|
||||
// unresolved so the caller preserves the running gateway config instead of
|
||||
// blanking TURN (bugboard #130).
|
||||
sfuPort := 0
|
||||
if stateHasSFU && stateSFUPort > 0 {
|
||||
sfuPort = stateSFUPort
|
||||
}
|
||||
|
||||
// Fall back to the DB when the state file has no TURN secret — that's
|
||||
// the marker that the namespace has WebRTC enabled at all. The state
|
||||
// file is not updated by EnableWebRTC, so a namespace enabled after
|
||||
// the state file was written reaches here with an empty secret.
|
||||
// (Stealth toggles DO rewrite cluster state on every node, so the
|
||||
// state-first read stays fresh for stealthDomain too.)
|
||||
unresolved := false
|
||||
if turnSecret == "" {
|
||||
dbSecret, dbDomain, dbStealth, dbSFU, resolved := dbFetch()
|
||||
switch {
|
||||
case !resolved:
|
||||
// The DB/decrypt lookup ERRORED — we do not know whether WebRTC
|
||||
// is enabled. This is DISTINCT from resolved-but-empty (genuinely
|
||||
// disabled). Signal unresolved so the caller preserves the
|
||||
// running config instead of writing a TURN-disabled one
|
||||
// (bugboard #130).
|
||||
unresolved = true
|
||||
case dbSecret != "":
|
||||
turnSecret = dbSecret
|
||||
if turnDomain == "" {
|
||||
turnDomain = dbDomain
|
||||
}
|
||||
if stealthDomain == "" {
|
||||
stealthDomain = dbStealth
|
||||
}
|
||||
if sfuPort == 0 {
|
||||
sfuPort = dbSFU
|
||||
}
|
||||
}
|
||||
if stateTURNSecret == "" && sfuPort == 0 {
|
||||
return restoreWebRTC{unresolved: true}
|
||||
}
|
||||
|
||||
return restoreWebRTC{
|
||||
enabled: !unresolved && (turnSecret != "" || sfuPort > 0),
|
||||
unresolved: unresolved,
|
||||
enabled: stateTURNSecret != "" || sfuPort > 0,
|
||||
sfuPort: sfuPort,
|
||||
turnDomain: turnDomain,
|
||||
turnSecret: turnSecret,
|
||||
stealthDomain: stealthDomain,
|
||||
turnDomain: stateTURNDomain,
|
||||
turnSecret: stateTURNSecret,
|
||||
stealthDomain: stateStealthDomain,
|
||||
}
|
||||
}
|
||||
|
||||
@ -2114,12 +2149,12 @@ func (cm *ClusterManager) restoreClusterFromState(ctx context.Context, state *Cl
|
||||
SecretsEncryptionKey: cm.secretsEncryptionKey,
|
||||
}
|
||||
|
||||
// Resolve WebRTC config. Prefer the local state file; fall back to
|
||||
// the DB (source of truth) to self-heal stale state. Bugboard #25 —
|
||||
// the state file is NOT updated by EnableWebRTC, so a namespace
|
||||
// enabled AFTER its state file was written carries no SFU/TURN
|
||||
// fields here. The lazy dbFetch only hits the DB when the state
|
||||
// file is incomplete.
|
||||
// Resolve WebRTC config. DB-FIRST (source of truth for the CURRENT
|
||||
// secret); the local state cache is consulted only when the DB read
|
||||
// fails (bugboard #130 follow-up — see chooseRestoreWebRTC). Bugboard
|
||||
// #25 — the state file is NOT updated by EnableWebRTC, so a namespace
|
||||
// enabled AFTER its state file was written carries no SFU/TURN fields
|
||||
// here; reading the DB re-materializes them.
|
||||
wr := chooseRestoreWebRTC(
|
||||
state.HasSFU, state.SFUSignalingPort, state.TURNDomain, state.TURNSharedSecret, state.TURNStealthDomain,
|
||||
func() (turnSecret, turnDomain, stealthDomain string, sfuPort int, resolved bool) {
|
||||
@ -2132,17 +2167,11 @@ func (cm *ClusterManager) restoreClusterFromState(ctx context.Context, state *Cl
|
||||
// decrypt failure (stale key) also errors here and will exhaust
|
||||
// the retries → unresolved → the caller preserves the running
|
||||
// config rather than blanking it.
|
||||
var webrtcCfg *WebRTCConfig
|
||||
var err error
|
||||
for attempt := 0; attempt < webrtcResolveRetries; attempt++ {
|
||||
webrtcCfg, err = cm.GetWebRTCConfig(ctx, state.NamespaceName)
|
||||
if err == nil {
|
||||
break // success — webrtcCfg may be nil (genuinely disabled)
|
||||
}
|
||||
if attempt < webrtcResolveRetries-1 {
|
||||
time.Sleep(webrtcResolveRetryDelay)
|
||||
}
|
||||
}
|
||||
webrtcCfg, err := resolveWebRTCConfigWithRetry(
|
||||
webrtcResolveRetries, webrtcResolveRetryDelay, time.Sleep,
|
||||
func() (*WebRTCConfig, error) {
|
||||
return cm.GetWebRTCConfig(ctx, state.NamespaceName)
|
||||
})
|
||||
if err != nil {
|
||||
// Persistent error after retries (slow read that never
|
||||
// landed, or a decrypt failure). Do NOT swallow into
|
||||
@ -2182,13 +2211,15 @@ func (cm *ClusterManager) restoreClusterFromState(ctx context.Context, state *Cl
|
||||
gwCfg.TURNSecret = wr.turnSecret
|
||||
gwCfg.TURNStealthDomain = wr.stealthDomain
|
||||
|
||||
// Cache the resolved secret into THIS node's local state so the
|
||||
// NEXT cold start reads it from disk (state-first in
|
||||
// chooseRestoreWebRTC short-circuits before the DB) instead of
|
||||
// depending on a live, possibly-slow namespace-rqlite read — which
|
||||
// is exactly what left a distant/slow node's gateway with TURN
|
||||
// disabled on restart (bugboard #130). Each node self-heals its own
|
||||
// cache on a successful resolve; nothing is sent cross-node.
|
||||
// Cache the resolved secret into THIS node's local state so that if
|
||||
// the NEXT cold start can't read the namespace rqlite (a distant/
|
||||
// slow node whose follower hasn't synced), chooseRestoreWebRTC can
|
||||
// fall back to this on-disk secret instead of coming up with TURN
|
||||
// disabled (bugboard #130). The cache is a FALLBACK — DB-first
|
||||
// resolution still prefers the live DB secret whenever it's
|
||||
// readable, so this cached value can never pin the node to a stale
|
||||
// secret. Each node self-heals its own cache on a successful
|
||||
// resolve; nothing is sent cross-node.
|
||||
if applyResolvedWebRTCToState(state, wr) {
|
||||
if err := cm.saveLocalState(state); err != nil {
|
||||
cm.logger.Warn("Failed to cache resolved WebRTC config to local state (cold start may fall back to the DB read next boot)",
|
||||
@ -2198,6 +2229,24 @@ func (cm *ClusterManager) restoreClusterFromState(ctx context.Context, state *Cl
|
||||
zap.String("namespace", state.NamespaceName))
|
||||
}
|
||||
}
|
||||
} else if !wr.unresolved {
|
||||
// The DB read RESOLVED that this namespace has NO WebRTC (disabled).
|
||||
// Clear any stale cached secret from local state so a future cold
|
||||
// start that hits a transient DB error can't fall back to it and
|
||||
// resurrect TURN for a disabled namespace — the hole being: a node
|
||||
// that was offline during DisableWebRTC never received the cleared
|
||||
// state push and would otherwise keep serving the old secret. Only
|
||||
// do this on a RESOLVED-disabled read, NEVER on an unresolved
|
||||
// (DB-error) one — there the cache IS the fallback and must survive.
|
||||
if applyResolvedWebRTCToState(state, restoreWebRTC{}) {
|
||||
if err := cm.saveLocalState(state); err != nil {
|
||||
cm.logger.Warn("Failed to clear stale cached WebRTC secret from local state after DB reported the namespace disabled",
|
||||
zap.String("namespace", state.NamespaceName), zap.Error(err))
|
||||
} else {
|
||||
cm.logger.Info("Cleared stale cached WebRTC secret from local state (namespace disabled in DB)",
|
||||
zap.String("namespace", state.NamespaceName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := http.Get(fmt.Sprintf("http://localhost:%d/v1/health", pb.GatewayHTTPPort))
|
||||
|
||||
71
core/pkg/namespace/cluster_state_perms_test.go
Normal file
71
core/pkg/namespace/cluster_state_perms_test.go
Normal file
@ -0,0 +1,71 @@
|
||||
package namespace
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Bugboard #130 — cluster-state.json carries the namespace TURN shared secret
|
||||
// (plaintext HMAC), so every writer of it must produce a 0600 file and tighten
|
||||
// any pre-existing world-readable file on rewrite. SaveClusterState is the
|
||||
// RECEIVER-side writer that persists state pushed from the coordinator to a
|
||||
// remote namespace node; without this it landed 0644.
|
||||
|
||||
func TestSaveClusterState_writes0600(t *testing.T) {
|
||||
base := t.TempDir()
|
||||
s := &SystemdSpawner{namespaceBase: base, logger: zap.NewNop()}
|
||||
|
||||
if err := s.SaveClusterState("ns-test", []byte(`{"turn_shared_secret":"sek-123"}`)); err != nil {
|
||||
t.Fatalf("SaveClusterState: %v", err)
|
||||
}
|
||||
|
||||
path := filepath.Join(base, "ns-test", "cluster-state.json")
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("stat cluster-state.json: %v", err)
|
||||
}
|
||||
if perm := info.Mode().Perm(); perm != 0600 {
|
||||
t.Errorf("cluster-state.json mode = %o; want 0600 (it carries the TURN secret)", perm)
|
||||
}
|
||||
// No leftover temp file from the atomic write.
|
||||
if _, err := os.Stat(path + ".tmp"); !os.IsNotExist(err) {
|
||||
t.Errorf("temp file should not survive a successful save; stat err = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveClusterState_tightensExisting0644(t *testing.T) {
|
||||
base := t.TempDir()
|
||||
s := &SystemdSpawner{namespaceBase: base, logger: zap.NewNop()}
|
||||
|
||||
// Simulate a file an older release wrote world-readable.
|
||||
dir := filepath.Join(base, "ns-test")
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
path := filepath.Join(dir, "cluster-state.json")
|
||||
if err := os.WriteFile(path, []byte(`{"old":true}`), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := s.SaveClusterState("ns-test", []byte(`{"turn_shared_secret":"sek-new"}`)); err != nil {
|
||||
t.Fatalf("SaveClusterState: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("stat cluster-state.json: %v", err)
|
||||
}
|
||||
if perm := info.Mode().Perm(); perm != 0600 {
|
||||
t.Errorf("rewrite did not tighten perms: mode = %o; want 0600", perm)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != `{"turn_shared_secret":"sek-new"}` {
|
||||
t.Errorf("content not replaced atomically: %s", data)
|
||||
}
|
||||
}
|
||||
@ -1,15 +1,23 @@
|
||||
package namespace
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Bugboard #25 — WebRTC config drift on restart + TURN/SFU decouple.
|
||||
// Bugboard #130 follow-up — DB-FIRST resolution so a stale cached secret can
|
||||
// never be served indefinitely.
|
||||
//
|
||||
// chooseRestoreWebRTC resolves a restored gateway's WebRTC config from the
|
||||
// local state file (which EnableWebRTC does NOT update) with a DB fallback
|
||||
// (source of truth). It also DECOUPLES the two aspects: TURN (secret +
|
||||
// domain) is namespace-wide so ANY gateway can serve credentials; the SFU
|
||||
// port is per-node (0 on a gateway-only node). Pins both the drift
|
||||
// fallback and the non-SFU-gateway case.
|
||||
// chooseRestoreWebRTC resolves a restored gateway's WebRTC config DB-FIRST
|
||||
// (the namespace_webrtc_config row is the source of truth for the current
|
||||
// secret); the local cluster-state.json cache is a FALLBACK consulted only
|
||||
// when the DB read fails (a slow node whose namespace rqlite hasn't synced).
|
||||
// It also DECOUPLES the two aspects: TURN (secret + domain) is namespace-wide
|
||||
// so ANY gateway can serve credentials; the SFU port is per-node (0 on a
|
||||
// gateway-only node). Pins the drift fallback, the non-SFU-gateway case, and
|
||||
// the DB-first precedence (DB secret wins over a cached/stale one).
|
||||
|
||||
// dbFetch signature: () -> (turnSecret, turnDomain, stealthDomain string, sfuPort int, resolved bool).
|
||||
// resolved=true means the lookup completed (with or without a config);
|
||||
@ -23,22 +31,39 @@ func dbFull(secret, domain string, sfuPort int) func() (string, string, string,
|
||||
return func() (string, string, string, int, bool) { return secret, domain, "", sfuPort, true }
|
||||
}
|
||||
|
||||
func TestChooseRestoreWebRTC_stateFileCompleteWins(t *testing.T) {
|
||||
// State file has TURN secret → use it, and NEVER consult the DB
|
||||
// (the lazy dbFetch must not be called — saves a query on the hot
|
||||
// restart path).
|
||||
dbCalled := false
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "state-secret", "",
|
||||
func() (string, string, string, int, bool) { dbCalled = true; return dbNone() })
|
||||
func TestChooseRestoreWebRTC_dbSecretWinsOverCachedState(t *testing.T) {
|
||||
// THE #130 FOLLOW-UP (staleness) case. The state file holds a cached
|
||||
// secret, but the DB (source of truth) has a DIFFERENT, current secret —
|
||||
// e.g. the secret was rotated (disable→enable) while this node was offline.
|
||||
// DB-first MUST serve the current DB secret, NOT the stale cached one. The
|
||||
// old state-first logic short-circuited the DB here and served "old-secret"
|
||||
// indefinitely.
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "old-secret", "cdn-old.dbrs.space",
|
||||
dbFull("new-secret", "turn.ns-x.dbrs.space", 7800))
|
||||
|
||||
if dbCalled {
|
||||
t.Error("DB fetch was called even though the state file had the TURN secret (should short-circuit)")
|
||||
if !got.enabled {
|
||||
t.Fatal("DB has a current secret; result must be enabled")
|
||||
}
|
||||
if !got.enabled || got.sfuPort != 7800 || got.turnSecret != "state-secret" {
|
||||
t.Errorf("want state-file values; got %+v", got)
|
||||
if got.turnSecret != "new-secret" {
|
||||
t.Errorf("BUG #130 STALENESS: turnSecret = %q; want new-secret (the current DB value, not the stale cache)", got.turnSecret)
|
||||
}
|
||||
if got.turnDomain != "turn.ns-x.dbrs.space" {
|
||||
t.Errorf("turnDomain = %q; want state-file value", got.turnDomain)
|
||||
if got.sfuPort != 7800 || got.turnDomain != "turn.ns-x.dbrs.space" {
|
||||
t.Errorf("want DB-derived block; got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChooseRestoreWebRTC_dbDisabledOverridesCachedSecret(t *testing.T) {
|
||||
// The cache holds a secret but the DB read completes and reports NO WebRTC
|
||||
// (the namespace was disabled while this node was offline). DB-first must
|
||||
// honor the disable, NOT keep serving the stale cached secret.
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "stale-secret", "",
|
||||
dbNone) // dbNone = resolved, no config
|
||||
|
||||
if got.enabled {
|
||||
t.Errorf("DB reports disabled: must not keep serving the cached secret; got %+v", got)
|
||||
}
|
||||
if got.unresolved {
|
||||
t.Error("a clean resolved-but-disabled lookup must not be marked unresolved")
|
||||
}
|
||||
}
|
||||
|
||||
@ -84,19 +109,19 @@ func TestChooseRestoreWebRTC_nonSFUGatewayGetsTURNOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChooseRestoreWebRTC_stateHasTURNButNoSFU(t *testing.T) {
|
||||
// State file for a non-SFU node: it has the TURN secret but HasSFU is
|
||||
// false / port 0. Must use the state TURN secret with sfuPort=0 and
|
||||
// NOT consult the DB (TURN secret present = complete enough).
|
||||
dbCalled := false
|
||||
got := chooseRestoreWebRTC(false, 0, "turn.ns-x.dbrs.space", "state-secret", "",
|
||||
func() (string, string, string, int, bool) { dbCalled = true; return dbNone() })
|
||||
func TestChooseRestoreWebRTC_cachedTurnOnlyFallbackOnDBError(t *testing.T) {
|
||||
// A non-SFU node holds a cached TURN secret (HasSFU false / port 0) and the
|
||||
// DB read ERRORS (its namespace rqlite isn't readable yet at cold start).
|
||||
// DB-first falls back to the cached secret so the gateway still serves TURN
|
||||
// credentials — sfuPort stays 0 (no local SFU). This is the #130 resilience
|
||||
// the cache exists for.
|
||||
got := chooseRestoreWebRTC(false, 0, "turn.ns-x.dbrs.space", "state-secret", "", dbError)
|
||||
|
||||
if dbCalled {
|
||||
t.Error("DB fetch called even though state file had the TURN secret")
|
||||
}
|
||||
if !got.enabled || got.sfuPort != 0 || got.turnSecret != "state-secret" {
|
||||
t.Errorf("want TURN-only from state (sfuPort 0); got %+v", got)
|
||||
t.Errorf("want cached TURN-only fallback (sfuPort 0); got %+v", got)
|
||||
}
|
||||
if got.unresolved {
|
||||
t.Error("a usable cached secret must not be marked unresolved")
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,16 +152,14 @@ func TestChooseRestoreWebRTC_dbNoSecretStaysDisabled(t *testing.T) {
|
||||
|
||||
// --- feat-124 stealth domain restore precedence ---
|
||||
|
||||
func TestChooseRestoreWebRTC_stealthFromStateFile(t *testing.T) {
|
||||
// Stealth toggles rewrite cluster state, so a fresh state file carries
|
||||
// the stealth domain and must win without a DB call.
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "state-secret", "cdn-abc123def456.dbrs.space",
|
||||
func() (string, string, string, int, bool) {
|
||||
t.Error("DB fetch called even though state file was complete")
|
||||
return dbNone()
|
||||
})
|
||||
if got.stealthDomain != "cdn-abc123def456.dbrs.space" {
|
||||
t.Errorf("stealthDomain = %q; want state-file value", got.stealthDomain)
|
||||
func TestChooseRestoreWebRTC_stealthFromCacheOnDBError(t *testing.T) {
|
||||
// When the DB read errors, the cache fallback carries the whole block —
|
||||
// including the cached stealth domain — so a stealth-enabled namespace
|
||||
// keeps advertising its stealth rung on a cold start that can't reach the
|
||||
// DB yet.
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "state-secret", "cdn-abc123def456.dbrs.space", dbError)
|
||||
if !got.enabled || got.stealthDomain != "cdn-abc123def456.dbrs.space" {
|
||||
t.Errorf("stealthDomain = %q; want cached value on DB-error fallback; got %+v", got.stealthDomain, got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,27 +211,24 @@ func TestChooseRestoreWebRTC_resolvedEmptyIsDisabledNotUnresolved(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChooseRestoreWebRTC_stateSecretWinsOverDBError(t *testing.T) {
|
||||
// A node that already holds the TURN secret in its state file must NOT be
|
||||
// affected by a DB error — it short-circuits before dbFetch and stays
|
||||
// enabled/resolved. Guards against the #130 fix accidentally disabling
|
||||
// healthy nodes when the DB is flaky.
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "state-secret", "",
|
||||
func() (string, string, string, int, bool) {
|
||||
t.Error("DB fetch must not be called when the state file has the secret")
|
||||
return dbError()
|
||||
})
|
||||
func TestChooseRestoreWebRTC_cachedSecretSurvivesDBError(t *testing.T) {
|
||||
// A node that holds the TURN secret in its state file must NOT be disabled
|
||||
// by a flaky/unsynced DB — when the DB read errors, DB-first falls back to
|
||||
// the cached secret and stays enabled (not unresolved). Guards against the
|
||||
// #130 fix accidentally disabling nodes when the DB is briefly unreadable.
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "state-secret", "", dbError)
|
||||
if got.unresolved || !got.enabled || got.turnSecret != "state-secret" {
|
||||
t.Errorf("state-file secret must win and stay enabled/resolved; got %+v", got)
|
||||
t.Errorf("cached secret must survive a DB error and stay enabled; got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChooseRestoreWebRTC_noStealthStaysEmpty(t *testing.T) {
|
||||
// Stealth disabled everywhere → empty stealthDomain (gateway advertises
|
||||
// the baseline 3-rung ladder only).
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "state-secret", "", dbNone)
|
||||
if got.stealthDomain != "" {
|
||||
t.Errorf("stealthDomain = %q; want empty when stealth is disabled", got.stealthDomain)
|
||||
// Stealth disabled → empty stealthDomain (gateway advertises the baseline
|
||||
// 3-rung ladder only). Uses the cache-fallback path (DB error) so an
|
||||
// enabled-but-no-stealth config is exercised end to end.
|
||||
got := chooseRestoreWebRTC(true, 7800, "turn.ns-x.dbrs.space", "state-secret", "", dbError)
|
||||
if !got.enabled || got.stealthDomain != "" {
|
||||
t.Errorf("stealthDomain = %q; want empty when stealth is disabled; got %+v", got.stealthDomain, got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,16 +252,12 @@ func TestApplyResolvedWebRTCToState_populatesAndReportsChange(t *testing.T) {
|
||||
t.Errorf("state not fully populated: %+v", st)
|
||||
}
|
||||
|
||||
// The whole point: a SECOND boot now reads the secret from state and must
|
||||
// NOT consult the DB (chooseRestoreWebRTC short-circuits).
|
||||
dbCalled := false
|
||||
got := chooseRestoreWebRTC(st.HasSFU, st.SFUSignalingPort, st.TURNDomain, st.TURNSharedSecret, st.TURNStealthDomain,
|
||||
func() (string, string, string, int, bool) { dbCalled = true; return dbError() })
|
||||
if dbCalled {
|
||||
t.Error("BUG #130: cold start still hit the DB even though the secret was cached in local state")
|
||||
}
|
||||
// The whole point of caching: on a SECOND boot where the DB read fails
|
||||
// (slow node, namespace rqlite not synced), the cached secret lets the
|
||||
// gateway still come up on TURN (DB-first falls back to the cache).
|
||||
got := chooseRestoreWebRTC(st.HasSFU, st.SFUSignalingPort, st.TURNDomain, st.TURNSharedSecret, st.TURNStealthDomain, dbError)
|
||||
if !got.enabled || got.unresolved || got.turnSecret != "sek-123" {
|
||||
t.Errorf("cached cold start should resolve enabled from state; got %+v", got)
|
||||
t.Errorf("cached cold start should fall back to the state secret on a DB error; got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -264,3 +280,97 @@ func TestApplyResolvedWebRTCToState_turnOnlyNode_noSFU(t *testing.T) {
|
||||
t.Errorf("turn-only node: want HasTURN=true HasSFU=false secret cached; got %+v", st)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyResolvedWebRTCToState_clearsCacheOnDisable(t *testing.T) {
|
||||
// When the DB resolves the namespace as DISABLED, the caller applies an
|
||||
// empty restoreWebRTC to wipe any stale cached secret from local state — so
|
||||
// a node that was offline during DisableWebRTC can't later fall back to the
|
||||
// old secret on a transient DB error and resurrect TURN for a disabled
|
||||
// namespace. Must report change=true and zero out the cached fields.
|
||||
st := &ClusterLocalState{HasTURN: true, HasSFU: true, TURNSharedSecret: "stale-secret", TURNDomain: "turn.ns-x.dbrs.space", SFUSignalingPort: 7800}
|
||||
|
||||
if !applyResolvedWebRTCToState(st, restoreWebRTC{}) {
|
||||
t.Fatal("disable: want change=true when clearing a cached secret")
|
||||
}
|
||||
if st.TURNSharedSecret != "" || st.HasTURN || st.HasSFU || st.SFUSignalingPort != 0 || st.TURNDomain != "" {
|
||||
t.Errorf("cache not fully cleared on disable: %+v", st)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyResolvedWebRTCToState_secretRotationReportsChange(t *testing.T) {
|
||||
// Secret rotation: the state holds an OLD cached secret and a fresh resolve
|
||||
// brings the NEW (rotated) secret. applyResolvedWebRTCToState MUST report
|
||||
// change=true and overwrite the cache, so the node's fallback secret tracks
|
||||
// the rotation instead of persisting a stale value on disk (bugboard #130
|
||||
// follow-up — the cache must never lag the rotated secret).
|
||||
st := &ClusterLocalState{HasTURN: true, TURNSharedSecret: "old-secret", TURNDomain: "turn.ns-x.dbrs.space"}
|
||||
wr := restoreWebRTC{enabled: true, turnSecret: "new-secret", turnDomain: "turn.ns-x.dbrs.space"}
|
||||
|
||||
if !applyResolvedWebRTCToState(st, wr) {
|
||||
t.Fatal("rotation: want change=true when the resolved secret differs from the cached one")
|
||||
}
|
||||
if st.TURNSharedSecret != "new-secret" {
|
||||
t.Errorf("cache not updated to the rotated secret: got %q; want new-secret", st.TURNSharedSecret)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Bugboard #130 — the cold-start read retries so a slow node's namespace
|
||||
// rqlite read lands once the follower syncs, instead of failing once and
|
||||
// coming up with TURN disabled.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestResolveWebRTCConfigWithRetry_succeedsOnNthAttempt(t *testing.T) {
|
||||
// The read errors on the first two attempts (rqlite not readable yet) then
|
||||
// succeeds — the retry must return the config and not surface the earlier
|
||||
// transient errors.
|
||||
calls := 0
|
||||
slept := 0
|
||||
cfg, err := resolveWebRTCConfigWithRetry(5, time.Millisecond, func(time.Duration) { slept++ },
|
||||
func() (*WebRTCConfig, error) {
|
||||
calls++
|
||||
if calls < 3 {
|
||||
return nil, errors.New("rqlite not readable yet")
|
||||
}
|
||||
return &WebRTCConfig{TURNSharedSecret: "sek-123"}, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("want success on the 3rd attempt; got err %v", err)
|
||||
}
|
||||
if cfg == nil || cfg.TURNSharedSecret != "sek-123" {
|
||||
t.Fatalf("want resolved config; got %+v", cfg)
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Errorf("want exactly 3 fetch attempts; got %d", calls)
|
||||
}
|
||||
if slept != 2 {
|
||||
t.Errorf("want a sleep between each of the 2 failed attempts; got %d", slept)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWebRTCConfigWithRetry_exhaustsAndReturnsError(t *testing.T) {
|
||||
// A persistent error (e.g. a decrypt failure after cluster-secret rotation)
|
||||
// must exhaust all attempts and return the final error — the caller maps
|
||||
// that to unresolved (NOT disabled). No sleep after the final attempt.
|
||||
calls := 0
|
||||
slept := 0
|
||||
cfg, err := resolveWebRTCConfigWithRetry(4, time.Millisecond, func(time.Duration) { slept++ },
|
||||
func() (*WebRTCConfig, error) {
|
||||
calls++
|
||||
return nil, errors.New("decrypt failed")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("want the final error after exhausting retries; got nil")
|
||||
}
|
||||
if cfg != nil {
|
||||
t.Errorf("want nil config on exhaustion; got %+v", cfg)
|
||||
}
|
||||
if calls != 4 {
|
||||
t.Errorf("want 4 attempts (all retries used); got %d", calls)
|
||||
}
|
||||
if slept != 3 {
|
||||
t.Errorf("want a sleep between attempts but not after the last; got %d", slept)
|
||||
}
|
||||
}
|
||||
|
||||
@ -801,16 +801,24 @@ func (s *SystemdSpawner) SaveClusterState(namespace string, data []byte) error {
|
||||
return fmt.Errorf("failed to create namespace dir: %w", err)
|
||||
}
|
||||
path := filepath.Join(dir, "cluster-state.json")
|
||||
// 0600 + chmod: cluster-state.json carries the namespace TURN shared secret
|
||||
// for cold-start resilience (bugboard #130), so it must not be world/group
|
||||
// readable on the receiving node either. WriteFile's mode only applies on
|
||||
// create, so chmod explicitly to tighten a file an older release wrote 0644.
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write cluster state: %w", err)
|
||||
// Atomic write to a temp file + rename: cluster-state.json carries the
|
||||
// namespace TURN shared secret (bugboard #130), so it must not be
|
||||
// world/group readable on the receiving node either, and a reader must
|
||||
// never see a half-written secret. 0600 + chmod on the temp file keeps the
|
||||
// secret private; the rename then makes the live file 0600 too, tightening
|
||||
// a file an older release wrote 0644.
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write temp cluster state: %w", err)
|
||||
}
|
||||
if err := os.Chmod(path, 0600); err != nil {
|
||||
if err := os.Chmod(tmp, 0600); err != nil {
|
||||
os.Remove(tmp)
|
||||
return fmt.Errorf("failed to set cluster state permissions: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, path); err != nil {
|
||||
os.Remove(tmp)
|
||||
return fmt.Errorf("failed to rename cluster state into place: %w", err)
|
||||
}
|
||||
s.logger.Info("Saved cluster state from coordinator",
|
||||
zap.String("namespace", namespace),
|
||||
zap.String("path", path))
|
||||
|
||||
@ -23,12 +23,14 @@ package ntfy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/push"
|
||||
@ -45,14 +47,34 @@ type Config struct {
|
||||
AuthToken string
|
||||
// Timeout bounds each Send call. 0 selects 5 seconds.
|
||||
Timeout time.Duration
|
||||
|
||||
// FanoutResolver, when set, returns the set of ntfy publish base URLs to
|
||||
// deliver EACH publish to — one per active push node. The cluster runs an
|
||||
// independent ntfy per node with NO shared message store, while subscribers
|
||||
// are scattered across nodes by round-robin DNS; a publish that lands on one
|
||||
// node only reaches subscribers on that node, losing ~(N-1)/N (bugboard
|
||||
// #858). Fanning a publish to EVERY node guarantees it reaches whichever
|
||||
// instance the subscriber's connection landed on. When nil, or it returns no
|
||||
// hosts (or errors), Send falls back to the single BaseURL — so push never
|
||||
// breaks if node discovery is unavailable.
|
||||
FanoutResolver func(ctx context.Context) ([]string, error)
|
||||
// FanoutHostHeader, when set, overrides the HTTP Host header and TLS SNI on
|
||||
// fan-out requests. Needed because FanoutResolver returns per-node addresses
|
||||
// (IPs) but each node's reverse proxy (Caddy) routes by — and serves its TLS
|
||||
// cert for — the public push hostname. Empty: no override (tests /
|
||||
// homogeneous hosts).
|
||||
FanoutHostHeader string
|
||||
}
|
||||
|
||||
// Provider is the ntfy push.PushProvider implementation.
|
||||
type Provider struct {
|
||||
baseURL string
|
||||
authToken string
|
||||
httpClient *http.Client
|
||||
logger *zap.Logger
|
||||
baseURL string
|
||||
authToken string
|
||||
httpClient *http.Client
|
||||
fanoutClient *http.Client
|
||||
fanoutResolver func(ctx context.Context) ([]string, error)
|
||||
fanoutHostHeader string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a Provider with the given config.
|
||||
@ -64,18 +86,37 @@ func New(cfg Config, logger *zap.Logger) *Provider {
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
return &Provider{
|
||||
baseURL: strings.TrimRight(cfg.BaseURL, "/"),
|
||||
authToken: cfg.AuthToken,
|
||||
httpClient: &http.Client{Timeout: timeout},
|
||||
logger: logger.Named("ntfy"),
|
||||
p := &Provider{
|
||||
baseURL: strings.TrimRight(cfg.BaseURL, "/"),
|
||||
authToken: cfg.AuthToken,
|
||||
httpClient: &http.Client{Timeout: timeout},
|
||||
fanoutResolver: cfg.FanoutResolver,
|
||||
fanoutHostHeader: cfg.FanoutHostHeader,
|
||||
logger: logger.Named("ntfy"),
|
||||
}
|
||||
if cfg.FanoutResolver != nil {
|
||||
// Fan-out requests dial per-node addresses but must present the public
|
||||
// push hostname for SNI so each node's Caddy serves the right cert and
|
||||
// routes to its local ntfy. A dedicated client carries that fixed SNI.
|
||||
tr := &http.Transport{}
|
||||
if cfg.FanoutHostHeader != "" {
|
||||
tr.TLSClientConfig = &tls.Config{ServerName: cfg.FanoutHostHeader}
|
||||
}
|
||||
p.fanoutClient = &http.Client{Timeout: timeout, Transport: tr}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Name implements push.PushProvider.
|
||||
func (p *Provider) Name() string { return "ntfy" }
|
||||
|
||||
// Send delivers a push notification to the device's ntfy topic.
|
||||
//
|
||||
// When a FanoutResolver is configured, the publish is delivered to EVERY active
|
||||
// push node (the ntfy instances don't share state, so the subscriber's instance
|
||||
// — whichever the round-robin LB picked — must be among the targets), and Send
|
||||
// succeeds as long as at least one instance accepted it (bugboard #858).
|
||||
// Otherwise it publishes to the single configured BaseURL.
|
||||
func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
if msg.DeviceToken == "" {
|
||||
return push.ErrEmptyToken
|
||||
@ -84,7 +125,7 @@ func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
return fmt.Errorf("ntfy: base URL not configured")
|
||||
}
|
||||
|
||||
endpointURL, err := p.resolveEndpoint(msg.DeviceToken)
|
||||
topic, err := p.resolveTopic(msg.DeviceToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -102,10 +143,73 @@ func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
body = string(b)
|
||||
}
|
||||
|
||||
// Resolve the set of base URLs to publish to. Default: the single base URL.
|
||||
// With a fan-out resolver, publish to every active push node so the
|
||||
// subscriber's instance is always covered. Resolver failure is non-fatal —
|
||||
// fall back to the base URL so push keeps working.
|
||||
bases := []string{p.baseURL}
|
||||
httpClient := p.httpClient
|
||||
hostHeader := ""
|
||||
if p.fanoutResolver != nil {
|
||||
if hosts, rerr := p.fanoutResolver(ctx); rerr != nil {
|
||||
p.logger.Warn("ntfy fan-out node resolution failed; publishing to base URL only", zap.Error(rerr))
|
||||
} else if len(hosts) > 0 {
|
||||
bases = hosts
|
||||
httpClient = p.fanoutClient
|
||||
hostHeader = p.fanoutHostHeader
|
||||
}
|
||||
}
|
||||
|
||||
if len(bases) == 1 {
|
||||
return p.postOne(ctx, httpClient, bases[0], topic, body, msg, hostHeader)
|
||||
}
|
||||
|
||||
// Fan out concurrently. Success = at least one instance accepted the
|
||||
// publish (the message is in the cluster). A node that's down is logged but
|
||||
// does not fail the Send, since the message still reaches every reachable
|
||||
// instance — including, in the common case, the subscriber's.
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, len(bases))
|
||||
for i, base := range bases {
|
||||
wg.Add(1)
|
||||
go func(i int, base string) {
|
||||
defer wg.Done()
|
||||
errs[i] = p.postOne(ctx, httpClient, base, topic, body, msg, hostHeader)
|
||||
}(i, base)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
okCount := 0
|
||||
var firstErr error
|
||||
for _, e := range errs {
|
||||
if e == nil {
|
||||
okCount++
|
||||
} else if firstErr == nil {
|
||||
firstErr = e
|
||||
}
|
||||
}
|
||||
if okCount == 0 {
|
||||
return fmt.Errorf("ntfy: fan-out to all %d push nodes failed: %w", len(bases), firstErr)
|
||||
}
|
||||
if okCount < len(bases) {
|
||||
p.logger.Warn("ntfy fan-out partial failure (message still delivered to the reachable instances)",
|
||||
zap.Int("delivered", okCount), zap.Int("total", len(bases)), zap.Error(firstErr))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// postOne publishes a single (already-resolved) topic+body to one ntfy base URL.
|
||||
// hostHeader, when non-empty, overrides the HTTP Host header so a request dialed
|
||||
// at a node IP is still routed by the node's proxy as the public push hostname.
|
||||
func (p *Provider) postOne(ctx context.Context, httpClient *http.Client, base, topic, body string, msg push.PushMessage, hostHeader string) error {
|
||||
endpointURL := strings.TrimRight(base, "/") + "/" + topic
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ntfy: build request: %w", err)
|
||||
}
|
||||
if hostHeader != "" {
|
||||
req.Host = hostHeader
|
||||
}
|
||||
|
||||
if msg.Title != "" {
|
||||
req.Header.Set("Title", msg.Title)
|
||||
@ -127,15 +231,15 @@ func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
req.Header.Set("Authorization", "Bearer "+p.authToken)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ntfy: post: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
return fmt.Errorf("ntfy: http %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
return fmt.Errorf("ntfy: http %d: %s", resp.StatusCode, strings.TrimSpace(string(errBody)))
|
||||
}
|
||||
|
||||
// Drain body to allow connection reuse.
|
||||
@ -143,20 +247,21 @@ func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveEndpoint maps a device token to the ntfy publish URL.
|
||||
// resolveTopic maps a device token to the escaped ntfy topic path (without the
|
||||
// base URL), so the same topic can be published to one or many push nodes.
|
||||
//
|
||||
// The token is one of two shapes:
|
||||
//
|
||||
// - A plain ntfy topic (possibly hierarchical, e.g. "ns/myapp/user-1") —
|
||||
// published to "<baseURL>/<topic>", with each path segment escaped so a
|
||||
// crafted token can't break out of the topic path.
|
||||
// each path segment is escaped so a crafted token can't break out of the
|
||||
// topic path.
|
||||
// - A full UnifiedPush endpoint URL handed to the client by the ntfy
|
||||
// distributor (e.g. "https://push.example.com/up<random>"). UnifiedPush
|
||||
// requires the application server to POST to that endpoint verbatim, so we
|
||||
// use it as-is — but ONLY after verifying its scheme+host match the
|
||||
// configured base URL. That check turns a device-supplied token into an
|
||||
// SSRF only against our own push host, never an arbitrary one.
|
||||
func (p *Provider) resolveEndpoint(token string) (string, error) {
|
||||
// requires the application server to POST to that endpoint, so we accept it
|
||||
// — but ONLY after verifying its scheme+host match the configured base URL,
|
||||
// then take only its path as the topic. That turns a device-supplied token
|
||||
// into a publish only against our own push host, never an arbitrary one.
|
||||
func (p *Provider) resolveTopic(token string) (string, error) {
|
||||
topic := token
|
||||
if isAbsoluteHTTPURL(token) {
|
||||
u, err := url.Parse(token)
|
||||
@ -173,10 +278,7 @@ func (p *Provider) resolveEndpoint(token string) (string, error) {
|
||||
return "", fmt.Errorf("ntfy: endpoint host %q does not match configured push host %q", u.Host, base.Host)
|
||||
}
|
||||
// Confine the URL form to the SAME publish surface as a bare topic:
|
||||
// take only the path as the topic and re-build through the per-segment
|
||||
// escaping below, dropping any query/fragment. So a UnifiedPush
|
||||
// endpoint token can publish a topic but can't gain arbitrary path or
|
||||
// query control on the push host beyond what a plain topic already has.
|
||||
// take only the path as the topic, dropping any query/fragment.
|
||||
topic = strings.TrimPrefix(u.Path, "/")
|
||||
if topic == "" {
|
||||
return "", fmt.Errorf("ntfy: endpoint url %q has no topic path", token)
|
||||
@ -188,7 +290,7 @@ func (p *Provider) resolveEndpoint(token string) (string, error) {
|
||||
for i, seg := range parts {
|
||||
parts[i] = url.PathEscape(seg)
|
||||
}
|
||||
return p.baseURL + "/" + strings.Join(parts, "/"), nil
|
||||
return strings.Join(parts, "/"), nil
|
||||
}
|
||||
|
||||
// isAbsoluteHTTPURL reports whether s looks like an absolute http(s) URL (the
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -306,3 +307,136 @@ func TestName(t *testing.T) {
|
||||
t.Errorf("expected Name=ntfy, got %s", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Bugboard #858 — cluster fan-out. Each push node runs an independent ntfy with
|
||||
// no shared store, so a publish must reach EVERY node for the subscriber's
|
||||
// instance (round-robin DNS picks one) to receive it.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// fanoutRecorder is a test ntfy node that records the topics it received.
|
||||
type fanoutRecorder struct {
|
||||
mu sync.Mutex
|
||||
topics []string
|
||||
}
|
||||
|
||||
func newFanoutNode(t *testing.T) (*httptest.Server, *fanoutRecorder) {
|
||||
t.Helper()
|
||||
rec := &fanoutRecorder{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
rec.mu.Lock()
|
||||
rec.topics = append(rec.topics, strings.TrimPrefix(r.URL.Path, "/"))
|
||||
rec.mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
return srv, rec
|
||||
}
|
||||
|
||||
func (r *fanoutRecorder) count() int {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return len(r.topics)
|
||||
}
|
||||
|
||||
func TestSend_fanout_publishesToAllNodes(t *testing.T) {
|
||||
s1, r1 := newFanoutNode(t)
|
||||
defer s1.Close()
|
||||
s2, r2 := newFanoutNode(t)
|
||||
defer s2.Close()
|
||||
s3, r3 := newFanoutNode(t)
|
||||
defer s3.Close()
|
||||
|
||||
p := New(Config{
|
||||
BaseURL: s1.URL, // base URL still required; fan-out targets come from the resolver
|
||||
FanoutResolver: func(context.Context) ([]string, error) {
|
||||
return []string{s1.URL, s2.URL, s3.URL}, nil
|
||||
},
|
||||
}, nil)
|
||||
|
||||
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "user-1", Body: "hi"}); err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
for i, r := range []*fanoutRecorder{r1, r2, r3} {
|
||||
if r.count() != 1 {
|
||||
t.Errorf("node %d received %d publishes; want exactly 1 (the publish must reach every node)", i+1, r.count())
|
||||
}
|
||||
if r.count() == 1 && r.topics[0] != "user-1" {
|
||||
t.Errorf("node %d got topic %q; want user-1", i+1, r.topics[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_fanout_oneNodeDown_stillSucceeds(t *testing.T) {
|
||||
up, rUp := newFanoutNode(t)
|
||||
defer up.Close()
|
||||
down, _ := newFanoutNode(t)
|
||||
down.Close() // unreachable
|
||||
|
||||
p := New(Config{
|
||||
BaseURL: up.URL,
|
||||
FanoutResolver: func(context.Context) ([]string, error) {
|
||||
return []string{up.URL, down.URL}, nil
|
||||
},
|
||||
}, nil)
|
||||
|
||||
// At least one node accepted it → Send succeeds; the message still reached
|
||||
// the reachable instances.
|
||||
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
|
||||
t.Fatalf("Send should succeed when at least one node is up; got %v", err)
|
||||
}
|
||||
if rUp.count() != 1 {
|
||||
t.Errorf("the up node should have received the publish; got %d", rUp.count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_fanout_allNodesDown_returnsError(t *testing.T) {
|
||||
d1, _ := newFanoutNode(t)
|
||||
d1.Close()
|
||||
d2, _ := newFanoutNode(t)
|
||||
d2.Close()
|
||||
|
||||
p := New(Config{
|
||||
BaseURL: "http://127.0.0.1:1", // unused for posting; just non-empty
|
||||
FanoutResolver: func(context.Context) ([]string, error) {
|
||||
return []string{d1.URL, d2.URL}, nil
|
||||
},
|
||||
}, nil)
|
||||
|
||||
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err == nil {
|
||||
t.Fatal("Send should fail when every node is unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_fanout_resolverEmpty_fallsBackToBaseURL(t *testing.T) {
|
||||
base, rBase := newFanoutNode(t)
|
||||
defer base.Close()
|
||||
|
||||
p := New(Config{
|
||||
BaseURL: base.URL,
|
||||
FanoutResolver: func(context.Context) ([]string, error) { return nil, nil }, // no active nodes
|
||||
}, nil)
|
||||
|
||||
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
if rBase.count() != 1 {
|
||||
t.Errorf("empty resolver must fall back to the base URL; base got %d publishes", rBase.count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend_fanout_resolverError_fallsBackToBaseURL(t *testing.T) {
|
||||
base, rBase := newFanoutNode(t)
|
||||
defer base.Close()
|
||||
|
||||
p := New(Config{
|
||||
BaseURL: base.URL,
|
||||
FanoutResolver: func(context.Context) ([]string, error) { return nil, context.DeadlineExceeded },
|
||||
}, nil)
|
||||
|
||||
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
|
||||
t.Fatalf("resolver error must not fail the push (fall back to base URL); got %v", err)
|
||||
}
|
||||
if rBase.count() != 1 {
|
||||
t.Errorf("resolver error must fall back to the base URL; base got %d publishes", rBase.count())
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user