diff --git a/core/Makefile b/core/Makefile index da8ab1a..91595b1 100644 --- a/core/Makefile +++ b/core/Makefile @@ -63,7 +63,7 @@ test-e2e-quick: .PHONY: build clean test deps tidy fmt vet lint install-hooks push-devnet push-testnet rollout-devnet rollout-testnet release -VERSION := 0.120.0 +VERSION := 0.121.0 COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown) DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ) LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)' @@ -80,6 +80,7 @@ build: deps go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway go build -ldflags "$(LDFLAGS)" -o bin/sfu ./cmd/sfu go build -ldflags "$(LDFLAGS)" -o bin/turn ./cmd/turn + go build -ldflags "$(LDFLAGS)" -o bin/orama-sni-router ./cmd/sni-router @echo "Build complete! Run ./bin/orama version" # Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source) diff --git a/core/cmd/sni-router/main.go b/core/cmd/sni-router/main.go new file mode 100644 index 0000000..cc727df --- /dev/null +++ b/core/cmd/sni-router/main.go @@ -0,0 +1,242 @@ +// Command sni-router is a TLS-level Server Name Indication router. +// +// It listens on a public TCP port (typically :443), peeks at the TLS +// ClientHello SNI on each connection, and forwards the raw stream to +// a configured backend. It does NOT terminate TLS — encrypted bytes +// pass through verbatim. This lets one port serve multiple TLS-speaking +// backends (HTTPS for the gateway, TURN-over-TLS for stealth WebRTC). +// +// See pkg/sniproxy for the underlying library. +// +// Configuration: YAML file at --config (defaults to ~/.orama/sni-router.yaml). +// +// Example sni-router.yaml: +// +// listen: ":443" +// client_hello_timeout: 5s +// backend_dial_timeout: 5s +// max_concurrent_conns: 10000 +// fallback: +// name: caddy +// addr: "127.0.0.1:8443" +// routes: +// - match: "cdn.example.com" +// backend: +// name: turn-tls +// addr: "127.0.0.1:5349" +// - match: "turn.example.com" +// backend: +// name: turn-tls +// addr: "127.0.0.1:5349" +// - match: "*.ns-myapp.example.com" +// backend: +// name: gateway +// addr: "127.0.0.1:8443" +package main + +import ( + "flag" + "fmt" + "net" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/DeBrosOfficial/network/pkg/config" + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/sniproxy" + "go.uber.org/zap" +) + +var ( + version = "dev" + commit = "unknown" +) + +// yamlBackend mirrors sniproxy.Backend for YAML decoding. +type yamlBackend struct { + Name string `yaml:"name"` + Network string `yaml:"network"` + Addr string `yaml:"addr"` +} + +// yamlRoute mirrors sniproxy.Route for YAML decoding. +type yamlRoute struct { + Match string `yaml:"match"` + Backend yamlBackend `yaml:"backend"` +} + +// yamlConfig is the on-disk configuration shape. +type yamlConfig struct { + Listen string `yaml:"listen"` + ClientHelloTimeout time.Duration `yaml:"client_hello_timeout"` + BackendDialTimeout time.Duration `yaml:"backend_dial_timeout"` + MaxConcurrentConns int `yaml:"max_concurrent_conns"` + Fallback yamlBackend `yaml:"fallback"` + Routes []yamlRoute `yaml:"routes"` +} + +func main() { + logger, err := logging.NewColoredLogger(logging.ComponentSNI, true) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err) + os.Exit(1) + } + + logger.ComponentInfo(logging.ComponentSNI, "Starting SNI router", + zap.String("version", version), + zap.String("commit", commit)) + + cfg := parseConfig(logger) + + router := sniproxy.NewRouter(toBackend(cfg.Fallback)) + router.Replace(toRoutes(cfg.Routes), toBackend(cfg.Fallback)) + + srv := sniproxy.NewServer(router, sniproxy.Config{ + ClientHelloTimeout: cfg.ClientHelloTimeout, + BackendDialTimeout: cfg.BackendDialTimeout, + MaxConcurrentConns: cfg.MaxConcurrentConns, + }, logger.Logger) + + ln, err := net.Listen("tcp", cfg.Listen) + if err != nil { + logger.ComponentError(logging.ComponentSNI, "Failed to listen", + zap.String("addr", cfg.Listen), zap.Error(err)) + os.Exit(1) + } + + logger.ComponentInfo(logging.ComponentSNI, "SNI router listening", + zap.String("addr", cfg.Listen), + zap.Int("routes", len(cfg.Routes)), + zap.String("fallback", cfg.Fallback.Addr), + ) + + // Run Serve in a goroutine so the main goroutine can wait on signals. + serveErrCh := make(chan error, 1) + go func() { + serveErrCh <- srv.Serve(ln) + }() + + // Wait for termination signal or unrecoverable Serve error. + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + + select { + case sig := <-quit: + logger.ComponentInfo(logging.ComponentSNI, "Shutdown signal received", + zap.String("signal", sig.String())) + case err := <-serveErrCh: + logger.ComponentError(logging.ComponentSNI, "Serve returned", + zap.Error(err)) + } + + // Stop accepting new connections, then drain in-flight ones. + _ = ln.Close() + srv.Close() + + logger.ComponentInfo(logging.ComponentSNI, "SNI router shutdown complete") +} + +func parseConfig(logger *logging.ColoredLogger) yamlConfig { + configFlag := flag.String("config", "", "Config file path (absolute or filename in ~/.orama)") + flag.Parse() + + var configPath string + var err error + if *configFlag != "" { + if filepath.IsAbs(*configFlag) { + configPath = *configFlag + } else { + configPath, err = config.DefaultPath(*configFlag) + if err != nil { + logger.ComponentError(logging.ComponentSNI, "Failed to determine config path", + zap.Error(err)) + os.Exit(1) + } + } + } else { + configPath, err = config.DefaultPath("sni-router.yaml") + if err != nil { + logger.ComponentError(logging.ComponentSNI, "Failed to determine config path", + zap.Error(err)) + os.Exit(1) + } + } + + data, err := os.ReadFile(configPath) + if err != nil { + logger.ComponentError(logging.ComponentSNI, "Config file not found", + zap.String("path", configPath), zap.Error(err)) + fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath) + os.Exit(1) + } + + var y yamlConfig + if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil { + logger.ComponentError(logging.ComponentSNI, "Failed to parse SNI router config", + zap.Error(err)) + fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err) + os.Exit(1) + } + + if errs := validateConfig(&y); len(errs) > 0 { + fmt.Fprintf(os.Stderr, "\nSNI router configuration errors (%d):\n", len(errs)) + for _, e := range errs { + fmt.Fprintf(os.Stderr, " - %s\n", e) + } + fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n") + os.Exit(1) + } + + logger.ComponentInfo(logging.ComponentSNI, "Loaded SNI router configuration", + zap.String("path", configPath), + ) + + return y +} + +// validateConfig returns a non-empty slice of human-readable errors on misconfig. +func validateConfig(y *yamlConfig) []string { + var errs []string + if y.Listen == "" { + errs = append(errs, "listen: required (e.g. \":443\")") + } + if y.Fallback.Addr == "" { + errs = append(errs, "fallback.addr: required (where to send unmatched SNIs, typically Caddy)") + } + for i, r := range y.Routes { + if r.Match == "" { + errs = append(errs, fmt.Sprintf("routes[%d].match: required", i)) + } + if r.Backend.Addr == "" { + errs = append(errs, fmt.Sprintf("routes[%d].backend.addr: required", i)) + } + } + return errs +} + +func toBackend(b yamlBackend) sniproxy.Backend { + network := b.Network + if network == "" { + network = "tcp" + } + return sniproxy.Backend{ + Name: b.Name, + Network: network, + Addr: b.Addr, + } +} + +func toRoutes(in []yamlRoute) []sniproxy.Route { + out := make([]sniproxy.Route, len(in)) + for i, r := range in { + out[i] = sniproxy.Route{ + Match: r.Match, + Backend: toBackend(r.Backend), + } + } + return out +} diff --git a/core/docs/STEALTH_TURN.md b/core/docs/STEALTH_TURN.md new file mode 100644 index 0000000..1005e89 --- /dev/null +++ b/core/docs/STEALTH_TURN.md @@ -0,0 +1,187 @@ +# Stealth TURN Deployment Guide + +## What this is + +A TLS-level SNI router that lets Orama serve TURN-over-TLS on `:443`, +sharing the port with Caddy HTTPS. From a network observer's +perspective, TURN traffic is indistinguishable from ordinary HTTPS — +useful for users in regions that block standard VoIP ports (UAE, Saudi +Arabia, China, Iran). + +## Architecture + +``` + Internet + │ + ▼ + TCP :443 + │ + ┌─────────┴─────────┐ + │ orama-sni-router │ peeks SNI, forwards bytes + └─────────┬─────────┘ + │ + ┌───────────────┼────────────────┐ + ▼ ▼ + cdn. *., + turn. (everything else) + │ │ + ▼ ▼ + Pion TURN-TLS Caddy + 127.0.0.1:5349 127.0.0.1:8443 + (existing) (moved from :443) +``` + +The router does **not** terminate TLS. It reads the unencrypted TLS +ClientHello (first ~5 KB), inspects the SNI extension, and dials the +matching backend. Encrypted bytes pass through verbatim. + +## Components + +- **Library:** `pkg/sniproxy/` — ClientHello parser, route table, TCP server +- **Binary:** `cmd/sni-router/` (built as `bin/orama-sni-router`) +- **Systemd unit:** `systemd/orama-sni-router.service` +- **Config:** `~/.orama/sni-router.yaml` + +## Deployment cutover + +⚠️ **This change touches production `:443`. Stage on one node first, watch for 24h, then roll out.** + +### 1. Reconfigure Caddy to listen on `:8443` + +Update wherever the Caddy config is generated (`pkg/environments/production/installers/caddy.go`) +so Caddy binds `:8443` (HTTPS) and `:8080` (HTTP) instead of `:443` and `:80`. + +Drop `CAP_NET_BIND_SERVICE` from Caddy's systemd unit — it no longer needs privileged ports. + +### 2. Provision the cert SAN for `cdn.` + +Caddy's automatic Let's Encrypt flow needs to issue a cert covering +`cdn.` and `cdn.ns-*.` so Pion TURN can read it +on startup. Add these names to Caddy's TLS config block. + +### 3. Drop `sni-router.yaml` config + +Example for a single-namespace node: + +```yaml +listen: ":443" +client_hello_timeout: 5s +backend_dial_timeout: 5s +max_concurrent_conns: 10000 +fallback: + name: caddy + addr: "127.0.0.1:8443" +routes: + - match: "cdn.example.com" + backend: + name: turn-tls + addr: "127.0.0.1:5349" + - match: "turn.example.com" + backend: + name: turn-tls + addr: "127.0.0.1:5349" +``` + +For multi-namespace, add per-namespace TURN backends (each namespace's +TURN-TLS port is allocated by `pkg/namespace`): + +```yaml + - match: "cdn.ns-myapp.example.com" + backend: { name: "turn-myapp", addr: "127.0.0.1:5349" } + - match: "cdn.ns-other.example.com" + backend: { name: "turn-other", addr: "127.0.0.1:5350" } +``` + +### 4. Deploy + start in order + +```bash +# Install binary +sudo cp bin-linux/orama-sni-router /opt/orama/bin/ + +# Install service +sudo cp systemd/orama-sni-router.service /etc/systemd/system/ +sudo systemctl daemon-reload + +# Stop Caddy briefly (it's about to lose :443) +sudo systemctl stop caddy + +# Start the SNI router (it takes :443) +sudo systemctl enable --now orama-sni-router + +# Restart Caddy on its new port +sudo systemctl start caddy + +# Verify +curl -v https://cdn.:443 # should hit TURN backend (TLS handshake will fail; that's fine) +curl -v https://:443 # should hit Caddy (normal HTTPS response) +``` + +### 5. Enable stealth in the gateway + +Once the SNI router is live, tell the gateway to advertise the stealth URI: + +```go +// in gateway dependencies / startup +webrtcHandlers.SetStealthCDNDomain("cdn.") +``` + +The credentials handler will start including `turns:cdn.:443` +in `POST /v1/webrtc/turn/credentials` responses automatically. + +### 6. Monitor + +```bash +journalctl -u orama-sni-router.service -f +journalctl -u caddy.service -f +``` + +Watch for: +- `Connection limit reached` warnings (bump `max_concurrent_conns`) +- `backend dial failed` warnings (Caddy isn't listening on `:8443`, or TURN isn't on `:5349`) +- `ClientHello peek failed` debugs (curious clients sending non-TLS to `:443` — usually port scanners) + +## Rollback + +If anything is wrong: + +```bash +sudo systemctl stop orama-sni-router +# Reconfigure Caddy back to :443 and restart +sudo systemctl restart caddy +``` + +Caddy reclaiming `:443` from the disabled router is the fastest way back to +the previous topology. + +## Known gaps + +- **Dynamic route source:** today's router reads YAML once at startup. To + pick up new namespaces without restart, implement a `RouteSource` that + polls `pkg/namespace` for active TURN deployments. The library is + already designed for `Router.Replace` to be called concurrently. +- **TLS cert hot-reload:** Pion TURN reads the cert once at startup. When + Caddy renews `cdn.`, Pion needs to be restarted to pick up + the new cert. A small file-watcher service (or a periodic restart in + off-peak hours) handles this for now. + +## What clients see + +Once enabled, the credentials response gains one entry: + +```json +{ + "username": "...", + "password": "...", + "ttl": 600, + "uris": [ + "turn:turn.example.com:3478?transport=udp", + "turn:turn.example.com:3478?transport=tcp", + "turns:turn.example.com:5349", + "turns:cdn.example.com:443" + ] +} +``` + +Browsers iterate ICE candidates; users in restricted regions will silently +succeed via the `:443` URI when others fail. No client-side change is +required. diff --git a/core/migrations/021_pubsub_trigger_patterns.sql b/core/migrations/021_pubsub_trigger_patterns.sql new file mode 100644 index 0000000..5b91e35 --- /dev/null +++ b/core/migrations/021_pubsub_trigger_patterns.sql @@ -0,0 +1,28 @@ +-- ============================================================================= +-- 021_pubsub_trigger_patterns.sql +-- +-- Add `topic_pattern` column alongside the existing `topic` column to +-- function_pubsub_triggers. The new column may contain SQLite GLOB +-- patterns (e.g. "presence:*") in addition to exact topic names. +-- +-- This is intentionally ADDITIVE rather than a column rename to remain +-- safe under rolling upgrades: +-- - Old binaries continue reading `topic` and keep working. +-- - New binaries read `topic_pattern` (which is back-filled from +-- `topic` for existing rows) and write BOTH columns. +-- A future migration can DROP COLUMN topic once every node is on the +-- new release. +-- ============================================================================= + +ALTER TABLE function_pubsub_triggers + ADD COLUMN topic_pattern TEXT NOT NULL DEFAULT ''; + +UPDATE function_pubsub_triggers +SET topic_pattern = topic +WHERE topic_pattern = ''; + +CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_function + ON function_pubsub_triggers(function_id); + +CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_enabled + ON function_pubsub_triggers(enabled); diff --git a/core/migrations/022_aggregation_windows.sql b/core/migrations/022_aggregation_windows.sql new file mode 100644 index 0000000..05f68f0 --- /dev/null +++ b/core/migrations/022_aggregation_windows.sql @@ -0,0 +1,20 @@ +-- ============================================================================= +-- 022_aggregation_windows.sql +-- +-- Add per-trigger aggregation parameters to function_pubsub_triggers. +-- +-- aggregation_window_ms = 0 means "no aggregation, invoke once per event" +-- (the existing behaviour). Any positive value enables buffering of events +-- in-memory on the dispatching node; the function is invoked once per +-- window with a batched payload. +-- +-- aggregation_max_batch_size caps the per-window batch. When the buffer +-- reaches this size, the dispatcher flushes immediately even if the +-- window timer hasn't fired yet. +-- ============================================================================= + +ALTER TABLE function_pubsub_triggers + ADD COLUMN aggregation_window_ms INTEGER NOT NULL DEFAULT 0; + +ALTER TABLE function_pubsub_triggers + ADD COLUMN aggregation_max_batch_size INTEGER NOT NULL DEFAULT 100; diff --git a/core/migrations/023_push_devices.sql b/core/migrations/023_push_devices.sql new file mode 100644 index 0000000..bc6c908 --- /dev/null +++ b/core/migrations/023_push_devices.sql @@ -0,0 +1,33 @@ +-- ============================================================================= +-- 023_push_devices.sql +-- +-- Per-namespace, per-user push notification device registry. +-- +-- token_encrypted is AES-256-GCM ciphertext (prefix 'enc:') derived via +-- pkg/secrets. Tokens are sensitive — they let the holder spam a user's +-- device — so they are never returned via any API or written to logs. +-- +-- provider matches a registered push.PushProvider name: +-- 'ntfy', 'expo', 'apns', 'fcm' (future), ... +-- ============================================================================= + +CREATE TABLE IF NOT EXISTS push_devices ( + id TEXT PRIMARY KEY, + namespace TEXT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + provider TEXT NOT NULL, + token_encrypted TEXT NOT NULL, + platform TEXT, + app_version TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + last_seen INTEGER, + UNIQUE(namespace, user_id, device_id) +); + +CREATE INDEX IF NOT EXISTS idx_push_devices_user + ON push_devices(namespace, user_id); + +CREATE INDEX IF NOT EXISTS idx_push_devices_provider + ON push_devices(provider); diff --git a/core/pkg/cli/build/builder.go b/core/pkg/cli/build/builder.go index 2e61dcf..51a32e1 100644 --- a/core/pkg/cli/build/builder.go +++ b/core/pkg/cli/build/builder.go @@ -157,6 +157,7 @@ func (b *Builder) buildOramaBinaries() error { {Name: "identity", Package: "./cmd/identity/"}, {Name: "sfu", Package: "./cmd/sfu/"}, {Name: "turn", Package: "./cmd/turn/"}, + {Name: "orama-sni-router", Package: "./cmd/sni-router/"}, } for _, bin := range binaries { diff --git a/core/pkg/cli/production/clean/clean.go b/core/pkg/cli/production/clean/clean.go index f473683..fe4b61f 100644 --- a/core/pkg/cli/production/clean/clean.go +++ b/core/pkg/cli/production/clean/clean.go @@ -171,7 +171,7 @@ rm -f /tmp/orama-*.sh /tmp/network-source.tar.gz /tmp/orama-*.tar.gz # Nuclear: remove binaries if [ -n "$NUCLEAR" ]; then rm -f /usr/local/bin/orama /usr/local/bin/orama-node /usr/local/bin/gateway - rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn + rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn /usr/local/bin/orama-sni-router rm -f /usr/local/bin/olric-server /usr/local/bin/ipfs /usr/local/bin/ipfs-cluster-service rm -f /usr/local/bin/rqlited /usr/local/bin/coredns rm -f /usr/bin/caddy diff --git a/core/pkg/client/interface.go b/core/pkg/client/interface.go index 2c7e40b..1fff4c9 100644 --- a/core/pkg/client/interface.go +++ b/core/pkg/client/interface.go @@ -47,10 +47,29 @@ type DatabaseClient interface { type PubSubClient interface { Subscribe(ctx context.Context, topic string, handler MessageHandler) error Publish(ctx context.Context, topic string, data []byte) error + // PublishBatch publishes multiple messages in parallel, one per topic. + // See pubsub.Manager.PublishBatch for semantics (fail-fast vs. best-effort). + PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error + // PublishSame sends the same payload to every topic in parallel. + PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error Unsubscribe(ctx context.Context, topic string) error ListTopics(ctx context.Context) ([]string, error) } +// TopicMessage is one entry in a batch publish. +// Mirrors pubsub.TopicMessage to avoid forcing client callers to import pkg/pubsub. +type TopicMessage struct { + Topic string + Data []byte +} + +// PublishBatchOptions controls batch publish behavior. +// Mirrors pubsub.PublishBatchOptions. +type PublishBatchOptions struct { + BestEffort bool + MaxConcurrency int +} + // NetworkInfo provides network status and peer information type NetworkInfo interface { GetPeers(ctx context.Context) ([]PeerInfo, error) diff --git a/core/pkg/client/pubsub_bridge.go b/core/pkg/client/pubsub_bridge.go index 653301e..79780b0 100644 --- a/core/pkg/client/pubsub_bridge.go +++ b/core/pkg/client/pubsub_bridge.go @@ -4,13 +4,13 @@ import ( "context" "fmt" - "github.com/DeBrosOfficial/network/pkg/pubsub" + pkgpubsub "github.com/DeBrosOfficial/network/pkg/pubsub" ) // pubSubBridge bridges between our PubSubClient interface and the pubsub package type pubSubBridge struct { client *Client - adapter *pubsub.ClientAdapter + adapter *pkgpubsub.ClientAdapter } func (p *pubSubBridge) Subscribe(ctx context.Context, topic string, handler MessageHandler) error { @@ -31,6 +31,26 @@ func (p *pubSubBridge) Publish(ctx context.Context, topic string, data []byte) e return p.adapter.Publish(ctx, topic, data) } +func (p *pubSubBridge) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error { + if err := p.client.requireAccess(ctx); err != nil { + return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err) + } + pkgMsgs := make([]pkgpubsub.TopicMessage, len(msgs)) + for i, m := range msgs { + pkgMsgs[i] = pkgpubsub.TopicMessage{Topic: m.Topic, Data: m.Data} + } + pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency} + return p.adapter.PublishBatch(ctx, pkgMsgs, pkgOpts) +} + +func (p *pubSubBridge) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error { + if err := p.client.requireAccess(ctx); err != nil { + return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err) + } + pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency} + return p.adapter.PublishSame(ctx, topics, data, pkgOpts) +} + func (p *pubSubBridge) Unsubscribe(ctx context.Context, topic string) error { if err := p.client.requireAccess(ctx); err != nil { return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err) diff --git a/core/pkg/gateway/config.go b/core/pkg/gateway/config.go index 41cdebb..323ae48 100644 --- a/core/pkg/gateway/config.go +++ b/core/pkg/gateway/config.go @@ -56,4 +56,15 @@ type Config struct { SFUPort int // Local SFU signaling port to proxy WebSocket connections to TURNDomain string // TURN server domain for credential generation TURNSecret string // HMAC-SHA1 shared secret for TURN credential generation + + // StealthCDNDomain, when set, makes the WebRTC credentials handler + // advertise turns::443 (served by the SNI router). + StealthCDNDomain string + + // Push notification configuration. Push is enabled when at least one + // provider URL/token is set. Tokens stored in the push_devices table + // are encrypted at rest via pkg/secrets using the cluster secret. + NtfyBaseURL string // ntfy server URL (e.g. "http://localhost:8080") + NtfyAuthToken string // optional bearer token for ntfy + ExpoAccessToken string // optional Expo access token } diff --git a/core/pkg/gateway/dependencies.go b/core/pkg/gateway/dependencies.go index eaad2dd..f109cd6 100644 --- a/core/pkg/gateway/dependencies.go +++ b/core/pkg/gateway/dependencies.go @@ -19,6 +19,9 @@ import ( "github.com/DeBrosOfficial/network/pkg/logging" "github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/pubsub" + "github.com/DeBrosOfficial/network/pkg/push" + pushexpo "github.com/DeBrosOfficial/network/pkg/push/providers/expo" + pushntfy "github.com/DeBrosOfficial/network/pkg/push/providers/ntfy" "github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/serverless" "github.com/DeBrosOfficial/network/pkg/serverless/hostfunctions" @@ -63,6 +66,11 @@ type Dependencies struct { // PubSub trigger dispatcher (used to wire into PubSubHandlers) PubSubDispatcher *triggers.PubSubDispatcher + // Push notification dispatcher (nil when push isn't configured — + // hostfunc + HTTP handlers degrade to no-op / 503). + PushDispatcher *push.PushDispatcher + PushDeviceStore push.PushDeviceStore + // Authentication service AuthService *auth.Service } @@ -412,6 +420,19 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe secretsMgr = smImpl } + // Initialize push notification dispatcher if any provider is configured. + // Devices are stored encrypted in RQLite (see migration 023). Providers + // are registered based on gateway config; missing config = provider absent. + pushDispatcher, pushStore, err := buildPushDispatcher(cfg, deps.ORMClient, logger) + if err != nil { + // Non-fatal: log and continue. Functions calling push_send will get nil + // (silent no-op) and HTTP /v1/push/* endpoints return 503. + logger.ComponentWarn(logging.ComponentGeneral, + "push notifications disabled (init failed)", zap.Error(err)) + } + deps.PushDispatcher = pushDispatcher + deps.PushDeviceStore = pushStore + // Create host functions provider (allows functions to call Orama services) hostFuncsCfg := hostfunctions.HostFunctionsConfig{ IPFSAPIURL: cfg.IPFSAPIURL, @@ -424,6 +445,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe pubsubAdapter, // pubsub adapter for serverless functions deps.ServerlessWSMgr, secretsMgr, + pushDispatcher, // may be nil — PushSend hostfunc handles that hostFuncsCfg, logger.Logger, ) @@ -686,3 +708,40 @@ func injectRQLiteAuth(dsn, username, password string) string { } return dsn } + +// buildPushDispatcher constructs a push.PushDispatcher + device store with +// all enabled providers. Returns (nil, nil, nil) when no provider is +// configured — that's a supported state, not an error. Returns (nil, nil, +// err) on hard init failures (e.g. cluster secret missing for the +// encrypted device store). +func buildPushDispatcher(cfg *Config, db rqlite.Client, logger *logging.ColoredLogger) (*push.PushDispatcher, push.PushDeviceStore, error) { + if cfg.NtfyBaseURL == "" && cfg.ExpoAccessToken == "" { + // No providers configured — push is disabled. + return nil, nil, nil + } + if cfg.ClusterSecret == "" { + // Devices are encrypted at rest using a cluster-secret-derived key. + // Without it we can't store anything safely. + return nil, nil, fmt.Errorf("push enabled but ClusterSecret is empty") + } + store, err := push.NewRqliteDeviceStore(db, cfg.ClusterSecret, logger.Logger) + if err != nil { + return nil, nil, fmt.Errorf("init push device store: %w", err) + } + d := push.New(store, logger.Logger) + if cfg.NtfyBaseURL != "" { + d.Register(pushntfy.New(pushntfy.Config{ + BaseURL: cfg.NtfyBaseURL, + AuthToken: cfg.NtfyAuthToken, + }, logger.Logger)) + logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: ntfy", + zap.String("base_url", cfg.NtfyBaseURL)) + } + if cfg.ExpoAccessToken != "" { + d.Register(pushexpo.New(pushexpo.Config{ + AccessToken: cfg.ExpoAccessToken, + }, logger.Logger)) + logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: expo") + } + return d, store, nil +} diff --git a/core/pkg/gateway/gateway.go b/core/pkg/gateway/gateway.go index 531f883..61bfa2b 100644 --- a/core/pkg/gateway/gateway.go +++ b/core/pkg/gateway/gateway.go @@ -28,6 +28,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache" deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments" pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub" + pushhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/push" serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" enrollhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/enroll" joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join" @@ -43,6 +44,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/serverless" + "github.com/DeBrosOfficial/network/pkg/serverless/triggers" _ "github.com/mattn/go-sqlite3" "go.uber.org/zap" ) @@ -83,13 +85,15 @@ type Gateway struct { mu sync.RWMutex presenceMu sync.RWMutex pubsubHandlers *pubsubhandlers.PubSubHandlers + pushHandlers *pushhandlers.Handlers // Serverless function engine - serverlessEngine *serverless.Engine - serverlessRegistry *serverless.Registry - serverlessInvoker *serverless.Invoker - serverlessWSMgr *serverless.WSManager - serverlessHandlers *serverlesshandlers.ServerlessHandlers + serverlessEngine *serverless.Engine + serverlessRegistry *serverless.Registry + serverlessInvoker *serverless.Invoker + serverlessWSMgr *serverless.WSManager + serverlessHandlers *serverlesshandlers.ServerlessHandlers + pubsubDispatcher *triggers.PubSubDispatcher // Authentication service authService *auth.Service @@ -342,11 +346,20 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { // Wire PubSub trigger dispatch if serverless is available if deps.PubSubDispatcher != nil { + gw.pubsubDispatcher = deps.PubSubDispatcher gw.pubsubHandlers.SetOnPublish(func(ctx context.Context, namespace, topic string, data []byte) { deps.PubSubDispatcher.Dispatch(ctx, namespace, topic, data, 0) }) } + // Push notification handlers — disabled when no provider is configured. + // The handlers themselves return 503 if dispatcher/store is nil; we + // register them unconditionally so the routes always exist with a + // predictable shape. + if deps.PushDispatcher != nil { + gw.pushHandlers = pushhandlers.NewHandlers(deps.PushDispatcher, deps.PushDeviceStore, logger) + } + if cfg.WebRTCEnabled && cfg.SFUPort > 0 { gw.webrtcHandlers = webrtchandlers.NewWebRTCHandlers( logger, diff --git a/core/pkg/gateway/handlers/pubsub/handlers_test.go b/core/pkg/gateway/handlers/pubsub/handlers_test.go index 71263b2..7e3b37c 100644 --- a/core/pkg/gateway/handlers/pubsub/handlers_test.go +++ b/core/pkg/gateway/handlers/pubsub/handlers_test.go @@ -21,10 +21,12 @@ import ( // mockPubSubClient implements client.PubSubClient for testing type mockPubSubClient struct { - PublishFunc func(ctx context.Context, topic string, data []byte) error - SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error - UnsubscribeFunc func(ctx context.Context, topic string) error - ListTopicsFunc func(ctx context.Context) ([]string, error) + PublishFunc func(ctx context.Context, topic string, data []byte) error + PublishBatchFunc func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error + PublishSameFunc func(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error + SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error + UnsubscribeFunc func(ctx context.Context, topic string) error + ListTopicsFunc func(ctx context.Context) ([]string, error) } func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byte) error { @@ -34,6 +36,20 @@ func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byt return nil } +func (m *mockPubSubClient) PublishBatch(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error { + if m.PublishBatchFunc != nil { + return m.PublishBatchFunc(ctx, msgs, opts) + } + return nil +} + +func (m *mockPubSubClient) PublishSame(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error { + if m.PublishSameFunc != nil { + return m.PublishSameFunc(ctx, topics, data, opts) + } + return nil +} + func (m *mockPubSubClient) Subscribe(ctx context.Context, topic string, handler client.MessageHandler) error { if m.SubscribeFunc != nil { return m.SubscribeFunc(ctx, topic, handler) diff --git a/core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go b/core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go new file mode 100644 index 0000000..21f05cc --- /dev/null +++ b/core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go @@ -0,0 +1,156 @@ +package pubsub + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/DeBrosOfficial/network/pkg/client" +) + +func TestPublishBatchHandler_invalid_method(t *testing.T) { + h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}}) + + req := withNamespace(httptest.NewRequest(http.MethodGet, "/v1/pubsub/publish-batch", nil), "ns") + rr := httptest.NewRecorder() + h.PublishBatchHandler(rr, req) + + if rr.Code != http.StatusMethodNotAllowed { + t.Errorf("expected 405, got %d", rr.Code) + } +} + +func TestPublishBatchHandler_missing_namespace(t *testing.T) { + h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}}) + + body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{{Topic: "a", DataB64: "AA=="}}}) + req := httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)) + rr := httptest.NewRecorder() + h.PublishBatchHandler(rr, req) + + if rr.Code != http.StatusForbidden { + t.Errorf("expected 403, got %d (body: %s)", rr.Code, rr.Body.String()) + } +} + +func TestPublishBatchHandler_empty_messages_rejected(t *testing.T) { + h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}}) + + body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{}}) + req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns") + rr := httptest.NewRecorder() + h.PublishBatchHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400 for empty messages, got %d", rr.Code) + } +} + +func TestPublishBatchHandler_oversize_batch_rejected(t *testing.T) { + h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}}) + + entries := make([]PublishBatchEntry, MaxPublishBatchSize+1) + for i := range entries { + entries[i] = PublishBatchEntry{Topic: "t", DataB64: "AA=="} + } + body, _ := json.Marshal(PublishBatchRequest{Messages: entries}) + req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns") + rr := httptest.NewRecorder() + h.PublishBatchHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400 for oversize batch, got %d", rr.Code) + } +} + +func TestPublishBatchHandler_invalid_base64_rejected(t *testing.T) { + h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}}) + + body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{ + {Topic: "good", DataB64: base64.StdEncoding.EncodeToString([]byte("ok"))}, + {Topic: "bad", DataB64: "!!!not-base64"}, + }}) + req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns") + rr := httptest.NewRecorder() + h.PublishBatchHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400 for invalid base64, got %d", rr.Code) + } +} + +func TestPublishBatchHandler_missing_topic_rejected(t *testing.T) { + h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}}) + + body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{ + {Topic: "", DataB64: base64.StdEncoding.EncodeToString([]byte("x"))}, + }}) + req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns") + rr := httptest.NewRecorder() + h.PublishBatchHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400 for missing topic, got %d", rr.Code) + } +} + +func TestPublishBatchHandler_happy_calls_PublishBatch(t *testing.T) { + var ( + called int32 + gotMessages []client.TopicMessage + mu sync.Mutex + ) + mock := &mockPubSubClient{ + PublishBatchFunc: func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error { + atomic.AddInt32(&called, 1) + mu.Lock() + gotMessages = append(gotMessages, msgs...) + mu.Unlock() + return nil + }, + } + h := newTestHandlers(&mockNetworkClient{pubsub: mock}) + + entries := []PublishBatchEntry{ + {Topic: "a", DataB64: base64.StdEncoding.EncodeToString([]byte("data-a"))}, + {Topic: "b", DataB64: base64.StdEncoding.EncodeToString([]byte("data-b"))}, + } + body, _ := json.Marshal(PublishBatchRequest{Messages: entries}) + req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "test-ns") + rr := httptest.NewRecorder() + + h.PublishBatchHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String()) + } + + // PublishBatch is invoked from a goroutine; give it a moment to run. + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&called) == 0 { + if time.Now().After(deadline) { + t.Fatal("PublishBatch was not called within 2s") + } + time.Sleep(10 * time.Millisecond) + } + + mu.Lock() + defer mu.Unlock() + if len(gotMessages) != 2 { + t.Fatalf("expected 2 messages forwarded, got %d", len(gotMessages)) + } + if gotMessages[0].Topic != "a" || string(gotMessages[0].Data) != "data-a" { + t.Errorf("unexpected first message: %+v", gotMessages[0]) + } + if gotMessages[1].Topic != "b" || string(gotMessages[1].Data) != "data-b" { + t.Errorf("unexpected second message: %+v", gotMessages[1]) + } +} + diff --git a/core/pkg/gateway/handlers/pubsub/publish_handler.go b/core/pkg/gateway/handlers/pubsub/publish_handler.go index a3cedd5..7c7e9fc 100644 --- a/core/pkg/gateway/handlers/pubsub/publish_handler.go +++ b/core/pkg/gateway/handlers/pubsub/publish_handler.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "net/http" + "strconv" "time" "github.com/DeBrosOfficial/network/pkg/client" @@ -12,6 +13,10 @@ import ( "go.uber.org/zap" ) +// MaxPublishBatchSize is the maximum number of messages allowed in a single +// /v1/pubsub/publish-batch request. Mirrors pubsub.MaxBatchSize. +const MaxPublishBatchSize = pubsub.MaxBatchSize + // PublishHandler handles POST /v1/pubsub/publish {topic, data_base64} func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request) { if p.client == nil { @@ -39,9 +44,133 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request) return } - // Check for local websocket subscribers FIRST and deliver directly + p.deliverLocal(ns, body.Topic, data) + + // Publish to libp2p asynchronously for cross-node delivery. + go func() { + publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns) + if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil { + p.logger.ComponentWarn("gateway", "async libp2p publish failed", + zap.String("topic", body.Topic), + zap.Error(err)) + } + }() + + writeJSON(w, http.StatusOK, map[string]any{"status": "ok"}) +} + +// PublishBatchRequest is the request body for POST /v1/pubsub/publish-batch. +type PublishBatchRequest struct { + Messages []PublishBatchEntry `json:"messages"` + BestEffort bool `json:"best_effort,omitempty"` +} + +// PublishBatchEntry is one message in a batch publish request. +type PublishBatchEntry struct { + Topic string `json:"topic"` + DataB64 string `json:"data_base64"` +} + +// PublishBatchResponse is the response body for /v1/pubsub/publish-batch. +// +// libp2p delivery is asynchronous and not awaited here, mirroring the +// single-publish handler's fire-and-forget contract. Per-topic failures +// are not surfaced via this response — operators should consult logs / +// metrics for delivery health. +type PublishBatchResponse struct { + Status string `json:"status"` // always "ok" — request was accepted +} + +// MaxPerMessageBytes caps an individual message payload inside a batch. +// Mirrors the 1MB cap on /v1/pubsub/publish. +const MaxPerMessageBytes = 1 << 20 + +// PublishBatchHandler handles POST /v1/pubsub/publish-batch. +// Accepts up to MaxPublishBatchSize messages and publishes them in parallel, +// preserving namespace isolation. Local subscribers receive messages +// immediately; libp2p delivery is async. +func (p *PubSubHandlers) PublishBatchHandler(w http.ResponseWriter, r *http.Request) { + if p.client == nil { + writeError(w, http.StatusServiceUnavailable, "client not initialized") + return + } + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + // Limit body size: MaxPublishBatchSize messages * ~1MB each = up to ~100MB. + // Cap conservatively at 16MB to discourage huge payloads. + r.Body = http.MaxBytesReader(w, r.Body, 16<<20) + + var body PublishBatchRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, "invalid body: expected {messages:[{topic,data_base64}]}") + return + } + if len(body.Messages) == 0 { + writeError(w, http.StatusBadRequest, "messages required") + return + } + if len(body.Messages) > MaxPublishBatchSize { + writeError(w, http.StatusBadRequest, "too many messages: max is 100 per batch") + return + } + + // Decode all messages up-front so we can fail fast on bad input. + decoded := make([]pubsub.TopicMessage, 0, len(body.Messages)) + for i, m := range body.Messages { + if m.Topic == "" { + writeError(w, http.StatusBadRequest, "message missing topic at index "+strconv.Itoa(i)) + return + } + data, err := base64.StdEncoding.DecodeString(m.DataB64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid base64 data at index "+strconv.Itoa(i)) + return + } + if len(data) > MaxPerMessageBytes { + writeError(w, http.StatusBadRequest, "message too large at index "+strconv.Itoa(i)) + return + } + decoded = append(decoded, pubsub.TopicMessage{Topic: m.Topic, Data: data}) + } + + // Deliver locally + dispatch triggers per topic synchronously (fast in-process). + for _, msg := range decoded { + p.deliverLocal(ns, msg.Topic, msg.Data) + } + + // Async libp2p batch publish, similar to PublishHandler's approach. + go func() { + publishCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns) + opts := pubsub.PublishBatchOptions{BestEffort: body.BestEffort} + err := p.client.PubSub().PublishBatch(ctx, toClientMessages(decoded), clientOpts(opts)) + if err != nil { + p.logger.ComponentWarn("gateway", "async libp2p batch publish failed", + zap.Int("messages", len(decoded)), + zap.Error(err)) + } + }() + + writeJSON(w, http.StatusOK, PublishBatchResponse{Status: "ok"}) +} + +// deliverLocal handles local-subscriber delivery and fires PubSub triggers. +// It does NOT publish to libp2p — callers handle that themselves (single +// or batched) so this helper stays focused on in-process fan-out. +func (p *PubSubHandlers) deliverLocal(ns, topic string, data []byte) { p.mu.RLock() - localSubs := p.getLocalSubscribers(body.Topic, ns) + localSubs := p.getLocalSubscribers(topic, ns) p.mu.RUnlock() localDeliveryCount := 0 @@ -50,48 +179,38 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request) select { case sub.msgChan <- data: localDeliveryCount++ - p.logger.ComponentDebug("gateway", "delivered to local subscriber", - zap.String("topic", body.Topic)) default: - // Drop if buffer full p.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message", - zap.String("topic", body.Topic)) + zap.String("topic", topic)) } } } p.logger.ComponentInfo("gateway", "pubsub publish: processing message", - zap.String("topic", body.Topic), + zap.String("topic", topic), zap.String("namespace", ns), zap.Int("data_len", len(data)), zap.Int("local_subscribers", len(localSubs)), zap.Int("local_delivered", localDeliveryCount)) - // Fire PubSub triggers for serverless functions (non-blocking) + // Fire PubSub triggers for serverless functions (non-blocking). if p.onPublish != nil { - go p.onPublish(context.Background(), ns, body.Topic, data) + go p.onPublish(context.Background(), ns, topic, data) } +} - // Publish to libp2p asynchronously for cross-node delivery - // This prevents blocking the HTTP response if libp2p network is slow - go func() { - publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() +// toClientMessages converts pubsub.TopicMessage to client.TopicMessage for +// passing through the PubSubClient interface. +func toClientMessages(msgs []pubsub.TopicMessage) []client.TopicMessage { + out := make([]client.TopicMessage, len(msgs)) + for i, m := range msgs { + out[i] = client.TopicMessage{Topic: m.Topic, Data: m.Data} + } + return out +} - ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns) - if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil { - p.logger.ComponentWarn("gateway", "async libp2p publish failed", - zap.String("topic", body.Topic), - zap.Error(err)) - } else { - p.logger.ComponentDebug("gateway", "async libp2p publish succeeded", - zap.String("topic", body.Topic)) - } - }() - - // Return immediately after local delivery - // Local WebSocket subscribers already received the message - writeJSON(w, http.StatusOK, map[string]any{"status": "ok"}) +func clientOpts(o pubsub.PublishBatchOptions) client.PublishBatchOptions { + return client.PublishBatchOptions{BestEffort: o.BestEffort, MaxConcurrency: o.MaxConcurrency} } // TopicsHandler lists topics within the caller's namespace diff --git a/core/pkg/gateway/handlers/push/handlers.go b/core/pkg/gateway/handlers/push/handlers.go new file mode 100644 index 0000000..36eff7b --- /dev/null +++ b/core/pkg/gateway/handlers/push/handlers.go @@ -0,0 +1,291 @@ +package push + +import ( + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/push" + "go.uber.org/zap" +) + +// validProviders is the allowlist for the `provider` field on RegisterDevice. +// Keep in sync with what the dispatcher actually has registered at startup. +var validProviders = map[string]struct{}{ + "ntfy": {}, + "expo": {}, + "apns": {}, // future — accepted at registration so apps can pre-flight +} + +// MaxTokenBytes caps the device-token length to prevent abuse. +// Real ntfy topic paths and Expo tokens are well under this. +const MaxTokenBytes = 512 + +// RegisterDeviceHandler handles POST /v1/push/devices. +// +// The caller must be authenticated; their JWT subject (Sub) is used as the +// user_id. API-key callers are allowed only if the body explicitly carries +// a user_id — currently rejected to keep the surface small. +func (h *Handlers) RegisterDeviceHandler(w http.ResponseWriter, r *http.Request) { + if h.store == nil { + writeError(w, http.StatusServiceUnavailable, "push: device store not configured") + return + } + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + ns := resolveNamespace(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + userID := resolveCallerUserID(r) + if userID == "" { + // We require a JWT-authenticated user to bind the device to. + // API-key-only callers can't register devices on behalf of users. + writeError(w, http.StatusUnauthorized, "user authentication required") + return + } + + r.Body = http.MaxBytesReader(w, r.Body, 4096) + var body RegisterDeviceRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, "invalid body") + return + } + body.DeviceID = strings.TrimSpace(body.DeviceID) + body.Provider = strings.TrimSpace(body.Provider) + body.Token = strings.TrimSpace(body.Token) + + if body.DeviceID == "" { + writeError(w, http.StatusBadRequest, "device_id required") + return + } + if _, ok := validProviders[body.Provider]; !ok { + writeError(w, http.StatusBadRequest, "unknown provider: "+body.Provider) + return + } + if body.Token == "" { + writeError(w, http.StatusBadRequest, "token required") + return + } + if len(body.Token) > MaxTokenBytes { + writeError(w, http.StatusBadRequest, "token too long") + return + } + + now := time.Now().Unix() + dev := push.PushDevice{ + Namespace: ns, + UserID: userID, + DeviceID: body.DeviceID, + Provider: body.Provider, + Token: body.Token, + Platform: body.Platform, + AppVer: body.AppVersion, + LastSeen: now, + } + if err := h.store.Upsert(boundCtx(r), dev); err != nil { + h.logger.ComponentWarn("push", "device upsert failed", + zap.String("namespace", ns), + zap.String("user_id", userID), + zap.Error(err)) + writeError(w, http.StatusInternalServerError, "registration failed") + return + } + + writeJSON(w, http.StatusOK, RegisterDeviceResponse{Status: "ok"}) +} + +// ListDevicesHandler handles GET /v1/push/devices. +// +// Returns the caller's own devices; tokens are NEVER included in the +// response. Other namespaces / other users are inaccessible. +func (h *Handlers) ListDevicesHandler(w http.ResponseWriter, r *http.Request) { + if h.store == nil { + writeError(w, http.StatusServiceUnavailable, "push: device store not configured") + return + } + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + ns := resolveNamespace(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + userID := resolveCallerUserID(r) + if userID == "" { + writeError(w, http.StatusUnauthorized, "user authentication required") + return + } + + devs, err := h.store.ListForUser(boundCtx(r), ns, userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "list failed") + return + } + views := make([]PushDeviceView, len(devs)) + for i, d := range devs { + views[i] = PushDeviceView{ + ID: d.ID, + DeviceID: d.DeviceID, + Provider: d.Provider, + Platform: d.Platform, + AppVersion: d.AppVer, + CreatedAt: d.CreatedAt, + UpdatedAt: d.UpdatedAt, + LastSeen: d.LastSeen, + } + } + writeJSON(w, http.StatusOK, map[string]interface{}{"devices": views}) +} + +// DeleteDeviceHandler handles DELETE /v1/push/devices/{id}. +// +// `{id}` is the database row ID returned at registration / by ListDevices. +// Only devices belonging to the caller (matched by namespace + user_id + +// the device ID lookup) can be deleted. +func (h *Handlers) DeleteDeviceHandler(w http.ResponseWriter, r *http.Request) { + if h.store == nil { + writeError(w, http.StatusServiceUnavailable, "push: device store not configured") + return + } + if r.Method != http.MethodDelete { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + ns := resolveNamespace(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + userID := resolveCallerUserID(r) + if userID == "" { + writeError(w, http.StatusUnauthorized, "user authentication required") + return + } + + id := extractIDFromPath(r.URL.Path, "/v1/push/devices/") + if id == "" { + writeError(w, http.StatusBadRequest, "device id required in path") + return + } + + // Authorization check: confirm the device belongs to the caller. + devs, err := h.store.ListForUser(boundCtx(r), ns, userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "lookup failed") + return + } + owns := false + for _, d := range devs { + if d.ID == id { + owns = true + break + } + } + if !owns { + // 404, not 403 — don't leak whether the ID exists in another scope. + writeError(w, http.StatusNotFound, "not found") + return + } + + if err := h.store.Delete(boundCtx(r), ns, id); err != nil { + h.logger.ComponentWarn("push", "device delete failed", + zap.String("namespace", ns), + zap.String("device_row_id", id), + zap.Error(err)) + writeError(w, http.StatusInternalServerError, "delete failed") + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) +} + +// SendHandler handles POST /v1/push/send. +// +// SECURITY: this endpoint sends arbitrary push messages to any user_id +// in the caller's namespace. It MUST be gated to a small set of trusted +// callers — typically only the namespace's own serverless functions +// (which can send via the WASM `push_send` hostfunc directly without +// going through HTTP) and the namespace operator. +// +// The current implementation accepts any JWT-authenticated caller within +// the namespace. **Add an explicit allow-list or admin-scope check before +// exposing this in production.** The WASM hostfunc bypasses this issue +// because trigger registration already gates which functions exist. +func (h *Handlers) SendHandler(w http.ResponseWriter, r *http.Request) { + if h.dispatcher == nil { + writeError(w, http.StatusServiceUnavailable, "push: dispatcher not configured") + return + } + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + ns := resolveNamespace(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + if resolveCallerUserID(r) == "" { + writeError(w, http.StatusUnauthorized, "user authentication required") + return + } + + r.Body = http.MaxBytesReader(w, r.Body, 64*1024) // generous for Data payloads + var body SendRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, "invalid body") + return + } + body.UserID = strings.TrimSpace(body.UserID) + if body.UserID == "" { + writeError(w, http.StatusBadRequest, "user_id required") + return + } + + msg := push.PushMessage{ + Title: body.Title, + Body: body.Body, + Channel: body.Channel, + Priority: pickPriority(body.Priority), + Badge: body.Badge, + Sound: body.Sound, + Data: body.Data, + } + if err := h.dispatcher.SendToUser(boundCtx(r), ns, body.UserID, msg); err != nil { + // Treat as non-fatal: some devices may have failed but others may + // have succeeded. Surface as 502 to signal partial trouble; logs + // have the per-device detail. + h.logger.ComponentWarn("push", "send to user partially failed", + zap.String("namespace", ns), + zap.String("user_id", body.UserID), + zap.Error(err)) + writeError(w, http.StatusBadGateway, "one or more devices failed") + return + } + writeJSON(w, http.StatusOK, SendResponse{Status: "ok"}) +} + +// extractIDFromPath returns the trailing path segment after `prefix`, or +// empty string if the path doesn't match. Used because the gateway uses +// the standard `net/http` mux which doesn't extract path params. +func extractIDFromPath(urlPath, prefix string) string { + if !strings.HasPrefix(urlPath, prefix) { + return "" + } + rest := urlPath[len(prefix):] + // Drop any query string (shouldn't normally appear in path here). + if i := strings.IndexAny(rest, "?#/"); i >= 0 { + rest = rest[:i] + } + return rest +} diff --git a/core/pkg/gateway/handlers/push/handlers_test.go b/core/pkg/gateway/handlers/push/handlers_test.go new file mode 100644 index 0000000..19509d4 --- /dev/null +++ b/core/pkg/gateway/handlers/push/handlers_test.go @@ -0,0 +1,330 @@ +package push + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + authsvc "github.com/DeBrosOfficial/network/pkg/gateway/auth" + "github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys" + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/push" + "go.uber.org/zap" +) + +// fakeStore is an in-memory PushDeviceStore for tests. +type fakeStore struct { + devices []push.PushDevice + upsertFn func(push.PushDevice) error + deleteFn func(ns, id string) error + listErr error +} + +func (s *fakeStore) Upsert(ctx context.Context, dev push.PushDevice) error { + if s.upsertFn != nil { + return s.upsertFn(dev) + } + if dev.ID == "" { + dev.ID = "row-" + dev.DeviceID + } + s.devices = append(s.devices, dev) + return nil +} +func (s *fakeStore) Delete(ctx context.Context, ns, id string) error { + if s.deleteFn != nil { + return s.deleteFn(ns, id) + } + for i, d := range s.devices { + if d.ID == id && d.Namespace == ns { + s.devices = append(s.devices[:i], s.devices[i+1:]...) + return nil + } + } + return errors.New("not found") +} +func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]push.PushDevice, error) { + if s.listErr != nil { + return nil, s.listErr + } + out := []push.PushDevice{} + for _, d := range s.devices { + if d.Namespace == ns && d.UserID == userID { + out = append(out, d) + } + } + return out, nil +} + +// withAuth populates the namespace + JWT claims (caller user ID). +func withAuth(r *http.Request, namespace, userID string) *http.Request { + ctx := r.Context() + if namespace != "" { + ctx = context.WithValue(ctx, ctxkeys.NamespaceOverride, namespace) + } + if userID != "" { + ctx = context.WithValue(ctx, ctxkeys.JWT, &authsvc.JWTClaims{Sub: userID, Namespace: namespace}) + } + return r.WithContext(ctx) +} + +func newHandlers(store push.PushDeviceStore, dispatcher *push.PushDispatcher) *Handlers { + logger := &logging.ColoredLogger{Logger: zap.NewNop()} + return NewHandlers(dispatcher, store, logger) +} + +// --- RegisterDeviceHandler --- + +func TestRegister_happy_path(t *testing.T) { + store := &fakeStore{} + h := newHandlers(store, nil) + + body, _ := json.Marshal(RegisterDeviceRequest{ + DeviceID: "iphone-abc", + Provider: "ntfy", + Token: "ns/myapp/user-1", + Platform: "ios", + }) + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "myapp", "user-1") + rr := httptest.NewRecorder() + h.RegisterDeviceHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String()) + } + if len(store.devices) != 1 { + t.Fatalf("expected 1 device stored, got %d", len(store.devices)) + } + d := store.devices[0] + if d.Namespace != "myapp" || d.UserID != "user-1" || d.Token != "ns/myapp/user-1" { + t.Errorf("unexpected device: %+v", d) + } +} + +func TestRegister_unauthenticated_rejected(t *testing.T) { + h := newHandlers(&fakeStore{}, nil) + body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: "t"}) + + // No JWT in context. + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "") + rr := httptest.NewRecorder() + h.RegisterDeviceHandler(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rr.Code) + } +} + +func TestRegister_unknown_provider_rejected(t *testing.T) { + h := newHandlers(&fakeStore{}, nil) + body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "weirdmail", Token: "t"}) + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u") + rr := httptest.NewRecorder() + h.RegisterDeviceHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", rr.Code) + } +} + +func TestRegister_oversize_token_rejected(t *testing.T) { + h := newHandlers(&fakeStore{}, nil) + huge := make([]byte, MaxTokenBytes+1) + for i := range huge { + huge[i] = 'a' + } + body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: string(huge)}) + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u") + rr := httptest.NewRecorder() + h.RegisterDeviceHandler(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", rr.Code) + } +} + +func TestRegister_no_store_returns_503(t *testing.T) { + h := newHandlers(nil, nil) + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader([]byte(`{}`))), "ns", "u") + rr := httptest.NewRecorder() + h.RegisterDeviceHandler(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Errorf("expected 503, got %d", rr.Code) + } +} + +// --- ListDevicesHandler --- + +func TestList_returns_only_callers_devices_without_tokens(t *testing.T) { + store := &fakeStore{ + devices: []push.PushDevice{ + {ID: "1", Namespace: "myapp", UserID: "u1", DeviceID: "d1", Provider: "ntfy", Token: "secret-token-1"}, + {ID: "2", Namespace: "myapp", UserID: "u1", DeviceID: "d2", Provider: "expo", Token: "secret-token-2"}, + {ID: "3", Namespace: "myapp", UserID: "u2", DeviceID: "d3", Provider: "ntfy", Token: "secret-token-3"}, + {ID: "4", Namespace: "other", UserID: "u1", DeviceID: "d4", Provider: "ntfy", Token: "secret-token-4"}, + }, + } + h := newHandlers(store, nil) + + req := withAuth(httptest.NewRequest(http.MethodGet, "/v1/push/devices", nil), "myapp", "u1") + rr := httptest.NewRecorder() + h.ListDevicesHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + var resp struct { + Devices []PushDeviceView `json:"devices"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("decode: %v", err) + } + if len(resp.Devices) != 2 { + t.Fatalf("expected 2 devices, got %d", len(resp.Devices)) + } + // Tokens must NOT appear in response — they're not even in the struct. + if bytes.Contains(rr.Body.Bytes(), []byte("secret-token")) { + t.Errorf("response leaked a token: %s", rr.Body.String()) + } +} + +// --- DeleteDeviceHandler --- + +func TestDelete_owns_device_succeeds(t *testing.T) { + store := &fakeStore{ + devices: []push.PushDevice{ + {ID: "row-d1", Namespace: "myapp", UserID: "u1", DeviceID: "d1"}, + }, + } + h := newHandlers(store, nil) + + req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1") + rr := httptest.NewRecorder() + h.DeleteDeviceHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String()) + } + if len(store.devices) != 0 { + t.Errorf("expected device removed") + } +} + +func TestDelete_other_users_device_returns_404(t *testing.T) { + store := &fakeStore{ + devices: []push.PushDevice{ + {ID: "row-d1", Namespace: "myapp", UserID: "other-user", DeviceID: "d1"}, + }, + } + h := newHandlers(store, nil) + + req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1") + rr := httptest.NewRecorder() + h.DeleteDeviceHandler(rr, req) + + if rr.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", rr.Code) + } + if len(store.devices) != 1 { + t.Errorf("expected device NOT removed") + } +} + +func TestDelete_missing_id_returns_400(t *testing.T) { + h := newHandlers(&fakeStore{}, nil) + req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/", nil), "myapp", "u1") + rr := httptest.NewRecorder() + h.DeleteDeviceHandler(rr, req) + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", rr.Code) + } +} + +// --- SendHandler --- + +func TestSend_dispatcher_called_for_user(t *testing.T) { + var sent int32 + dispatcher := push.New(&fakeStore{ + devices: []push.PushDevice{ + {ID: "row-1", Namespace: "myapp", UserID: "target-user", Provider: "fake", Token: "tok"}, + }, + }, zap.NewNop()) + dispatcher.Register(&fakePushProvider{ + name: "fake", + fn: func(ctx context.Context, msg push.PushMessage) error { atomic.AddInt32(&sent, 1); return nil }, + }) + + h := newHandlers(&fakeStore{}, dispatcher) + + body, _ := json.Marshal(SendRequest{ + UserID: "target-user", Title: "hi", Body: "world", + }) + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1") + rr := httptest.NewRecorder() + h.SendHandler(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String()) + } + if atomic.LoadInt32(&sent) != 1 { + t.Errorf("expected provider called once, got %d", sent) + } +} + +func TestSend_no_dispatcher_returns_503(t *testing.T) { + h := newHandlers(&fakeStore{}, nil) + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader([]byte(`{"user_id":"u"}`))), "myapp", "u1") + rr := httptest.NewRecorder() + h.SendHandler(rr, req) + if rr.Code != http.StatusServiceUnavailable { + t.Errorf("expected 503, got %d", rr.Code) + } +} + +func TestSend_missing_user_id_returns_400(t *testing.T) { + dispatcher := push.New(&fakeStore{}, zap.NewNop()) + h := newHandlers(&fakeStore{}, dispatcher) + + body, _ := json.Marshal(SendRequest{}) + req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1") + rr := httptest.NewRecorder() + h.SendHandler(rr, req) + if rr.Code != http.StatusBadRequest { + t.Errorf("expected 400, got %d", rr.Code) + } +} + +// --- helpers --- + +type fakePushProvider struct { + name string + fn func(ctx context.Context, msg push.PushMessage) error +} + +func (p *fakePushProvider) Name() string { return p.name } +func (p *fakePushProvider) Send(ctx context.Context, msg push.PushMessage) error { + if p.fn != nil { + return p.fn(ctx, msg) + } + return nil +} + +func TestExtractIDFromPath(t *testing.T) { + cases := []struct { + path, prefix, want string + }{ + {"/v1/push/devices/abc", "/v1/push/devices/", "abc"}, + {"/v1/push/devices/abc?x=1", "/v1/push/devices/", "abc"}, + {"/v1/push/devices/", "/v1/push/devices/", ""}, + {"/v1/other/abc", "/v1/push/devices/", ""}, + } + for _, c := range cases { + if got := extractIDFromPath(c.path, c.prefix); got != c.want { + t.Errorf("extractIDFromPath(%q, %q) = %q, want %q", c.path, c.prefix, got, c.want) + } + } +} diff --git a/core/pkg/gateway/handlers/push/types.go b/core/pkg/gateway/handlers/push/types.go new file mode 100644 index 0000000..8410409 --- /dev/null +++ b/core/pkg/gateway/handlers/push/types.go @@ -0,0 +1,150 @@ +// Package push provides HTTP handlers for managing push-notification +// device registrations and sending pushes. +// +// Endpoints: +// +// GET /v1/push/devices — list caller's registered devices (tokens omitted) +// POST /v1/push/devices — register / update a device +// DELETE /v1/push/devices/{id} — unregister a device +// POST /v1/push/send — send a push to a user (admin/internal scope) +// +// Device tokens are stored AES-256-GCM-encrypted in RQLite via the +// pkg/push.RqliteDeviceStore. Tokens are NEVER returned by any endpoint — +// the GET endpoint omits the token field for safety. +package push + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/DeBrosOfficial/network/pkg/gateway/auth" + "github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys" + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/push" +) + +// Handlers serves the /v1/push/* HTTP endpoints. Construct via NewHandlers; +// it's safe for concurrent use. +type Handlers struct { + dispatcher *push.PushDispatcher + store push.PushDeviceStore + logger *logging.ColoredLogger +} + +// NewHandlers constructs a Handlers. Either argument may be nil — in which +// case the corresponding endpoints return 503 Service Unavailable. +func NewHandlers(dispatcher *push.PushDispatcher, store push.PushDeviceStore, logger *logging.ColoredLogger) *Handlers { + return &Handlers{ + dispatcher: dispatcher, + store: store, + logger: logger, + } +} + +// RegisterDeviceRequest is the body of POST /v1/push/devices. +// +// `device_id` is an app-supplied stable identifier (e.g. the OS-assigned +// device UUID). Combined with (namespace, user_id) it uniquely identifies +// the registration; re-posting with the same device_id updates the token. +// +// `token` is provider-specific: +// - ntfy: the topic path the device subscribes to (e.g. "ns/myapp/user-1") +// - expo: an ExponentPushToken[...] +// - apns: a hex APNs device token (future) +type RegisterDeviceRequest struct { + DeviceID string `json:"device_id"` + Provider string `json:"provider"` // "ntfy" | "expo" | "apns" + Token string `json:"token"` + Platform string `json:"platform,omitempty"` // "ios" | "android" | "web" + AppVersion string `json:"app_version,omitempty"` +} + +// RegisterDeviceResponse is the body of POST /v1/push/devices. +type RegisterDeviceResponse struct { + Status string `json:"status"` +} + +// PushDeviceView is the safe (token-omitting) representation returned +// by GET /v1/push/devices. +type PushDeviceView struct { + ID string `json:"id"` + DeviceID string `json:"device_id"` + Provider string `json:"provider"` + Platform string `json:"platform,omitempty"` + AppVersion string `json:"app_version,omitempty"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + LastSeen int64 `json:"last_seen,omitempty"` +} + +// SendRequest is the body of POST /v1/push/send. +// +// The dispatcher fans out to all of `user_id`'s registered devices in +// the caller's namespace. Auth scope: see SendHandler — currently +// requires the caller to act on behalf of their own namespace; finer +// per-user authorization is the app's responsibility. +type SendRequest struct { + UserID string `json:"user_id"` + Title string `json:"title"` + Body string `json:"body"` + Channel string `json:"channel,omitempty"` + Priority string `json:"priority,omitempty"` // "high" | "normal" | "" (default) + Badge int `json:"badge,omitempty"` + Sound string `json:"sound,omitempty"` + Data map[string]interface{} `json:"data,omitempty"` +} + +// SendResponse is the body of POST /v1/push/send. +type SendResponse struct { + Status string `json:"status"` +} + +// resolveNamespace pulls the namespace set by auth middleware out of context. +func resolveNamespace(r *http.Request) string { + if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// resolveCallerUserID extracts the JWT subject (typically the wallet) of +// the caller, or empty if the request was authenticated by API key only. +func resolveCallerUserID(r *http.Request) string { + if v := r.Context().Value(ctxkeys.JWT); v != nil { + if claims, ok := v.(*auth.JWTClaims); ok && claims != nil { + return claims.Sub + } + } + return "" +} + +func writeError(w http.ResponseWriter, code int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(map[string]string{"error": message}) +} + +func writeJSON(w http.ResponseWriter, code int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(v) +} + +// pickPriority maps the wire-format priority string to the typed enum. +func pickPriority(s string) push.PushPriority { + switch s { + case "high": + return push.PriorityHigh + case "normal": + return push.PriorityNormal + default: + return push.PriorityNormal + } +} + +// boundCtx returns a request-scoped context with no extra wrapping; +// kept as a seam for future scope (rate-limit context etc.). +func boundCtx(r *http.Request) context.Context { return r.Context() } diff --git a/core/pkg/gateway/handlers/webrtc/credentials.go b/core/pkg/gateway/handlers/webrtc/credentials.go index 405b734..446323a 100644 --- a/core/pkg/gateway/handlers/webrtc/credentials.go +++ b/core/pkg/gateway/handlers/webrtc/credentials.go @@ -42,6 +42,11 @@ func (h *WebRTCHandlers) CredentialsHandler(w http.ResponseWriter, r *http.Reque fmt.Sprintf("turns:%s:5349", h.turnDomain), ) } + // Stealth: TURNS via the SNI router on :443. Looks like ordinary HTTPS + // to a passive observer / DPI; usable in restricted regions. + if h.stealthCDNDomain != "" { + uris = append(uris, fmt.Sprintf("turns:%s:443", h.stealthCDNDomain)) + } h.logger.ComponentInfo(logging.ComponentGeneral, "Issued TURN credentials", zap.String("namespace", ns), diff --git a/core/pkg/gateway/handlers/webrtc/types.go b/core/pkg/gateway/handlers/webrtc/types.go index 62167f0..2580f59 100644 --- a/core/pkg/gateway/handlers/webrtc/types.go +++ b/core/pkg/gateway/handlers/webrtc/types.go @@ -17,10 +17,21 @@ type WebRTCHandlers struct { turnDomain string // TURN server domain for building URIs turnSecret string // HMAC-SHA1 shared secret for TURN credential generation + // stealthCDNDomain, when non-empty, causes CredentialsHandler to also + // advertise turns://:443 — the stealth TURN URI served + // via the in-house SNI router. See pkg/sniproxy. + stealthCDNDomain string + // proxyWebSocket is injected from the gateway to reuse its WebSocket proxy logic proxyWebSocket func(w http.ResponseWriter, r *http.Request, targetHost string) bool } +// SetStealthCDNDomain enables the stealth TURN URI in CredentialsHandler. +// Pass empty string to disable. Safe to call before serving begins. +func (h *WebRTCHandlers) SetStealthCDNDomain(domain string) { + h.stealthCDNDomain = domain +} + // NewWebRTCHandlers creates a new WebRTCHandlers instance. func NewWebRTCHandlers( logger *logging.ColoredLogger, diff --git a/core/pkg/gateway/lifecycle.go b/core/pkg/gateway/lifecycle.go index 049336d..4fc0de7 100644 --- a/core/pkg/gateway/lifecycle.go +++ b/core/pkg/gateway/lifecycle.go @@ -12,6 +12,23 @@ import ( // It closes the serverless engine, network client, database connections, // Olric cache client, and IPFS client in sequence. func (g *Gateway) Close() { + // Flush PubSub aggregator buffers before tearing down the engine. + // Pending events are dispatched via the invoker which still needs the + // engine to be alive, so this MUST happen before the engine close. + // Aggregator state is local to this node — events not flushed here are + // lost (intended trade-off for high-frequency lossy streams). + if g.pubsubDispatcher != nil { + if agg := g.pubsubDispatcher.Aggregator(); agg != nil { + // 5s budget — same as the engine close timeout below. + // In-flight flushes call back into the invoker which still + // needs the engine to be alive. + if !agg.Shutdown(5 * time.Second) { + g.logger.ComponentWarn(logging.ComponentGeneral, + "PubSub aggregator shutdown timed out; some buffered events may be lost") + } + } + } + // Close serverless engine first if g.serverlessEngine != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/core/pkg/gateway/middleware.go b/core/pkg/gateway/middleware.go index 1cb5a07..00aedb1 100644 --- a/core/pkg/gateway/middleware.go +++ b/core/pkg/gateway/middleware.go @@ -643,6 +643,9 @@ func requiresNamespaceOwnership(p string) bool { if strings.HasPrefix(p, "/v1/webrtc/") { return true } + if strings.HasPrefix(p, "/v1/push/") { + return true + } return false } diff --git a/core/pkg/gateway/routes.go b/core/pkg/gateway/routes.go index 4d3cd08..a03762c 100644 --- a/core/pkg/gateway/routes.go +++ b/core/pkg/gateway/routes.go @@ -119,10 +119,30 @@ func (g *Gateway) Routes() http.Handler { if g.pubsubHandlers != nil { mux.HandleFunc("/v1/pubsub/ws", g.pubsubHandlers.WebsocketHandler) mux.HandleFunc("/v1/pubsub/publish", g.pubsubHandlers.PublishHandler) + mux.HandleFunc("/v1/pubsub/publish-batch", g.pubsubHandlers.PublishBatchHandler) mux.HandleFunc("/v1/pubsub/topics", g.pubsubHandlers.TopicsHandler) mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler) } + // push notifications + if g.pushHandlers != nil { + // GET + POST share the path; the handler dispatches by method. + mux.HandleFunc("/v1/push/devices", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + g.pushHandlers.ListDevicesHandler(w, r) + case http.MethodPost: + g.pushHandlers.RegisterDeviceHandler(w, r) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + }) + // DELETE /v1/push/devices/{id} — uses path-prefix routing because + // net/http mux doesn't extract path params; the handler parses {id}. + mux.HandleFunc("/v1/push/devices/", g.pushHandlers.DeleteDeviceHandler) + mux.HandleFunc("/v1/push/send", g.pushHandlers.SendHandler) + } + // operator node management (wallet JWT auth via middleware) if g.operatorHandler != nil { mux.HandleFunc("/v1/operator/invite", g.operatorHandler.HandleInvite) diff --git a/core/pkg/logging/logger.go b/core/pkg/logging/logger.go index 4b78345..a741ae4 100644 --- a/core/pkg/logging/logger.go +++ b/core/pkg/logging/logger.go @@ -57,6 +57,7 @@ const ( ComponentGateway Component = "GATEWAY" ComponentSFU Component = "SFU" ComponentTURN Component = "TURN" + ComponentSNI Component = "SNI" ) // getComponentColor returns the color for a specific component diff --git a/core/pkg/pubsub/adapter.go b/core/pkg/pubsub/adapter.go index de8f4c5..3e03097 100644 --- a/core/pkg/pubsub/adapter.go +++ b/core/pkg/pubsub/adapter.go @@ -29,6 +29,18 @@ func (a *ClientAdapter) Publish(ctx context.Context, topic string, data []byte) return a.manager.Publish(ctx, topic, data) } +// PublishBatch publishes multiple messages in parallel. +// See Manager.PublishBatch for semantics. +func (a *ClientAdapter) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error { + return a.manager.PublishBatch(ctx, msgs, opts) +} + +// PublishSame sends the same payload to every topic in parallel. +// See Manager.PublishSame for semantics. +func (a *ClientAdapter) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error { + return a.manager.PublishSame(ctx, topics, data, opts) +} + // Unsubscribe unsubscribes from a topic func (a *ClientAdapter) Unsubscribe(ctx context.Context, topic string) error { return a.manager.Unsubscribe(ctx, topic) diff --git a/core/pkg/pubsub/publish.go b/core/pkg/pubsub/publish.go index 3fb309a..56c7163 100644 --- a/core/pkg/pubsub/publish.go +++ b/core/pkg/pubsub/publish.go @@ -3,9 +3,56 @@ package pubsub import ( "context" "fmt" + "strings" + "sync" "time" + + "golang.org/x/sync/errgroup" ) +// defaultBatchConcurrency is the default cap on in-flight publishes within a single batch. +const defaultBatchConcurrency = 32 + +// MaxBatchSize is the maximum number of messages allowed per PublishBatch call. +// The HTTP handler enforces this; the Manager itself does not, so internal callers +// (e.g. the SDK) can pass larger batches if they accept the responsibility. +const MaxBatchSize = 100 + +// TopicMessage is one entry in a batch publish. +type TopicMessage struct { + Topic string + Data []byte +} + +// PublishBatchOptions controls batch publish behavior. +type PublishBatchOptions struct { + // BestEffort, if true, attempts every publish even when some fail and returns + // a *BatchError summarizing per-topic failures. Default (false) is fail-fast: + // the first failure cancels remaining in-flight publishes and returns that error. + BestEffort bool + + // MaxConcurrency caps the number of in-flight publishes within this batch. + // 0 means use defaultBatchConcurrency. + MaxConcurrency int +} + +// BatchError aggregates per-topic errors returned when PublishBatch is called +// with BestEffort=true and at least one publish failed. +type BatchError struct { + Errors map[string]error // topic -> error +} + +func (e *BatchError) Error() string { + if len(e.Errors) == 0 { + return "batch publish: no errors" + } + names := make([]string, 0, len(e.Errors)) + for t := range e.Errors { + names = append(names, t) + } + return fmt.Sprintf("batch publish: %d topic(s) failed: %s", len(e.Errors), strings.Join(names, ", ")) +} + // Publish publishes a message to a topic func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error { if m.pubsub == nil { @@ -58,3 +105,90 @@ func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error return nil } + +// PublishBatch publishes multiple messages in parallel, one per topic. +// Default behavior is fail-fast: the first publish error cancels remaining work +// and is returned. If opts.BestEffort is set, every publish is attempted and a +// *BatchError is returned if any failed. +// +// Concurrency is bounded by opts.MaxConcurrency (default 32). +// Empty msgs slice is a no-op and returns nil. +func (m *Manager) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error { + if len(msgs) == 0 { + return nil + } + if m.pubsub == nil { + return fmt.Errorf("pubsub not initialized") + } + + maxConc := opts.MaxConcurrency + if maxConc <= 0 { + maxConc = defaultBatchConcurrency + } + if maxConc > len(msgs) { + maxConc = len(msgs) + } + sem := make(chan struct{}, maxConc) + + if !opts.BestEffort { + g, gctx := errgroup.WithContext(ctx) + for _, msg := range msgs { + msg := msg + g.Go(func() error { + select { + case sem <- struct{}{}: + case <-gctx.Done(): + return gctx.Err() + } + defer func() { <-sem }() + return m.Publish(gctx, msg.Topic, msg.Data) + }) + } + return g.Wait() + } + + // Best-effort path: attempt every publish, collect per-topic errors. + var ( + wg sync.WaitGroup + errMu sync.Mutex + errMap = map[string]error{} + ) + for _, msg := range msgs { + msg := msg + wg.Add(1) + go func() { + defer wg.Done() + select { + case sem <- struct{}{}: + case <-ctx.Done(): + errMu.Lock() + errMap[msg.Topic] = ctx.Err() + errMu.Unlock() + return + } + defer func() { <-sem }() + if err := m.Publish(ctx, msg.Topic, msg.Data); err != nil { + errMu.Lock() + errMap[msg.Topic] = err + errMu.Unlock() + } + }() + } + wg.Wait() + if len(errMap) > 0 { + return &BatchError{Errors: errMap} + } + return nil +} + +// PublishSame is a convenience wrapper that sends the same payload to every topic. +func (m *Manager) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error { + if len(topics) == 0 { + return nil + } + msgs := make([]TopicMessage, len(topics)) + for i, t := range topics { + msgs[i] = TopicMessage{Topic: t, Data: data} + } + return m.PublishBatch(ctx, msgs, opts) +} diff --git a/core/pkg/pubsub/publish_batch_test.go b/core/pkg/pubsub/publish_batch_test.go new file mode 100644 index 0000000..4d13349 --- /dev/null +++ b/core/pkg/pubsub/publish_batch_test.go @@ -0,0 +1,196 @@ +package pubsub + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestPublishBatch_empty_slice_returns_nil(t *testing.T) { + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + + if err := mgr.PublishBatch(context.Background(), nil, PublishBatchOptions{}); err != nil { + t.Fatalf("expected nil error for empty slice, got: %v", err) + } + if err := mgr.PublishBatch(context.Background(), []TopicMessage{}, PublishBatchOptions{}); err != nil { + t.Fatalf("expected nil error for empty slice, got: %v", err) + } +} + +func TestPublishBatch_happy_path(t *testing.T) { + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + + msgs := []TopicMessage{ + {Topic: "a", Data: []byte("data-a")}, + {Topic: "b", Data: []byte("data-b")}, + {Topic: "c", Data: []byte("data-c")}, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil { + t.Fatalf("PublishBatch failed: %v", err) + } +} + +func TestPublishSame_uses_same_payload(t *testing.T) { + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + + topics := []string{"x", "y", "z"} + data := []byte("shared") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := mgr.PublishSame(ctx, topics, data, PublishBatchOptions{}); err != nil { + t.Fatalf("PublishSame failed: %v", err) + } +} + +func TestPublishSame_empty_returns_nil(t *testing.T) { + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + if err := mgr.PublishSame(context.Background(), nil, []byte("x"), PublishBatchOptions{}); err != nil { + t.Fatalf("expected nil for empty topics, got: %v", err) + } +} + +func TestPublishBatch_context_cancel_returns_error(t *testing.T) { + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + + msgs := make([]TopicMessage, 50) + for i := range msgs { + msgs[i] = TopicMessage{Topic: fmt.Sprintf("topic-%d", i), Data: []byte("d")} + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}) + if err == nil { + t.Fatal("expected context.Canceled error, got nil") + } + if !errors.Is(err, context.Canceled) { + t.Logf("got error (acceptable as long as it's an error): %v", err) + } +} + +func TestPublishBatch_concurrency_limit(t *testing.T) { + // Verify PublishBatch with low MaxConcurrency completes without deadlocking. + // Each Publish in a no-peer test environment waits up to 2s for mesh formation, + // so we use a small batch size to keep wall time bounded. + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + + msgs := make([]TopicMessage, 8) + for i := range msgs { + msgs[i] = TopicMessage{Topic: fmt.Sprintf("ct-%d", i), Data: []byte("d")} + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 2}); err != nil { + t.Fatalf("PublishBatch with low concurrency failed: %v", err) + } +} + +// TestPublishBatch_caps_concurrency_above_msg_count verifies that MaxConcurrency +// is clamped to len(msgs) — passing 100 with 3 messages should not panic on +// channel capacity. +func TestPublishBatch_caps_concurrency_above_msg_count(t *testing.T) { + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + + msgs := []TopicMessage{ + {Topic: "a", Data: []byte("1")}, + {Topic: "b", Data: []byte("2")}, + {Topic: "c", Data: []byte("3")}, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 100}); err != nil { + t.Fatalf("PublishBatch failed: %v", err) + } +} + +func TestBatchError_Error_summarizes(t *testing.T) { + be := &BatchError{Errors: map[string]error{ + "topic-a": errors.New("boom"), + "topic-b": errors.New("kaboom"), + }} + s := be.Error() + if s == "" { + t.Fatal("expected non-empty error string") + } + // Should mention both topics. + if !contains(s, "topic-a") || !contains(s, "topic-b") { + t.Errorf("error string %q should mention both failing topics", s) + } +} + +func TestBatchError_Error_empty_map(t *testing.T) { + be := &BatchError{} + if s := be.Error(); s == "" { + t.Fatal("expected non-empty string even for empty map") + } +} + +// contains is a tiny helper to avoid importing strings just for this. +func contains(s, substr string) bool { + for i := 0; i+len(substr) <= len(s); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// TestPublishBatch_concurrent_publishes_thread_safe ensures concurrent +// PublishBatch invocations don't race on internal state. +func TestPublishBatch_concurrent_publishes_thread_safe(t *testing.T) { + mgr, cleanup := createTestManager(t, "test-ns") + defer cleanup() + + const goroutines = 8 + const msgsPerGoroutine = 5 + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + var wg sync.WaitGroup + var failures int64 + + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(gid int) { + defer wg.Done() + msgs := make([]TopicMessage, msgsPerGoroutine) + for i := range msgs { + msgs[i] = TopicMessage{ + Topic: fmt.Sprintf("g%d-t%d", gid, i), + Data: []byte("d"), + } + } + if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil { + atomic.AddInt64(&failures, 1) + t.Logf("goroutine %d failed: %v", gid, err) + } + }(g) + } + wg.Wait() + + if failures > 0 { + t.Errorf("%d concurrent batches failed", failures) + } +} diff --git a/core/pkg/push/device_store_rqlite.go b/core/pkg/push/device_store_rqlite.go new file mode 100644 index 0000000..1f1971e --- /dev/null +++ b/core/pkg/push/device_store_rqlite.go @@ -0,0 +1,172 @@ +package push + +import ( + "context" + "fmt" + "time" + + "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/DeBrosOfficial/network/pkg/secrets" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// SecretsKeyPurpose is the HKDF purpose string for push token encryption. +// Used in pkg/secrets.DeriveKey to derive a domain-separated AES key. +const SecretsKeyPurpose = "push-device-tokens" + +// RqliteDeviceStore is a PushDeviceStore backed by RQLite + AES-256-GCM +// at-rest encryption of the push token. +type RqliteDeviceStore struct { + db rqlite.Client + encKey []byte // derived once at construction + logger *zap.Logger +} + +// NewRqliteDeviceStore derives the per-cluster encryption key from the +// cluster secret and returns a ready-to-use store. The cluster secret is +// the same one used for other at-rest encryption (see pkg/secrets). +func NewRqliteDeviceStore(db rqlite.Client, clusterSecret string, logger *zap.Logger) (*RqliteDeviceStore, error) { + if logger == nil { + logger = zap.NewNop() + } + key, err := secrets.DeriveKey(clusterSecret, SecretsKeyPurpose) + if err != nil { + return nil, fmt.Errorf("derive push-device key: %w", err) + } + return &RqliteDeviceStore{ + db: db, + encKey: key, + logger: logger.Named("push-store"), + }, nil +} + +// deviceRow is the scan target for SELECT queries. +type deviceRow struct { + ID string + Namespace string + UserID string + DeviceID string + Provider string + TokenEncrypted string + Platform string + AppVersion string + CreatedAt int64 + UpdatedAt int64 + LastSeen int64 +} + +// Upsert implements PushDeviceStore. +func (s *RqliteDeviceStore) Upsert(ctx context.Context, dev PushDevice) error { + if dev.Namespace == "" || dev.UserID == "" || dev.DeviceID == "" { + return fmt.Errorf("namespace, user_id, device_id required") + } + if dev.Provider == "" { + return fmt.Errorf("provider required") + } + if dev.Token == "" { + return ErrEmptyToken + } + + encToken, err := secrets.Encrypt(dev.Token, s.encKey) + if err != nil { + return fmt.Errorf("encrypt token: %w", err) + } + + now := time.Now().Unix() + if dev.CreatedAt == 0 { + dev.CreatedAt = now + } + dev.UpdatedAt = now + + id := dev.ID + if id == "" { + id = uuid.New().String() + } + + // SQLite UPSERT keyed on (namespace, user_id, device_id) per the migration's + // UNIQUE constraint. On conflict we replace token + provider + metadata + // while preserving the original id and created_at. + query := ` + INSERT INTO push_devices + (id, namespace, user_id, device_id, provider, token_encrypted, + platform, app_version, created_at, updated_at, last_seen) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(namespace, user_id, device_id) DO UPDATE SET + provider = excluded.provider, + token_encrypted = excluded.token_encrypted, + platform = excluded.platform, + app_version = excluded.app_version, + updated_at = excluded.updated_at, + last_seen = excluded.last_seen + ` + _, err = s.db.Exec(ctx, query, + id, dev.Namespace, dev.UserID, dev.DeviceID, dev.Provider, encToken, + dev.Platform, dev.AppVer, dev.CreatedAt, dev.UpdatedAt, dev.LastSeen, + ) + if err != nil { + return fmt.Errorf("upsert push device: %w", err) + } + return nil +} + +// Delete implements PushDeviceStore. +func (s *RqliteDeviceStore) Delete(ctx context.Context, namespace, id string) error { + if namespace == "" || id == "" { + return fmt.Errorf("namespace and id required") + } + query := `DELETE FROM push_devices WHERE id = ? AND namespace = ?` + res, err := s.db.Exec(ctx, query, id, namespace) + if err != nil { + return fmt.Errorf("delete push device: %w", err) + } + n, _ := res.RowsAffected() + if n == 0 { + return fmt.Errorf("push device not found: %s", id) + } + return nil +} + +// ListForUser implements PushDeviceStore. Returns devices with decrypted tokens. +// Caller MUST treat tokens as sensitive. +func (s *RqliteDeviceStore) ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error) { + if namespace == "" || userID == "" { + return nil, nil + } + query := ` + SELECT id, namespace, user_id, device_id, provider, token_encrypted, + COALESCE(platform, ''), COALESCE(app_version, ''), + created_at, updated_at, COALESCE(last_seen, 0) + FROM push_devices + WHERE namespace = ? AND user_id = ? + ` + var rows []deviceRow + if err := s.db.Query(ctx, &rows, query, namespace, userID); err != nil { + return nil, fmt.Errorf("query push devices: %w", err) + } + + out := make([]PushDevice, 0, len(rows)) + for _, r := range rows { + token, err := secrets.Decrypt(r.TokenEncrypted, s.encKey) + if err != nil { + s.logger.Warn("failed to decrypt push token; skipping device", + zap.String("device_id", r.DeviceID), + zap.Error(err)) + continue + } + out = append(out, PushDevice{ + ID: r.ID, + Namespace: r.Namespace, + UserID: r.UserID, + DeviceID: r.DeviceID, + Provider: r.Provider, + Token: token, + Platform: r.Platform, + AppVer: r.AppVersion, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + LastSeen: r.LastSeen, + }) + } + return out, nil +} diff --git a/core/pkg/push/dispatcher.go b/core/pkg/push/dispatcher.go new file mode 100644 index 0000000..f43017d --- /dev/null +++ b/core/pkg/push/dispatcher.go @@ -0,0 +1,97 @@ +package push + +import ( + "context" + "fmt" + "sync" + + "go.uber.org/zap" +) + +// PushDispatcher routes push messages to the matching provider for each +// of a user's registered devices. +type PushDispatcher struct { + mu sync.RWMutex + providers map[string]PushProvider + devices PushDeviceStore + logger *zap.Logger +} + +// New creates a dispatcher with the given device store. Register +// providers before sending. +func New(devices PushDeviceStore, logger *zap.Logger) *PushDispatcher { + if logger == nil { + logger = zap.NewNop() + } + return &PushDispatcher{ + providers: map[string]PushProvider{}, + devices: devices, + logger: logger.Named("push"), + } +} + +// Register makes a provider available to dispatch. Calling Register with +// the same name twice replaces the previous provider — useful in tests. +func (d *PushDispatcher) Register(p PushProvider) { + d.mu.Lock() + defer d.mu.Unlock() + d.providers[p.Name()] = p +} + +// Provider returns the registered provider by name, or nil. +func (d *PushDispatcher) Provider(name string) PushProvider { + d.mu.RLock() + defer d.mu.RUnlock() + return d.providers[name] +} + +// SendToUser fans out the message to every registered device for the +// user. Each provider failure is logged but does not stop subsequent +// devices. Returns the first encountered error (if any) so callers can +// surface a partial-failure signal. +// +// SendToUser returns nil if the user has no registered devices — that +// is normal, not an error. +func (d *PushDispatcher) SendToUser( + ctx context.Context, + namespace, userID string, + msg PushMessage, +) error { + devs, err := d.devices.ListForUser(ctx, namespace, userID) + if err != nil { + return fmt.Errorf("list devices: %w", err) + } + if len(devs) == 0 { + return nil + } + + var firstErr error + for _, dev := range devs { + d.mu.RLock() + p, ok := d.providers[dev.Provider] + d.mu.RUnlock() + if !ok { + d.logger.Warn("push: dropping device with unregistered provider", + zap.String("provider", dev.Provider), + zap.String("device_id", dev.DeviceID), + ) + if firstErr == nil { + firstErr = fmt.Errorf("%w: %s", ErrUnknownProvider, dev.Provider) + } + continue + } + m := msg + m.DeviceToken = dev.Token + if err := p.Send(ctx, m); err != nil { + d.logger.Warn("push: provider send failed", + zap.String("provider", dev.Provider), + zap.String("device_id", dev.DeviceID), + zap.Error(err), + ) + if firstErr == nil { + firstErr = err + } + } + } + return firstErr +} diff --git a/core/pkg/push/dispatcher_test.go b/core/pkg/push/dispatcher_test.go new file mode 100644 index 0000000..46c93ac --- /dev/null +++ b/core/pkg/push/dispatcher_test.go @@ -0,0 +1,149 @@ +package push + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "go.uber.org/zap" +) + +// fakeProvider records every Send call. +type fakeProvider struct { + name string + sent int32 + lastToken string + err error +} + +func (f *fakeProvider) Name() string { return f.name } +func (f *fakeProvider) Send(ctx context.Context, msg PushMessage) error { + atomic.AddInt32(&f.sent, 1) + f.lastToken = msg.DeviceToken + return f.err +} + +// fakeStore is an in-memory PushDeviceStore. +type fakeStore struct { + devices []PushDevice + err error +} + +func (s *fakeStore) Upsert(ctx context.Context, dev PushDevice) error { + if s.err != nil { + return s.err + } + s.devices = append(s.devices, dev) + return nil +} +func (s *fakeStore) Delete(ctx context.Context, ns, id string) error { return nil } +func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]PushDevice, error) { + if s.err != nil { + return nil, s.err + } + out := []PushDevice{} + for _, d := range s.devices { + if d.Namespace == ns && d.UserID == userID { + out = append(out, d) + } + } + return out, nil +} + +func TestSendToUser_no_devices_returns_nil(t *testing.T) { + d := New(&fakeStore{}, zap.NewNop()) + if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "x"}); err != nil { + t.Fatalf("expected nil for no devices, got: %v", err) + } +} + +func TestSendToUser_routes_to_correct_provider(t *testing.T) { + store := &fakeStore{devices: []PushDevice{ + {Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "ntfy-tok"}, + {Namespace: "ns", UserID: "u", Provider: "expo", Token: "expo-tok"}, + }} + ntfy := &fakeProvider{name: "ntfy"} + expo := &fakeProvider{name: "expo"} + + d := New(store, zap.NewNop()) + d.Register(ntfy) + d.Register(expo) + + if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "hi"}); err != nil { + t.Fatalf("SendToUser: %v", err) + } + + if atomic.LoadInt32(&ntfy.sent) != 1 || ntfy.lastToken != "ntfy-tok" { + t.Errorf("ntfy provider not called correctly: sent=%d token=%s", ntfy.sent, ntfy.lastToken) + } + if atomic.LoadInt32(&expo.sent) != 1 || expo.lastToken != "expo-tok" { + t.Errorf("expo provider not called correctly: sent=%d token=%s", expo.sent, expo.lastToken) + } +} + +func TestSendToUser_unknown_provider_returns_error_continues(t *testing.T) { + store := &fakeStore{devices: []PushDevice{ + {Namespace: "ns", UserID: "u", Provider: "ghost", Token: "tok"}, + {Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "real"}, + }} + ntfy := &fakeProvider{name: "ntfy"} + + d := New(store, zap.NewNop()) + d.Register(ntfy) + + err := d.SendToUser(context.Background(), "ns", "u", PushMessage{}) + if err == nil { + t.Fatal("expected error for unknown provider") + } + if !errors.Is(err, ErrUnknownProvider) { + t.Errorf("expected ErrUnknownProvider, got %v", err) + } + // ntfy should still have been called. + if atomic.LoadInt32(&ntfy.sent) != 1 { + t.Error("ntfy should have been called for the second device") + } +} + +func TestSendToUser_provider_failure_returned_but_other_devices_still_processed(t *testing.T) { + store := &fakeStore{devices: []PushDevice{ + {Namespace: "ns", UserID: "u", Provider: "expo", Token: "tok-1"}, + {Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "tok-2"}, + }} + expoErr := errors.New("expo down") + expo := &fakeProvider{name: "expo", err: expoErr} + ntfy := &fakeProvider{name: "ntfy"} + + d := New(store, zap.NewNop()) + d.Register(expo) + d.Register(ntfy) + + err := d.SendToUser(context.Background(), "ns", "u", PushMessage{}) + if !errors.Is(err, expoErr) { + t.Errorf("expected expo error, got %v", err) + } + if atomic.LoadInt32(&ntfy.sent) != 1 { + t.Error("ntfy should have been called even though expo failed") + } +} + +func TestSendToUser_store_error_propagated(t *testing.T) { + storeErr := errors.New("store boom") + d := New(&fakeStore{err: storeErr}, zap.NewNop()) + + err := d.SendToUser(context.Background(), "ns", "u", PushMessage{}) + if err == nil || !errors.Is(err, storeErr) { + t.Errorf("expected store error, got %v", err) + } +} + +func TestRegister_replaces_existing_provider(t *testing.T) { + d := New(&fakeStore{}, zap.NewNop()) + a := &fakeProvider{name: "ntfy"} + b := &fakeProvider{name: "ntfy"} + d.Register(a) + d.Register(b) + if d.Provider("ntfy") != b { + t.Error("expected second Register to replace the first") + } +} diff --git a/core/pkg/push/providers/expo/expo.go b/core/pkg/push/providers/expo/expo.go new file mode 100644 index 0000000..38c95c7 --- /dev/null +++ b/core/pkg/push/providers/expo/expo.go @@ -0,0 +1,160 @@ +// Package expo wraps the Expo push relay as a push.PushProvider. +// +// This is a thin port of the legacy gateway.PushNotificationService — +// behaviour preserved, surface adapted to the provider abstraction. +// +// Long term Expo is intended to be replaced with direct APNs (iOS) + +// ntfy (Android). This provider exists so the gateway can keep using +// Expo while the migration happens. +package expo + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/DeBrosOfficial/network/pkg/push" + "go.uber.org/zap" +) + +const expoAPIURL = "https://exp.host/--/api/v2/push/send" + +// Config holds Expo provider settings. +type Config struct { + // AccessToken is an optional Expo access token. When set, it's sent + // as a Bearer token, which Expo uses for higher-priority delivery + // and to attribute the send to your account. + AccessToken string + // Timeout bounds each Send call. 0 selects 10 seconds (matching the + // previous PushNotificationService default). + Timeout time.Duration +} + +// Provider is the Expo push.PushProvider implementation. +type Provider struct { + accessToken string + httpClient *http.Client + logger *zap.Logger +} + +// New creates a Provider with the given config. +func New(cfg Config, logger *zap.Logger) *Provider { + if logger == nil { + logger = zap.NewNop() + } + timeout := cfg.Timeout + if timeout <= 0 { + timeout = 10 * time.Second + } + return &Provider{ + accessToken: cfg.AccessToken, + httpClient: &http.Client{Timeout: timeout}, + logger: logger.Named("expo"), + } +} + +// Name implements push.PushProvider. +func (p *Provider) Name() string { return "expo" } + +// expoMessage matches the wire format Expo expects. +type expoMessage struct { + To string `json:"to"` + Title string `json:"title,omitempty"` + Body string `json:"body,omitempty"` + Data map[string]interface{} `json:"data,omitempty"` + Sound string `json:"sound,omitempty"` + Badge int `json:"badge,omitempty"` + Priority string `json:"priority,omitempty"` + MutableContent bool `json:"mutableContent,omitempty"` + ChannelID string `json:"channelId,omitempty"` +} + +// expoTicket is the per-message response. +type expoTicket struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +type expoResponse struct { + Data []expoTicket `json:"data"` +} + +// Send delivers a push via the Expo relay. +func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error { + if msg.DeviceToken == "" { + return push.ErrEmptyToken + } + + priority := "default" + if msg.Priority == push.PriorityHigh { + priority = "high" + } + + wire := expoMessage{ + To: msg.DeviceToken, + Title: msg.Title, + Body: msg.Body, + Data: msg.Data, + Sound: msg.Sound, + Badge: msg.Badge, + Priority: priority, + MutableContent: true, // for iOS Notification Service Extension + ChannelID: msg.Channel, + } + if wire.Sound == "" { + wire.Sound = "default" + } + + body, err := json.Marshal(wire) + if err != nil { + return fmt.Errorf("expo: marshal: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, expoAPIURL, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("expo: build request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + if p.accessToken != "" { + req.Header.Set("Authorization", "Bearer "+p.accessToken) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return fmt.Errorf("expo: post: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 16<<10)) + if err != nil { + return fmt.Errorf("expo: read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("expo: http %d: %s", resp.StatusCode, string(respBody)) + } + + var er expoResponse + if err := json.Unmarshal(respBody, &er); err != nil { + // Older Expo responses sometimes return a bare array; try that fallback. + var tickets []expoTicket + if err2 := json.Unmarshal(respBody, &tickets); err2 == nil { + er.Data = tickets + } else { + return fmt.Errorf("expo: parse response: %w", err) + } + } + + for _, t := range er.Data { + if t.Status != "" && t.Status != "ok" { + return fmt.Errorf("expo: ticket status %q: %s", t.Status, t.Message) + } + } + + return nil +} diff --git a/core/pkg/push/providers/expo/expo_test.go b/core/pkg/push/providers/expo/expo_test.go new file mode 100644 index 0000000..fa89aaa --- /dev/null +++ b/core/pkg/push/providers/expo/expo_test.go @@ -0,0 +1,118 @@ +package expo + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DeBrosOfficial/network/pkg/push" +) + +// roundTripFunc lets us mock http.Client transport for the Expo provider so +// we can assert against requests without hitting exp.host. +type roundTripFunc func(req *http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } + +func newTestProvider(rt roundTripFunc) *Provider { + p := New(Config{}, nil) + p.httpClient.Transport = rt + return p +} + +func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) { + p := New(Config{}, nil) + err := p.Send(context.Background(), push.PushMessage{Body: "x"}) + if err != push.ErrEmptyToken { + t.Errorf("expected ErrEmptyToken, got %v", err) + } +} + +func TestSend_happy_path(t *testing.T) { + var gotPayload map[string]interface{} + var gotAuth string + + p := newTestProvider(func(req *http.Request) (*http.Response, error) { + gotAuth = req.Header.Get("Authorization") + body, _ := io.ReadAll(req.Body) + _ = json.Unmarshal(body, &gotPayload) + resp := httptest.NewRecorder() + resp.WriteHeader(200) + _, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`) + return resp.Result(), nil + }) + p.accessToken = "secret-token" + + err := p.Send(context.Background(), push.PushMessage{ + DeviceToken: "ExponentPushToken[abc]", + Title: "T", Body: "B", + Priority: push.PriorityHigh, + }) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + if gotAuth != "Bearer secret-token" { + t.Errorf("auth header wrong: %s", gotAuth) + } + if gotPayload["to"] != "ExponentPushToken[abc]" { + t.Errorf("to wrong: %v", gotPayload["to"]) + } + if gotPayload["priority"] != "high" { + t.Errorf("priority wrong: %v", gotPayload["priority"]) + } +} + +func TestSend_ticket_error_returns_error(t *testing.T) { + p := newTestProvider(func(req *http.Request) (*http.Response, error) { + resp := httptest.NewRecorder() + resp.WriteHeader(200) + _, _ = resp.WriteString(`{"data":[{"status":"error","message":"DeviceNotRegistered"}]}`) + return resp.Result(), nil + }) + err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"}) + if err == nil { + t.Fatal("expected error for ticket failure") + } +} + +func TestSend_http_error_returns_error(t *testing.T) { + p := newTestProvider(func(req *http.Request) (*http.Response, error) { + resp := httptest.NewRecorder() + resp.WriteHeader(500) + _, _ = resp.WriteString(`upstream broken`) + return resp.Result(), nil + }) + err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"}) + if err == nil { + t.Fatal("expected error for HTTP 500") + } +} + +func TestSend_normal_priority_maps_to_default(t *testing.T) { + var gotPayload map[string]interface{} + p := newTestProvider(func(req *http.Request) (*http.Response, error) { + body, _ := io.ReadAll(req.Body) + _ = json.Unmarshal(body, &gotPayload) + resp := httptest.NewRecorder() + resp.WriteHeader(200) + _, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`) + return resp.Result(), nil + }) + if err := p.Send(context.Background(), push.PushMessage{ + DeviceToken: "x", Body: "y", Priority: push.PriorityNormal, + }); err != nil { + t.Fatal(err) + } + if gotPayload["priority"] != "default" { + t.Errorf("expected priority=default, got %v", gotPayload["priority"]) + } +} + +func TestName(t *testing.T) { + if New(Config{}, nil).Name() != "expo" { + t.Error("expected Name=expo") + } +} diff --git a/core/pkg/push/providers/ntfy/ntfy.go b/core/pkg/push/providers/ntfy/ntfy.go new file mode 100644 index 0000000..adc96b6 --- /dev/null +++ b/core/pkg/push/providers/ntfy/ntfy.go @@ -0,0 +1,132 @@ +// Package ntfy implements a push.PushProvider backed by an ntfy server. +// +// ntfy delivers notifications via plain HTTP POST to /. +// We map PushMessage fields to ntfy headers: +// - Title -> "Title" +// - Priority -> "Priority" +// - Channel -> "Tags" +// - Data -> base64-encoded JSON in "X-Data" +// +// See https://docs.ntfy.sh/publish/#publish-as-json for details. +package ntfy + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/push" + "go.uber.org/zap" +) + +// Config holds per-provider settings. +type Config struct { + // BaseURL is the ntfy HTTP endpoint (e.g. "http://localhost:8080" or + // "https://push.example.com"). Trailing slash is tolerated. + BaseURL string + // AuthToken is an optional per-namespace bearer token. Leave empty to + // disable authentication. + AuthToken string + // Timeout bounds each Send call. 0 selects 5 seconds. + Timeout time.Duration +} + +// Provider is the ntfy push.PushProvider implementation. +type Provider struct { + baseURL string + authToken string + httpClient *http.Client + logger *zap.Logger +} + +// New creates a Provider with the given config. +func New(cfg Config, logger *zap.Logger) *Provider { + if logger == nil { + logger = zap.NewNop() + } + timeout := cfg.Timeout + if timeout <= 0 { + timeout = 5 * time.Second + } + return &Provider{ + baseURL: strings.TrimRight(cfg.BaseURL, "/"), + authToken: cfg.AuthToken, + httpClient: &http.Client{Timeout: timeout}, + logger: logger.Named("ntfy"), + } +} + +// Name implements push.PushProvider. +func (p *Provider) Name() string { return "ntfy" } + +// Send delivers a push notification to the device's ntfy topic. +func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error { + if msg.DeviceToken == "" { + return push.ErrEmptyToken + } + if p.baseURL == "" { + return fmt.Errorf("ntfy: base URL not configured") + } + + // URL-escape each path segment of the device token. ntfy topics can be + // hierarchical (e.g. "ns/myapp/user-1") and we want to preserve those + // '/' separators while escaping any other special characters that + // could let a malicious token escape the topic path. + parts := strings.Split(msg.DeviceToken, "/") + for i, p := range parts { + parts[i] = url.PathEscape(p) + } + endpointURL := p.baseURL + "/" + strings.Join(parts, "/") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(msg.Body)) + if err != nil { + return fmt.Errorf("ntfy: build request: %w", err) + } + + if msg.Title != "" { + req.Header.Set("Title", msg.Title) + } + if msg.Priority == push.PriorityHigh { + req.Header.Set("Priority", "high") + } else if msg.Priority == push.PriorityNormal { + req.Header.Set("Priority", "default") + } + if msg.Channel != "" { + // ntfy uses "Tags" for both visual emoji and operator-defined tags. + req.Header.Set("Tags", msg.Channel) + } + if msg.Badge > 0 { + req.Header.Set("X-Badge", fmt.Sprintf("%d", msg.Badge)) + } + if len(msg.Data) > 0 { + b, err := json.Marshal(msg.Data) + if err != nil { + return fmt.Errorf("ntfy: marshal data: %w", err) + } + req.Header.Set("X-Data", base64.StdEncoding.EncodeToString(b)) + } + if p.authToken != "" { + req.Header.Set("Authorization", "Bearer "+p.authToken) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return fmt.Errorf("ntfy: post: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return fmt.Errorf("ntfy: http %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + // Drain body to allow connection reuse. + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4096)) + return nil +} diff --git a/core/pkg/push/providers/ntfy/ntfy_test.go b/core/pkg/push/providers/ntfy/ntfy_test.go new file mode 100644 index 0000000..d6f08a3 --- /dev/null +++ b/core/pkg/push/providers/ntfy/ntfy_test.go @@ -0,0 +1,191 @@ +package ntfy + +import ( + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/DeBrosOfficial/network/pkg/push" +) + +func TestSend_happy_path(t *testing.T) { + var ( + gotPath string + gotBody string + gotTitle string + gotPriority string + gotAuth string + ) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotTitle = r.Header.Get("Title") + gotPriority = r.Header.Get("Priority") + gotAuth = r.Header.Get("Authorization") + b, _ := io.ReadAll(r.Body) + gotBody = string(b) + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + p := New(Config{BaseURL: srv.URL, AuthToken: "secret"}, nil) + err := p.Send(context.Background(), push.PushMessage{ + DeviceToken: "ns/myapp/user-1", + Title: "Hello", + Body: "World", + Priority: push.PriorityHigh, + }) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + if gotPath != "/ns/myapp/user-1" { + t.Errorf("expected path /ns/myapp/user-1, got %s", gotPath) + } + if gotTitle != "Hello" { + t.Errorf("expected Title=Hello, got %s", gotTitle) + } + if gotPriority != "high" { + t.Errorf("expected Priority=high, got %s", gotPriority) + } + if gotAuth != "Bearer secret" { + t.Errorf("expected Authorization=Bearer secret, got %s", gotAuth) + } + if gotBody != "World" { + t.Errorf("expected body=World, got %s", gotBody) + } +} + +func TestSend_includes_data_header_when_data_set(t *testing.T) { + var gotData string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotData = r.Header.Get("X-Data") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + p := New(Config{BaseURL: srv.URL}, nil) + err := p.Send(context.Background(), push.PushMessage{ + DeviceToken: "topic", + Body: "x", + Data: map[string]interface{}{"call_id": "abc-123"}, + }) + if err != nil { + t.Fatalf("Send: %v", err) + } + decoded, err := base64.StdEncoding.DecodeString(gotData) + if err != nil { + t.Fatalf("X-Data not valid base64: %v", err) + } + var got map[string]interface{} + if err := json.Unmarshal(decoded, &got); err != nil { + t.Fatalf("X-Data not valid JSON: %v", err) + } + if got["call_id"] != "abc-123" { + t.Errorf("data round-trip failed: got %v", got) + } +} + +func TestSend_no_data_no_data_header(t *testing.T) { + var gotData string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotData = r.Header.Get("X-Data") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + p := New(Config{BaseURL: srv.URL}, nil) + if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil { + t.Fatal(err) + } + if gotData != "" { + t.Errorf("expected no X-Data header, got %q", gotData) + } +} + +func TestSend_no_auth_header_when_token_empty(t *testing.T) { + var gotAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + p := New(Config{BaseURL: srv.URL}, nil) + if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil { + t.Fatal(err) + } + if gotAuth != "" { + t.Errorf("expected no Authorization header, got %q", gotAuth) + } +} + +func TestSend_4xx_returns_error_with_body_excerpt(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("forbidden topic")) + })) + defer srv.Close() + + p := New(Config{BaseURL: srv.URL}, nil) + err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}) + if err == nil { + t.Fatal("expected error for 403") + } + if !strings.Contains(err.Error(), "403") || !strings.Contains(err.Error(), "forbidden") { + t.Errorf("error should mention status and body, got: %v", err) + } +} + +func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) { + p := New(Config{BaseURL: "http://example.invalid"}, nil) + err := p.Send(context.Background(), push.PushMessage{Body: "x"}) + if err == nil { + t.Fatal("expected error for empty token") + } + if err != push.ErrEmptyToken { + t.Errorf("expected ErrEmptyToken, got %v", err) + } +} + +func TestSend_short_timeout_returns_error(t *testing.T) { + // Server that blocks for 2s — provider with 100ms timeout should give up. + blockUntil := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-blockUntil: + case <-time.After(2 * time.Second): + } + })) + defer func() { close(blockUntil); srv.Close() }() + + p := New(Config{BaseURL: srv.URL, Timeout: 100 * time.Millisecond}, nil) + start := time.Now() + err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}) + elapsed := time.Since(start) + if err == nil { + t.Error("expected timeout error") + } + if elapsed > 1*time.Second { + t.Errorf("expected fast timeout, took %v", elapsed) + } +} + +func TestSend_no_baseURL_returns_error(t *testing.T) { + p := New(Config{}, nil) + err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}) + if err == nil { + t.Fatal("expected error for missing base URL") + } +} + +func TestName(t *testing.T) { + p := New(Config{BaseURL: "http://x"}, nil) + if p.Name() != "ntfy" { + t.Errorf("expected Name=ntfy, got %s", p.Name()) + } +} diff --git a/core/pkg/push/types.go b/core/pkg/push/types.go new file mode 100644 index 0000000..44cdbe6 --- /dev/null +++ b/core/pkg/push/types.go @@ -0,0 +1,91 @@ +// Package push provides a generic push-notification abstraction for Orama. +// +// Apps register devices with a provider name ("ntfy", "expo", "apns", ...) +// and a provider-specific token. The PushDispatcher routes outbound push +// messages to the matching provider so call sites stay backend-agnostic. +// +// Long-term the platform aims to drop Expo in favour of direct APNs + +// ntfy. The abstraction makes that swap a configuration change rather +// than a code change. +package push + +import ( + "context" + "errors" +) + +// PushPriority signals delivery urgency to the provider. +// Providers that don't support priorities ignore the value. +type PushPriority string + +const ( + PriorityNormal PushPriority = "normal" + PriorityHigh PushPriority = "high" +) + +// PushMessage is the provider-agnostic message format. +// +// DeviceToken is the provider-specific identifier (e.g. an ntfy topic, +// an Expo push token, an APNs device token). The PushDispatcher fills +// it in per-device before calling Send. +type PushMessage struct { + DeviceToken string + Title string + Body string + Data map[string]interface{} + Badge int + Sound string + Channel string // "messages", "calls", etc — provider may map to its own channel concept + Priority PushPriority +} + +// PushProvider is implemented by each backend (ntfy, expo, apns). +type PushProvider interface { + Name() string + // Send delivers a single push. Returning an error counts as a delivery + // failure for that device; the dispatcher logs it and continues. + Send(ctx context.Context, msg PushMessage) error +} + +// PushDevice represents a registered push target for a user. +// +// Token is plaintext in this struct — encryption happens at the storage +// layer. Callers who load Devices from the store must treat tokens as +// sensitive material (don't log them). +type PushDevice struct { + ID string + Namespace string + UserID string + DeviceID string // app-provided + Provider string // matches PushProvider.Name() + Token string + Platform string // "ios" | "android" | "web" + AppVer string + CreatedAt int64 // unix seconds + UpdatedAt int64 + LastSeen int64 +} + +// PushDeviceStore persists per-user device registrations. +type PushDeviceStore interface { + // Upsert registers or updates a device. The Token is encrypted by the + // implementation before being written to durable storage. + Upsert(ctx context.Context, dev PushDevice) error + + // Delete removes a single device by ID, scoped to the namespace. + Delete(ctx context.Context, namespace, id string) error + + // ListForUser returns all devices for a user within a namespace. + // Tokens in the returned slice are decrypted. + ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error) +} + +// Sentinel errors. +var ( + // ErrUnknownProvider is returned by the dispatcher when a device + // references a provider that isn't registered. + ErrUnknownProvider = errors.New("push: unknown provider") + // ErrEmptyToken is returned by providers when called with an empty + // DeviceToken. + ErrEmptyToken = errors.New("push: empty device token") +) diff --git a/core/pkg/serverless/aggregator/aggregator.go b/core/pkg/serverless/aggregator/aggregator.go new file mode 100644 index 0000000..8e58ea2 --- /dev/null +++ b/core/pkg/serverless/aggregator/aggregator.go @@ -0,0 +1,295 @@ +// Package aggregator buffers PubSub trigger events per +// (namespace, function, trigger) and flushes them as a single batched +// invocation. It's used by the PubSub trigger dispatcher when a trigger +// declares aggregation_window_ms > 0. +// +// State is local to each node — buffers are not replicated. This is by +// design: aggregation is intended for high-frequency, lossy event streams +// (presence, VAD signals, metrics). Crash recovery is not provided; an +// orderly shutdown flushes pending buffers via Shutdown(). +package aggregator + +import ( + "context" + "encoding/json" + "sync" + "time" + + "go.uber.org/zap" +) + +// DefaultMaxBatchSize is used when a trigger sets MaxBatchSize=0. +const DefaultMaxBatchSize = 100 + +// Event is one buffered message, mirroring the dispatcher's PubSubEvent +// shape but kept local to avoid an import cycle. +type Event struct { + Topic string `json:"topic"` + Data json.RawMessage `json:"data"` + Namespace string `json:"namespace"` + TriggerDepth int `json:"trigger_depth"` + Timestamp int64 `json:"timestamp"` +} + +// BatchedPayload is what the function receives on a flush. +// `Batched: true` lets a function differentiate single vs. batched mode +// by parsing this discriminator first. +type BatchedPayload struct { + Batched bool `json:"batched"` + Events []Event `json:"events"` +} + +// FlushFn is invoked when a buffer flushes. It receives the marshalled +// BatchedPayload and a context with a sane timeout. The aggregator does +// not retry on flush errors — that's the invoker's responsibility. +type FlushFn func(ctx context.Context, payload []byte) + +// FlushReason annotates why a flush happened. Useful for metrics. +type FlushReason string + +const ( + FlushReasonTimer FlushReason = "timer" + FlushReasonSize FlushReason = "size" + FlushReasonShutdown FlushReason = "shutdown" +) + +// FlushFnWithReason is like FlushFn but also receives the reason. +// Internal use; FlushFn is the simple public form. +type FlushFnWithReason func(ctx context.Context, payload []byte, reason FlushReason) + +// bufferKey identifies a single in-memory buffer. +type bufferKey struct { + Namespace string + FunctionID string + TriggerID string +} + +type bufferEntry struct { + events []Event + timer *time.Timer + windowMs int + maxBatch int + flushFn FlushFnWithReason +} + +// Aggregator buffers events per (namespace, function, trigger) and flushes +// either when the window timer fires or when MaxBatch is reached. +type Aggregator struct { + mu sync.Mutex + buffers map[bufferKey]*bufferEntry + logger *zap.Logger + flushTimeout time.Duration + // inflight tracks dispatched flush goroutines so Shutdown can wait + // for them to finish (or time out) before returning. + inflight sync.WaitGroup +} + +// New creates an Aggregator. flushTimeout bounds the context passed to FlushFn. +// 0 selects a sane default (60s). +func New(logger *zap.Logger, flushTimeout time.Duration) *Aggregator { + if flushTimeout <= 0 { + flushTimeout = 60 * time.Second + } + if logger == nil { + logger = zap.NewNop() + } + return &Aggregator{ + buffers: map[bufferKey]*bufferEntry{}, + logger: logger.Named("aggregator"), + flushTimeout: flushTimeout, + } +} + +// BufferRequest carries everything needed to add an event. +type BufferRequest struct { + Namespace string + FunctionID string + TriggerID string + WindowMs int + MaxBatchSize int + FlushFn FlushFn // simple public form; internally promoted to FlushFnWithReason + Event Event +} + +// Buffer adds an event to the matching buffer. Returns immediately — +// the function is invoked later, asynchronously, when the window or +// size threshold fires. +// +// If WindowMs <= 0, this method panics with a programming-error message +// to surface misuse: callers should not buffer non-aggregating triggers. +func (a *Aggregator) Buffer(req BufferRequest) { + if req.WindowMs <= 0 { + // Aggregator should never be called for non-aggregating triggers. + // Panicking here makes the caller bug obvious during development. + panic("aggregator: Buffer called with WindowMs <= 0") + } + maxBatch := req.MaxBatchSize + if maxBatch <= 0 { + maxBatch = DefaultMaxBatchSize + } + + key := bufferKey{Namespace: req.Namespace, FunctionID: req.FunctionID, TriggerID: req.TriggerID} + + a.mu.Lock() + defer a.mu.Unlock() + + entry, ok := a.buffers[key] + if !ok { + // Promote the user-facing FlushFn into the reason-aware variant. + // We capture req.FlushFn so subsequent Buffer calls keep using it. + userFn := req.FlushFn + entry = &bufferEntry{ + events: make([]Event, 0, maxBatch), + windowMs: req.WindowMs, + maxBatch: maxBatch, + flushFn: func(ctx context.Context, payload []byte, reason FlushReason) { + if userFn != nil { + userFn(ctx, payload) + } + }, + } + a.buffers[key] = entry + } + + entry.events = append(entry.events, req.Event) + + // Start the window timer on the first event of a new window. + if entry.timer == nil { + // Capture key by value for the closure. + k := key + entry.timer = time.AfterFunc(time.Duration(entry.windowMs)*time.Millisecond, func() { + a.flushByTimer(k) + }) + } + + // Size-triggered flush. + if len(entry.events) >= entry.maxBatch { + a.flushLocked(key, entry, FlushReasonSize) + } +} + +// flushByTimer is invoked by time.AfterFunc; it acquires the lock then flushes. +func (a *Aggregator) flushByTimer(key bufferKey) { + a.mu.Lock() + defer a.mu.Unlock() + entry, ok := a.buffers[key] + if !ok { + // Buffer already flushed by size threshold and the bucket removed. + return + } + if len(entry.events) == 0 { + // Defensive: empty buffer — drop it so the map stays bounded. + delete(a.buffers, key) + return + } + a.flushLocked(key, entry, FlushReasonTimer) +} + +// flushLocked must be called with a.mu held. It snapshots the current +// events, removes the bucket entry, then dispatches the flush in a +// goroutine so the caller doesn't block on the function invocation. +// +// Removing the bucket on flush keeps the buffers map bounded over the +// lifetime of the process. If a subsequent event arrives for the same +// (namespace, function, trigger) tuple, Buffer recreates the entry. +func (a *Aggregator) flushLocked(key bufferKey, entry *bufferEntry, reason FlushReason) { + if entry.timer != nil { + entry.timer.Stop() + entry.timer = nil + } + if len(entry.events) == 0 { + // Empty bucket — drop it so the map doesn't accumulate. + delete(a.buffers, key) + return + } + events := entry.events + + payload, err := json.Marshal(BatchedPayload{Batched: true, Events: events}) + if err != nil { + a.logger.Error("failed to marshal batched payload", + zap.String("namespace", key.Namespace), + zap.String("function_id", key.FunctionID), + zap.String("trigger_id", key.TriggerID), + zap.Int("batch_size", len(events)), + zap.Error(err), + ) + // Still drop the bucket — there's no point retrying with the same data. + delete(a.buffers, key) + return + } + + a.logger.Debug("aggregator flush", + zap.String("namespace", key.Namespace), + zap.String("function_id", key.FunctionID), + zap.String("trigger_id", key.TriggerID), + zap.Int("batch_size", len(events)), + zap.String("reason", string(reason)), + ) + + flushFn := entry.flushFn + timeout := a.flushTimeout + delete(a.buffers, key) + + a.inflight.Add(1) + go func() { + defer a.inflight.Done() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + flushFn(ctx, payload, reason) + }() +} + +// Shutdown drains all non-empty buffers and waits for the resulting flush +// invocations to finish, bounded by `wait`. Callers should pass a wait +// long enough to cover one function invocation (e.g. 5–10 seconds) but +// short enough that a misbehaving function can't delay process exit. +// +// Returns true if all in-flight flushes completed before the deadline, +// false on timeout (in which case some events are effectively lost). +func (a *Aggregator) Shutdown(wait time.Duration) bool { + a.mu.Lock() + keys := make([]bufferKey, 0, len(a.buffers)) + for k := range a.buffers { + keys = append(keys, k) + } + for _, k := range keys { + entry := a.buffers[k] + if entry == nil { + continue + } + if entry.timer != nil { + entry.timer.Stop() + entry.timer = nil + } + a.flushLocked(k, entry, FlushReasonShutdown) + } + a.mu.Unlock() + + if wait <= 0 { + return true + } + done := make(chan struct{}) + go func() { + a.inflight.Wait() + close(done) + }() + select { + case <-done: + return true + case <-time.After(wait): + a.logger.Warn("aggregator shutdown timed out; some buffered events may be lost") + return false + } +} + +// Stats reports the current number of buffered events across all keys. +// Useful for metrics. +func (a *Aggregator) Stats() (numBuffers, totalEvents int) { + a.mu.Lock() + defer a.mu.Unlock() + numBuffers = len(a.buffers) + for _, e := range a.buffers { + totalEvents += len(e.events) + } + return +} diff --git a/core/pkg/serverless/aggregator/aggregator_test.go b/core/pkg/serverless/aggregator/aggregator_test.go new file mode 100644 index 0000000..523b049 --- /dev/null +++ b/core/pkg/serverless/aggregator/aggregator_test.go @@ -0,0 +1,307 @@ +package aggregator + +import ( + "context" + "encoding/json" + "sync" + "sync/atomic" + "testing" + "time" + + "go.uber.org/zap" +) + +func TestBuffer_panics_on_zero_window(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic when WindowMs <= 0") + } + }() + a := New(zap.NewNop(), time.Second) + a.Buffer(BufferRequest{ + Namespace: "ns", + FunctionID: "fn", + TriggerID: "tr", + WindowMs: 0, + FlushFn: func(ctx context.Context, payload []byte) {}, + Event: Event{Topic: "t"}, + }) +} + +func TestBuffer_flushes_on_timer(t *testing.T) { + a := New(zap.NewNop(), 5*time.Second) + + var ( + got []Event + gotMu sync.Mutex + done = make(chan struct{}) + ) + + flush := func(ctx context.Context, payload []byte) { + var p BatchedPayload + if err := json.Unmarshal(payload, &p); err != nil { + t.Errorf("unmarshal: %v", err) + } + gotMu.Lock() + got = append(got, p.Events...) + gotMu.Unlock() + close(done) + } + + for i := 0; i < 3; i++ { + a.Buffer(BufferRequest{ + Namespace: "ns", + FunctionID: "fn", + TriggerID: "tr", + WindowMs: 100, // short window so test runs fast + FlushFn: flush, + Event: Event{Topic: "presence:user", Data: json.RawMessage(`"e"`)}, + }) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("flush did not fire within 2s") + } + + gotMu.Lock() + defer gotMu.Unlock() + if len(got) != 3 { + t.Errorf("expected 3 buffered events, got %d", len(got)) + } +} + +func TestBuffer_flushes_on_max_batch_size(t *testing.T) { + a := New(zap.NewNop(), 5*time.Second) + + var ( + flushCount int32 + flushSize int32 + done = make(chan struct{}) + ) + + flush := func(ctx context.Context, payload []byte) { + var p BatchedPayload + _ = json.Unmarshal(payload, &p) + atomic.AddInt32(&flushCount, 1) + atomic.StoreInt32(&flushSize, int32(len(p.Events))) + select { + case <-done: + default: + close(done) + } + } + + for i := 0; i < 5; i++ { + a.Buffer(BufferRequest{ + Namespace: "ns", + FunctionID: "fn", + TriggerID: "tr", + WindowMs: 30_000, // long enough that the timer won't fire + MaxBatchSize: 5, + FlushFn: flush, + Event: Event{Topic: "t"}, + }) + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("max-batch flush did not fire") + } + + if atomic.LoadInt32(&flushCount) != 1 { + t.Errorf("expected 1 flush, got %d", flushCount) + } + if atomic.LoadInt32(&flushSize) != 5 { + t.Errorf("expected batch size 5, got %d", flushSize) + } +} + +func TestBuffer_separate_keys_independent(t *testing.T) { + a := New(zap.NewNop(), 5*time.Second) + + var ( + mu sync.Mutex + counts = map[string]int{} + flush = func(name string) FlushFn { + return func(ctx context.Context, payload []byte) { + var p BatchedPayload + _ = json.Unmarshal(payload, &p) + mu.Lock() + counts[name] += len(p.Events) + mu.Unlock() + } + } + ) + + a.Buffer(BufferRequest{ + Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A", + WindowMs: 100, FlushFn: flush("A"), + Event: Event{Topic: "a"}, + }) + a.Buffer(BufferRequest{ + Namespace: "ns", FunctionID: "fn", TriggerID: "tr-B", + WindowMs: 100, FlushFn: flush("B"), + Event: Event{Topic: "b"}, + }) + a.Buffer(BufferRequest{ + Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A", + WindowMs: 100, FlushFn: flush("A"), + Event: Event{Topic: "a2"}, + }) + + time.Sleep(500 * time.Millisecond) + + mu.Lock() + defer mu.Unlock() + if counts["A"] != 2 { + t.Errorf("A: expected 2 events, got %d", counts["A"]) + } + if counts["B"] != 1 { + t.Errorf("B: expected 1 event, got %d", counts["B"]) + } +} + +func TestShutdown_flushes_all_buffers(t *testing.T) { + a := New(zap.NewNop(), 2*time.Second) + + var flushed int32 + flush := func(ctx context.Context, payload []byte) { + atomic.AddInt32(&flushed, 1) + } + + for i := 0; i < 4; i++ { + a.Buffer(BufferRequest{ + Namespace: "ns", FunctionID: "fn", TriggerID: "tr", + WindowMs: 30_000, + FlushFn: flush, + Event: Event{Topic: "t"}, + }) + } + // Different trigger key — should produce a separate flush. + a.Buffer(BufferRequest{ + Namespace: "ns", FunctionID: "fn", TriggerID: "other", + WindowMs: 30_000, + FlushFn: flush, + Event: Event{Topic: "t2"}, + }) + + a.Shutdown(2*time.Second) + + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&flushed) < 2 { + if time.Now().After(deadline) { + t.Fatalf("expected 2 flushes from Shutdown, got %d", flushed) + } + time.Sleep(10 * time.Millisecond) + } +} + +func TestShutdown_skips_empty_buffers(t *testing.T) { + a := New(zap.NewNop(), 2*time.Second) + + var flushed int32 + flush := func(ctx context.Context, payload []byte) { + atomic.AddInt32(&flushed, 1) + } + + // Add an event to create the buffer entry, then drain via size flush. + a.Buffer(BufferRequest{ + Namespace: "ns", FunctionID: "fn", TriggerID: "tr", + WindowMs: 30_000, MaxBatchSize: 1, + FlushFn: flush, Event: Event{Topic: "t"}, + }) + + // Wait for the size-triggered flush to drain. + deadline := time.Now().Add(2 * time.Second) + for atomic.LoadInt32(&flushed) < 1 { + if time.Now().After(deadline) { + t.Fatal("size flush didn't fire") + } + time.Sleep(5 * time.Millisecond) + } + + // Now the buffer is empty. Shutdown should not flush again. + a.Shutdown(2*time.Second) + time.Sleep(200 * time.Millisecond) + if atomic.LoadInt32(&flushed) != 1 { + t.Errorf("Shutdown flushed an empty buffer: total flushes %d", flushed) + } +} + +func TestStats_reports_buffered_state(t *testing.T) { + a := New(zap.NewNop(), 2*time.Second) + flush := func(ctx context.Context, payload []byte) {} + + a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}}) + a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}}) + a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "b", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}}) + + bufs, evs := a.Stats() + if bufs != 2 { + t.Errorf("expected 2 buffers, got %d", bufs) + } + if evs != 3 { + t.Errorf("expected 3 buffered events, got %d", evs) + } +} + +func TestBuffer_concurrent_writes_no_race(t *testing.T) { + // Run with -race: this should not detect any data races. + a := New(zap.NewNop(), 2*time.Second) + flush := func(ctx context.Context, payload []byte) {} + + var wg sync.WaitGroup + for g := 0; g < 8; g++ { + wg.Add(1) + go func(gid int) { + defer wg.Done() + for i := 0; i < 50; i++ { + a.Buffer(BufferRequest{ + Namespace: "ns", + FunctionID: "fn", + TriggerID: "tr", + WindowMs: 200, + FlushFn: flush, + Event: Event{Topic: "t"}, + }) + } + }(g) + } + wg.Wait() + // Drain. + a.Shutdown(2*time.Second) +} + +func TestBuffer_payload_includes_batched_true_and_topic(t *testing.T) { + a := New(zap.NewNop(), 2*time.Second) + + got := make(chan BatchedPayload, 1) + flush := func(ctx context.Context, payload []byte) { + var p BatchedPayload + if err := json.Unmarshal(payload, &p); err != nil { + t.Errorf("unmarshal: %v", err) + } + got <- p + } + + a.Buffer(BufferRequest{ + Namespace: "ns", FunctionID: "fn", TriggerID: "tr", + WindowMs: 50, FlushFn: flush, + Event: Event{Topic: "presence:user-1", Data: json.RawMessage(`{"x":1}`)}, + }) + + select { + case p := <-got: + if !p.Batched { + t.Error("payload should have Batched=true") + } + if len(p.Events) != 1 || p.Events[0].Topic != "presence:user-1" { + t.Errorf("unexpected events: %+v", p.Events) + } + case <-time.After(2 * time.Second): + t.Fatal("flush did not fire") + } +} diff --git a/core/pkg/serverless/engine.go b/core/pkg/serverless/engine.go index aeddc8c..2a9c3e4 100644 --- a/core/pkg/serverless/engine.go +++ b/core/pkg/serverless/engine.go @@ -328,6 +328,8 @@ func (e *Engine) registerHostModule(ctx context.Context) error { NewFunctionBuilder().WithFunc(e.hCacheIncrBy).Export("cache_incr_by"). NewFunctionBuilder().WithFunc(e.hHTTPFetch).Export("http_fetch"). NewFunctionBuilder().WithFunc(e.hPubSubPublish).Export("pubsub_publish"). + NewFunctionBuilder().WithFunc(e.hPubSubPublishBatch).Export("pubsub_publish_batch"). + NewFunctionBuilder().WithFunc(e.hPushSend).Export("push_send"). NewFunctionBuilder().WithFunc(e.hLogInfo).Export("log_info"). NewFunctionBuilder().WithFunc(e.hLogError).Export("log_error"). Instantiate(ctx) @@ -517,6 +519,46 @@ func (e *Engine) hPubSubPublish(ctx context.Context, mod api.Module, topicPtr, t return 1 // Success } +// hPubSubPublishBatch is the WASM-callable wrapper for PubSubPublishBatch. +// Input: pointer/length of a JSON array of {topic, data_base64}. +// Returns 1 on success, 0 on error. +func (e *Engine) hPubSubPublishBatch(ctx context.Context, mod api.Module, msgsPtr, msgsLen uint32) uint32 { + msgsJSON, ok := e.executor.ReadFromGuest(mod, msgsPtr, msgsLen) + if !ok { + return 0 + } + if err := e.hostServices.PubSubPublishBatch(ctx, msgsJSON); err != nil { + e.logger.Error("host function pubsub_publish_batch failed", zap.Error(err)) + return 0 + } + return 1 +} + +// hPushSend is the WASM-callable wrapper for PushSend. +// Inputs: +// userIDPtr/userIDLen — UTF-8 user ID to push to (within the function's +// own namespace; the namespace is server-side trusted) +// msgPtr/msgLen — JSON payload matching hostfunctions.PushSendArgs +// Returns 1 on success, 0 on error. +func (e *Engine) hPushSend(ctx context.Context, mod api.Module, + userIDPtr, userIDLen, msgPtr, msgLen uint32) uint32 { + userID, ok := e.executor.ReadFromGuest(mod, userIDPtr, userIDLen) + if !ok { + return 0 + } + msgJSON, ok := e.executor.ReadFromGuest(mod, msgPtr, msgLen) + if !ok { + return 0 + } + if err := e.hostServices.PushSend(ctx, string(userID), msgJSON); err != nil { + e.logger.Error("host function push_send failed", + zap.String("user_id", string(userID)), + zap.Error(err)) + return 0 + } + return 1 +} + func (e *Engine) hLogInfo(ctx context.Context, mod api.Module, ptr, size uint32) { msg, ok := e.executor.ReadFromGuest(mod, ptr, size) if ok { diff --git a/core/pkg/serverless/hostfuncs_test.go b/core/pkg/serverless/hostfuncs_test.go index cbdbb32..93e363f 100644 --- a/core/pkg/serverless/hostfuncs_test.go +++ b/core/pkg/serverless/hostfuncs_test.go @@ -90,6 +90,14 @@ func (m *mockHostServices) PubSubPublish(ctx context.Context, topic string, data return nil } +func (m *mockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error { + return nil +} + +func (m *mockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error { + return nil +} + func (m *mockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error { return nil } diff --git a/core/pkg/serverless/hostfunctions/host_services.go b/core/pkg/serverless/hostfunctions/host_services.go index 64f6878..069adcf 100644 --- a/core/pkg/serverless/hostfunctions/host_services.go +++ b/core/pkg/serverless/hostfunctions/host_services.go @@ -5,6 +5,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/ipfs" "github.com/DeBrosOfficial/network/pkg/pubsub" + "github.com/DeBrosOfficial/network/pkg/push" "github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/serverless" "github.com/DeBrosOfficial/network/pkg/tlsutil" @@ -13,6 +14,10 @@ import ( ) // NewHostFunctions creates a new HostFunctions instance. +// +// pushDispatcher may be nil when push isn't configured on this gateway — +// in that case PushSend hostfunc returns nil (silent no-op) so functions +// remain portable across deployments with/without push. func NewHostFunctions( db rqlite.Client, cacheClient olriclib.Client, @@ -20,6 +25,7 @@ func NewHostFunctions( pubsubAdapter *pubsub.ClientAdapter, wsManager serverless.WebSocketManager, secrets serverless.SecretsManager, + pushDispatcher *push.PushDispatcher, cfg HostFunctionsConfig, logger *zap.Logger, ) *HostFunctions { @@ -29,15 +35,16 @@ func NewHostFunctions( } return &HostFunctions{ - db: db, - cacheClient: cacheClient, - storage: storage, - ipfsAPIURL: cfg.IPFSAPIURL, - pubsub: pubsubAdapter, - wsManager: wsManager, - secrets: secrets, - httpClient: tlsutil.NewHTTPClient(httpTimeout), - logger: logger, - logs: make([]serverless.LogEntry, 0), + db: db, + cacheClient: cacheClient, + storage: storage, + ipfsAPIURL: cfg.IPFSAPIURL, + pubsub: pubsubAdapter, + wsManager: wsManager, + secrets: secrets, + pushDispatcher: pushDispatcher, + httpClient: tlsutil.NewHTTPClient(httpTimeout), + logger: logger, + logs: make([]serverless.LogEntry, 0), } } diff --git a/core/pkg/serverless/hostfunctions/pubsub.go b/core/pkg/serverless/hostfunctions/pubsub.go index 82394c1..7e9b570 100644 --- a/core/pkg/serverless/hostfunctions/pubsub.go +++ b/core/pkg/serverless/hostfunctions/pubsub.go @@ -2,8 +2,11 @@ package hostfunctions import ( "context" + "encoding/base64" + "encoding/json" "fmt" + "github.com/DeBrosOfficial/network/pkg/pubsub" "github.com/DeBrosOfficial/network/pkg/serverless" ) @@ -21,6 +24,64 @@ func (h *HostFunctions) PubSubPublish(ctx context.Context, topic string, data [] return nil } +// pubSubBatchEntry mirrors the JSON shape accepted by PubSubPublishBatch. +type pubSubBatchEntry struct { + Topic string `json:"topic"` + DataB64 string `json:"data_base64"` +} + +// PubSubPublishBatch publishes multiple messages in parallel. +// +// Input is JSON: [{"topic":"...","data_base64":"..."}, ...] +// Up to pubsub.MaxBatchSize entries per call. +// +// Default behavior is fail-fast (first publish error is returned). The +// host function does not currently expose a best-effort flag — WASM +// callers that need it should call this function multiple times in +// chunks they're willing to retry independently. +func (h *HostFunctions) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error { + if h.pubsub == nil { + return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("pubsub not available")} + } + + var entries []pubSubBatchEntry + if err := json.Unmarshal(msgsJSON, &entries); err != nil { + return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("invalid json: %w", err)} + } + if len(entries) == 0 { + return nil + } + if len(entries) > pubsub.MaxBatchSize { + return &serverless.HostFunctionError{ + Function: "pubsub_publish_batch", + Cause: fmt.Errorf("too many messages: max %d per batch", pubsub.MaxBatchSize), + } + } + + msgs := make([]pubsub.TopicMessage, 0, len(entries)) + for i, e := range entries { + if e.Topic == "" { + return &serverless.HostFunctionError{ + Function: "pubsub_publish_batch", + Cause: fmt.Errorf("entry %d: empty topic", i), + } + } + data, err := base64.StdEncoding.DecodeString(e.DataB64) + if err != nil { + return &serverless.HostFunctionError{ + Function: "pubsub_publish_batch", + Cause: fmt.Errorf("entry %d (topic %q): bad base64: %w", i, e.Topic, err), + } + } + msgs = append(msgs, pubsub.TopicMessage{Topic: e.Topic, Data: data}) + } + + if err := h.pubsub.PublishBatch(ctx, msgs, pubsub.PublishBatchOptions{}); err != nil { + return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: err} + } + return nil +} + // WSSend sends data to a specific WebSocket client. func (h *HostFunctions) WSSend(ctx context.Context, clientID string, data []byte) error { if h.wsManager == nil { diff --git a/core/pkg/serverless/hostfunctions/push.go b/core/pkg/serverless/hostfunctions/push.go new file mode 100644 index 0000000..dfd2638 --- /dev/null +++ b/core/pkg/serverless/hostfunctions/push.go @@ -0,0 +1,103 @@ +package hostfunctions + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/DeBrosOfficial/network/pkg/push" + "github.com/DeBrosOfficial/network/pkg/serverless" +) + +// PushSendArgs is the JSON payload format the WASM caller marshals into +// the `msgJSON` argument of PushSend. Mirrors push.PushMessage minus the +// device-token (which is filled in per-device by the dispatcher). +type PushSendArgs struct { + Title string `json:"title,omitempty"` + Body string `json:"body,omitempty"` + Channel string `json:"channel,omitempty"` + Priority string `json:"priority,omitempty"` // "high" | "normal" | "" + Badge int `json:"badge,omitempty"` + Sound string `json:"sound,omitempty"` + Data map[string]interface{} `json:"data,omitempty"` +} + +// MaxPushSendArgsBytes caps the JSON arg size to a few KB. Push payloads +// are small by construction (APNs caps at 4KB, ntfy/Expo similar). +const MaxPushSendArgsBytes = 16 * 1024 + +// PushSend implements serverless.HostServices.PushSend. +// +// Sends a push notification to all devices the user has registered in the +// function's namespace. The caller can only target users in their own +// namespace — the dispatcher reads the namespace from the invocation +// context (set by the engine before invoking). +// +// If push is not configured on this gateway (no dispatcher), this returns +// nil (silent no-op) so functions remain portable across environments. +func (h *HostFunctions) PushSend(ctx context.Context, userID string, msgJSON []byte) error { + if h.pushDispatcher == nil { + // Silent no-op — push isn't configured on this gateway. + return nil + } + if userID == "" { + return &serverless.HostFunctionError{ + Function: "push_send", + Cause: fmt.Errorf("user_id required"), + } + } + if len(msgJSON) > MaxPushSendArgsBytes { + return &serverless.HostFunctionError{ + Function: "push_send", + Cause: fmt.Errorf("msg too large: max %d bytes", MaxPushSendArgsBytes), + } + } + + var args PushSendArgs + if err := json.Unmarshal(msgJSON, &args); err != nil { + return &serverless.HostFunctionError{ + Function: "push_send", + Cause: fmt.Errorf("invalid json: %w", err), + } + } + + // Resolve namespace from the current invocation context. A function + // can NEVER push to another namespace's users — the namespace is + // trusted server-side, not from the WASM input. + h.invCtxLock.RLock() + var namespace string + if h.invCtx != nil { + namespace = h.invCtx.Namespace + } + h.invCtxLock.RUnlock() + + if namespace == "" { + return &serverless.HostFunctionError{ + Function: "push_send", + Cause: fmt.Errorf("no namespace in invocation context"), + } + } + + priority := push.PriorityNormal + switch args.Priority { + case "high": + priority = push.PriorityHigh + case "normal", "": + priority = push.PriorityNormal + } + + msg := push.PushMessage{ + Title: args.Title, + Body: args.Body, + Channel: args.Channel, + Priority: priority, + Badge: args.Badge, + Sound: args.Sound, + Data: args.Data, + } + + if err := h.pushDispatcher.SendToUser(ctx, namespace, userID, msg); err != nil { + return &serverless.HostFunctionError{Function: "push_send", Cause: err} + } + return nil +} diff --git a/core/pkg/serverless/hostfunctions/types.go b/core/pkg/serverless/hostfunctions/types.go index 3df7406..28e1aea 100644 --- a/core/pkg/serverless/hostfunctions/types.go +++ b/core/pkg/serverless/hostfunctions/types.go @@ -7,6 +7,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/ipfs" "github.com/DeBrosOfficial/network/pkg/pubsub" + "github.com/DeBrosOfficial/network/pkg/push" "github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/serverless" olriclib "github.com/olric-data/olric" @@ -32,6 +33,10 @@ type HostFunctions struct { httpClient *http.Client logger *zap.Logger + // pushDispatcher may be nil when push isn't configured for this gateway. + // In that case PushSend returns nil silently — see hostfunctions/push.go. + pushDispatcher *push.PushDispatcher + // Current invocation context (set per-execution) invCtx *serverless.InvocationContext invCtxLock sync.RWMutex diff --git a/core/pkg/serverless/mocks_test.go b/core/pkg/serverless/mocks_test.go index 2146358..94fffba 100644 --- a/core/pkg/serverless/mocks_test.go +++ b/core/pkg/serverless/mocks_test.go @@ -176,6 +176,14 @@ func (m *MockHostServices) PubSubPublish(ctx context.Context, topic string, data return nil } +func (m *MockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error { + return nil +} + +func (m *MockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error { + return nil +} + func (m *MockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error { return nil } diff --git a/core/pkg/serverless/triggers/dispatcher.go b/core/pkg/serverless/triggers/dispatcher.go index 94e5d55..d004003 100644 --- a/core/pkg/serverless/triggers/dispatcher.go +++ b/core/pkg/serverless/triggers/dispatcher.go @@ -6,14 +6,12 @@ import ( "time" "github.com/DeBrosOfficial/network/pkg/serverless" + "github.com/DeBrosOfficial/network/pkg/serverless/aggregator" olriclib "github.com/olric-data/olric" "go.uber.org/zap" ) const ( - // triggerCacheDMap is the Olric DMap name for caching trigger lookups. - triggerCacheDMap = "pubsub_triggers" - // maxTriggerDepth prevents infinite loops when triggered functions publish // back to the same topic via the HTTP API. maxTriggerDepth = 5 @@ -37,6 +35,7 @@ type PubSubDispatcher struct { store *PubSubTriggerStore invoker *serverless.Invoker olricClient olriclib.Client // may be nil (cache disabled) + aggregator *aggregator.Aggregator logger *zap.Logger } @@ -51,10 +50,17 @@ func NewPubSubDispatcher( store: store, invoker: invoker, olricClient: olricClient, + aggregator: aggregator.New(logger, dispatchTimeout), logger: logger, } } +// Aggregator exposes the underlying aggregator so callers (gateway lifecycle) +// can flush pending buffers on shutdown. +func (d *PubSubDispatcher) Aggregator() *aggregator.Aggregator { + return d.aggregator +} + // Dispatch looks up all triggers registered for the given topic+namespace and // invokes matching functions asynchronously. Each invocation runs in its own // goroutine and does not block the caller. @@ -82,7 +88,7 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string return } - // Build the event payload once for all invocations + // Build the per-event payload once for non-aggregating dispatches. event := PubSubEvent{ Topic: topic, Data: json.RawMessage(data), @@ -90,11 +96,6 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string TriggerDepth: depth + 1, Timestamp: time.Now().Unix(), } - eventJSON, err := json.Marshal(event) - if err != nil { - d.logger.Error("Failed to marshal PubSub event", zap.Error(err)) - return - } d.logger.Debug("Dispatching PubSub triggers", zap.String("namespace", namespace), @@ -103,94 +104,81 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string zap.Int("depth", depth), ) + var ( + eventJSON []byte + marshalErr error + ) + for _, match := range matches { + if match.AggregationWindowMs > 0 { + d.bufferEvent(match, event) + continue + } + // Lazily marshal — non-aggregating triggers need eventJSON. + if eventJSON == nil && marshalErr == nil { + eventJSON, marshalErr = json.Marshal(event) + if marshalErr != nil { + d.logger.Error("Failed to marshal PubSub event", zap.Error(marshalErr)) + continue + } + } + if marshalErr != nil { + continue + } go d.invokeFunction(match, eventJSON) } } -// InvalidateCache removes the cached trigger lookup for a namespace+topic. -// Call this when triggers are added or removed. -func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) { - if d.olricClient == nil { - return - } - - dm, err := d.olricClient.NewDMap(triggerCacheDMap) - if err != nil { - d.logger.Debug("Failed to get trigger cache DMap for invalidation", zap.Error(err)) - return - } - - key := cacheKey(namespace, topic) - if _, err := dm.Delete(ctx, key); err != nil { - d.logger.Debug("Failed to invalidate trigger cache", zap.String("key", key), zap.Error(err)) - } +// bufferEvent routes an event through the aggregator. The flush callback +// invokes the function with the batched payload. +func (d *PubSubDispatcher) bufferEvent(match TriggerMatch, event PubSubEvent) { + d.aggregator.Buffer(aggregator.BufferRequest{ + Namespace: match.Namespace, + FunctionID: match.FunctionID, + TriggerID: match.TriggerID, + WindowMs: match.AggregationWindowMs, + MaxBatchSize: match.AggregationMaxBatchSize, + Event: aggregator.Event{ + Topic: event.Topic, + Data: event.Data, + Namespace: event.Namespace, + TriggerDepth: event.TriggerDepth, + Timestamp: event.Timestamp, + }, + FlushFn: func(ctx context.Context, payload []byte) { + req := &serverless.InvokeRequest{ + Namespace: match.Namespace, + FunctionName: match.FunctionName, + Input: payload, + TriggerType: serverless.TriggerTypePubSub, + } + if _, err := d.invoker.Invoke(ctx, req); err != nil { + d.logger.Warn("Aggregated PubSub invocation failed", + zap.String("function", match.FunctionName), + zap.String("trigger_id", match.TriggerID), + zap.Error(err), + ) + } + }, + }) } -// getMatches returns the trigger matches for a topic+namespace, using Olric cache when available. +// InvalidateCache is now a no-op — the dispatcher no longer caches lookups. +// Kept on the type so callers who used it still compile. +func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) {} + +// getMatches returns the trigger matches for a topic+namespace. +// +// Caching note: an earlier revision cached results in Olric keyed by +// (namespace, topic). With wildcard triggers the cache becomes +// inconsistent — a single trigger Add/Remove invalidates an unbounded +// number of resolved-topic keys. The cache was removed; re-introducing +// it requires a generation-counter (or TTL) scheme that handles +// wildcard pattern changes. func (d *PubSubDispatcher) getMatches(ctx context.Context, namespace, topic string) ([]TriggerMatch, error) { - // Try cache first - if d.olricClient != nil { - if matches, ok := d.getCached(ctx, namespace, topic); ok { - return matches, nil - } - } - - // Cache miss — query database - matches, err := d.store.GetByTopicAndNamespace(ctx, topic, namespace) - if err != nil { - return nil, err - } - - // Populate cache - if d.olricClient != nil && matches != nil { - d.setCache(ctx, namespace, topic, matches) - } - - return matches, nil + return d.store.GetByTopicAndNamespace(ctx, topic, namespace) } -// getCached attempts to retrieve trigger matches from Olric cache. -func (d *PubSubDispatcher) getCached(ctx context.Context, namespace, topic string) ([]TriggerMatch, bool) { - dm, err := d.olricClient.NewDMap(triggerCacheDMap) - if err != nil { - return nil, false - } - - key := cacheKey(namespace, topic) - result, err := dm.Get(ctx, key) - if err != nil { - return nil, false - } - - data, err := result.Byte() - if err != nil { - return nil, false - } - - var matches []TriggerMatch - if err := json.Unmarshal(data, &matches); err != nil { - return nil, false - } - - return matches, true -} - -// setCache stores trigger matches in Olric cache. -func (d *PubSubDispatcher) setCache(ctx context.Context, namespace, topic string, matches []TriggerMatch) { - dm, err := d.olricClient.NewDMap(triggerCacheDMap) - if err != nil { - return - } - - data, err := json.Marshal(matches) - if err != nil { - return - } - - key := cacheKey(namespace, topic) - _ = dm.Put(ctx, key, data) -} // invokeFunction invokes a single function for a trigger match. func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte) { @@ -224,7 +212,3 @@ func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte) ) } -// cacheKey returns the Olric cache key for a namespace+topic pair. -func cacheKey(namespace, topic string) string { - return "triggers:" + namespace + ":" + topic -} diff --git a/core/pkg/serverless/triggers/pattern.go b/core/pkg/serverless/triggers/pattern.go new file mode 100644 index 0000000..79f2fae --- /dev/null +++ b/core/pkg/serverless/triggers/pattern.go @@ -0,0 +1,138 @@ +package triggers + +import ( + "fmt" + "strings" +) + +// MaxPatternLength is the maximum allowed glob pattern length. +const MaxPatternLength = 256 + +// ValidatePattern checks that a glob pattern is well-formed. +// Empty patterns and unbalanced character classes return an error. +// Patterns longer than MaxPatternLength are rejected to keep DB scans bounded. +func ValidatePattern(p string) error { + if p == "" { + return fmt.Errorf("empty pattern") + } + if len(p) > MaxPatternLength { + return fmt.Errorf("pattern too long: %d > %d", len(p), MaxPatternLength) + } + open := 0 + for i, c := range p { + switch c { + case '[': + open++ + case ']': + open-- + if open < 0 { + return fmt.Errorf("unmatched ']' at position %d", i) + } + } + } + if open != 0 { + return fmt.Errorf("unmatched '['") + } + return nil +} + +// IsWildcard reports whether the pattern contains any glob metacharacter. +// Useful for choosing between exact-match cache keys and wildcard scans. +func IsWildcard(p string) bool { + return strings.ContainsAny(p, "*?[") +} + +// PatternMatches returns true when topic matches pattern under Orama's glob +// semantics: +// - '*' matches zero or more characters EXCEPT ':' +// - '**' matches zero or more characters INCLUDING ':' (deep wildcard) +// - '?' matches exactly one character (any) +// - '[abc]' / '[!abc]' character classes +// +// SQLite's GLOB is the first-pass filter (in pubsub_store.go); this +// post-filter enforces segment boundaries for single-'*' patterns since +// SQLite GLOB treats '*' as "any chars including separators". +func PatternMatches(pattern, topic string) bool { + if strings.Contains(pattern, "**") { + // Deep wildcards already accept across segment boundaries — SQLite GLOB + // already accepted this row. No further filtering needed. + return true + } + return strictGlobMatch(pattern, topic) +} + +// strictGlobMatch implements glob matching where '*' does NOT cross ':'. +// Recursive backtracking matcher; bounded length keeps it cheap. +func strictGlobMatch(pattern, s string) bool { + pi, si := 0, 0 + starPi, starSi := -1, -1 + + for si < len(s) { + if pi < len(pattern) { + pc := pattern[pi] + switch pc { + case '?': + pi++ + si++ + continue + case '*': + // Remember position so we can backtrack. + starPi = pi + starSi = si + pi++ + continue + case '[': + end := strings.IndexByte(pattern[pi+1:], ']') + if end < 0 { + return false + } + class := pattern[pi+1 : pi+1+end] + if matchClass(class, s[si]) { + pi += end + 2 + si++ + continue + } + default: + if pc == s[si] { + pi++ + si++ + continue + } + } + } + // No match at this position — try to extend the last '*' if any, + // but '*' must not cross a ':' segment separator. + if starPi >= 0 && s[starSi] != ':' { + starSi++ + pi = starPi + 1 + si = starSi + continue + } + return false + } + + // Consume any trailing '*' in the pattern. + for pi < len(pattern) && pattern[pi] == '*' { + pi++ + } + return pi == len(pattern) +} + +// matchClass reports whether c matches the SQLite-style character class body +// (between '[' and ']'). Supports negation with leading '!'. +func matchClass(class string, c byte) bool { + if class == "" { + return false + } + negate := false + if class[0] == '!' { + negate = true + class = class[1:] + } + for i := 0; i < len(class); i++ { + if class[i] == c { + return !negate + } + } + return negate +} diff --git a/core/pkg/serverless/triggers/pattern_test.go b/core/pkg/serverless/triggers/pattern_test.go new file mode 100644 index 0000000..84e25c8 --- /dev/null +++ b/core/pkg/serverless/triggers/pattern_test.go @@ -0,0 +1,162 @@ +package triggers + +import ( + "strings" + "testing" +) + +func TestValidatePattern_empty_returns_error(t *testing.T) { + if err := ValidatePattern(""); err == nil { + t.Error("expected error for empty pattern") + } +} + +func TestValidatePattern_too_long_returns_error(t *testing.T) { + long := strings.Repeat("a", MaxPatternLength+1) + if err := ValidatePattern(long); err == nil { + t.Error("expected error for over-long pattern") + } +} + +func TestValidatePattern_unbalanced_brackets_returns_error(t *testing.T) { + cases := []string{"a[b", "a]b", "[a[b]", "a]"} + for _, c := range cases { + if err := ValidatePattern(c); err == nil { + t.Errorf("expected error for %q", c) + } + } +} + +func TestValidatePattern_valid_patterns_no_error(t *testing.T) { + cases := []string{"foo", "foo:*", "foo:**", "*.bar", "[abc]xyz", "[!a]b", "?abc"} + for _, c := range cases { + if err := ValidatePattern(c); err != nil { + t.Errorf("expected no error for %q, got: %v", c, err) + } + } +} + +func TestIsWildcard(t *testing.T) { + cases := map[string]bool{ + "foo": false, + "foo:bar": false, + "foo:*": true, + "foo?bar": true, + "[abc]xyz": true, + "foo:**": true, + "a:b:c:d:e:f": false, + } + for in, want := range cases { + if got := IsWildcard(in); got != want { + t.Errorf("IsWildcard(%q) = %v, want %v", in, got, want) + } + } +} + +func TestPatternMatches_exact(t *testing.T) { + cases := []struct { + pattern, topic string + want bool + }{ + {"foo", "foo", true}, + {"foo", "bar", false}, + {"foo:bar", "foo:bar", true}, + {"foo:bar", "foo:baz", false}, + } + for _, c := range cases { + if got := PatternMatches(c.pattern, c.topic); got != c.want { + t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want) + } + } +} + +func TestPatternMatches_single_star_segment_bounded(t *testing.T) { + cases := []struct { + pattern, topic string + want bool + }{ + // '*' matches within a single segment + {"presence:*", "presence:user-1", true}, + {"presence:*", "presence:user-2", true}, + {"presence:*", "presence:", true}, + // '*' does NOT cross ':' + {"presence:*", "presence:user:device", false}, + {"a:*:b", "a:x:b", true}, + {"a:*:b", "a:x:y:b", false}, + // Different prefix + {"presence:*", "calls:invite", false}, + } + for _, c := range cases { + if got := PatternMatches(c.pattern, c.topic); got != c.want { + t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want) + } + } +} + +func TestPatternMatches_double_star_crosses_segments(t *testing.T) { + cases := []struct { + pattern, topic string + want bool + }{ + {"notify:**", "notify:user-1", true}, + {"notify:**", "notify:user:device:1", true}, + {"**", "anything:goes:here", true}, + } + for _, c := range cases { + if got := PatternMatches(c.pattern, c.topic); got != c.want { + t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want) + } + } +} + +func TestPatternMatches_question_mark(t *testing.T) { + cases := []struct { + pattern, topic string + want bool + }{ + {"a?c", "abc", true}, + {"a?c", "axc", true}, + {"a?c", "ac", false}, + {"a?c", "abbc", false}, + } + for _, c := range cases { + if got := PatternMatches(c.pattern, c.topic); got != c.want { + t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want) + } + } +} + +func TestPatternMatches_character_class(t *testing.T) { + cases := []struct { + pattern, topic string + want bool + }{ + {"[abc]xyz", "axyz", true}, + {"[abc]xyz", "bxyz", true}, + {"[abc]xyz", "dxyz", false}, + {"[!a]bc", "xbc", true}, + {"[!a]bc", "abc", false}, + } + for _, c := range cases { + if got := PatternMatches(c.pattern, c.topic); got != c.want { + t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want) + } + } +} + +func TestPatternMatches_trailing_star_with_remaining_chars(t *testing.T) { + // '*' can match zero characters at end. + cases := []struct { + pattern, topic string + want bool + }{ + {"foo*", "foo", true}, + {"foo*", "foobar", true}, + {"foo*", "foobar:baz", false}, // ':' breaks single '*' + } + for _, c := range cases { + if got := PatternMatches(c.pattern, c.topic); got != c.want { + t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want) + } + } +} diff --git a/core/pkg/serverless/triggers/pubsub_store.go b/core/pkg/serverless/triggers/pubsub_store.go index 7ee14fb..6125339 100644 --- a/core/pkg/serverless/triggers/pubsub_store.go +++ b/core/pkg/serverless/triggers/pubsub_store.go @@ -16,30 +16,43 @@ import ( // TriggerMatch contains the fields needed to dispatch a trigger invocation. // It's the result of JOINing function_pubsub_triggers with functions. +// +// Topic is the *resolved* topic that the published message was sent to, +// not the pattern stored in the trigger. This lets aggregating functions +// see which concrete topic each event came from. +// +// AggregationWindowMs > 0 indicates the dispatcher should buffer events +// instead of invoking the function per event. type TriggerMatch struct { - TriggerID string - FunctionID string - FunctionName string - Namespace string - Topic string + TriggerID string + FunctionID string + FunctionName string + Namespace string + Topic string + AggregationWindowMs int + AggregationMaxBatchSize int } // triggerRow maps to the function_pubsub_triggers table for query scanning. type triggerRow struct { - ID string - FunctionID string - Topic string - Enabled bool - CreatedAt time.Time + ID string + FunctionID string + TopicPattern string + Enabled bool + CreatedAt time.Time + AggregationWindowMs int + AggregationMaxBatchSize int } // triggerMatchRow maps to the JOIN query result for scanning. type triggerMatchRow struct { - TriggerID string - FunctionID string - FunctionName string - Namespace string - Topic string + TriggerID string + FunctionID string + FunctionName string + Namespace string + TopicPattern string + AggregationWindowMs int + AggregationMaxBatchSize int } // PubSubTriggerStore manages PubSub trigger persistence in RQLite. @@ -57,30 +70,60 @@ func NewPubSubTriggerStore(db rqlite.Client, logger *zap.Logger) *PubSubTriggerS } // Add registers a new PubSub trigger for a function. +// `topicPattern` may be an exact topic or a SQLite GLOB pattern (e.g. "presence:*"). // Returns the trigger ID. -func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topic string) (string, error) { +// +// For backward compatibility, aggregation defaults to disabled (windowMs=0). +// Use AddWithAggregation to opt in. +func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topicPattern string) (string, error) { + return s.AddWithAggregation(ctx, functionID, topicPattern, 0, 0) +} + +// AddWithAggregation registers a trigger with optional aggregation. +// - aggregationWindowMs = 0 disables aggregation (per-event invocation, default). +// - aggregationMaxBatchSize = 0 uses the default (100) when aggregation is enabled. +func (s *PubSubTriggerStore) AddWithAggregation( + ctx context.Context, + functionID, topicPattern string, + aggregationWindowMs, aggregationMaxBatchSize int, +) (string, error) { if functionID == "" { return "", fmt.Errorf("function ID required") } - if topic == "" { - return "", fmt.Errorf("topic required") + if err := ValidatePattern(topicPattern); err != nil { + return "", fmt.Errorf("invalid topic pattern: %w", err) + } + if aggregationWindowMs < 0 || aggregationWindowMs > 60_000 { + return "", fmt.Errorf("aggregation_window_ms must be between 0 and 60000") + } + if aggregationMaxBatchSize < 0 || aggregationMaxBatchSize > 1000 { + return "", fmt.Errorf("aggregation_max_batch_size must be between 0 and 1000") + } + if aggregationWindowMs > 0 && aggregationMaxBatchSize == 0 { + aggregationMaxBatchSize = 100 } id := uuid.New().String() now := time.Now() + // Write both `topic` (legacy) and `topic_pattern` (new). Keeping `topic` + // populated lets old binaries running concurrently during a rolling + // upgrade continue reading triggers. A future migration drops `topic`. query := ` - INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) - VALUES (?, ?, ?, TRUE, ?) + INSERT INTO function_pubsub_triggers (id, function_id, topic, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size) + VALUES (?, ?, ?, ?, TRUE, ?, ?, ?) ` - if _, err := s.db.Exec(ctx, query, id, functionID, topic, now); err != nil { + if _, err := s.db.Exec(ctx, query, id, functionID, topicPattern, topicPattern, now, aggregationWindowMs, aggregationMaxBatchSize); err != nil { return "", fmt.Errorf("failed to add pubsub trigger: %w", err) } s.logger.Info("PubSub trigger added", zap.String("trigger_id", id), zap.String("function_id", functionID), - zap.String("topic", topic), + zap.String("topic_pattern", topicPattern), + zap.Bool("wildcard", IsWildcard(topicPattern)), + zap.Int("aggregation_window_ms", aggregationWindowMs), + zap.Int("aggregation_max_batch_size", aggregationMaxBatchSize), ) return id, nil @@ -129,7 +172,7 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri } query := ` - SELECT id, function_id, topic, enabled, created_at + SELECT id, function_id, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size FROM function_pubsub_triggers WHERE function_id = ? ` @@ -142,18 +185,22 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri triggers := make([]serverless.PubSubTrigger, len(rows)) for i, row := range rows { triggers[i] = serverless.PubSubTrigger{ - ID: row.ID, - FunctionID: row.FunctionID, - Topic: row.Topic, - Enabled: row.Enabled, + ID: row.ID, + FunctionID: row.FunctionID, + Topic: row.TopicPattern, + Enabled: row.Enabled, + AggregationWindowMs: row.AggregationWindowMs, + AggregationMaxBatchSize: row.AggregationMaxBatchSize, } } return triggers, nil } -// GetByTopicAndNamespace returns all enabled triggers for a topic within a namespace. -// Only returns triggers for active functions. +// GetByTopicAndNamespace returns all enabled triggers whose topic_pattern +// matches `topic` within the namespace. Patterns are SQLite GLOB; the +// post-filter enforces stricter segment-aware semantics. +// Only triggers for active functions are returned. func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic, namespace string) ([]TriggerMatch, error) { if topic == "" || namespace == "" { return nil, nil @@ -161,10 +208,12 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic, query := ` SELECT t.id AS trigger_id, t.function_id AS function_id, - f.name AS function_name, f.namespace AS namespace, t.topic AS topic + f.name AS function_name, f.namespace AS namespace, t.topic_pattern AS topic_pattern, + t.aggregation_window_ms AS aggregation_window_ms, + t.aggregation_max_batch_size AS aggregation_max_batch_size FROM function_pubsub_triggers t JOIN functions f ON t.function_id = f.id - WHERE t.topic = ? AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active' + WHERE ? GLOB t.topic_pattern AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active' ` var rows []triggerMatchRow @@ -172,15 +221,21 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic, return nil, fmt.Errorf("failed to query triggers for topic: %w", err) } - matches := make([]TriggerMatch, len(rows)) - for i, row := range rows { - matches[i] = TriggerMatch{ - TriggerID: row.TriggerID, - FunctionID: row.FunctionID, - FunctionName: row.FunctionName, - Namespace: row.Namespace, - Topic: row.Topic, + matches := make([]TriggerMatch, 0, len(rows)) + for _, row := range rows { + // Post-filter to enforce strict segment boundaries on '*'. + if !PatternMatches(row.TopicPattern, topic) { + continue } + matches = append(matches, TriggerMatch{ + TriggerID: row.TriggerID, + FunctionID: row.FunctionID, + FunctionName: row.FunctionName, + Namespace: row.Namespace, + Topic: topic, // resolved topic, not the pattern + AggregationWindowMs: row.AggregationWindowMs, + AggregationMaxBatchSize: row.AggregationMaxBatchSize, + }) } return matches, nil diff --git a/core/pkg/serverless/triggers/triggers_test.go b/core/pkg/serverless/triggers/triggers_test.go index a9822cc..e2662f4 100644 --- a/core/pkg/serverless/triggers/triggers_test.go +++ b/core/pkg/serverless/triggers/triggers_test.go @@ -40,13 +40,6 @@ func TestDispatcher_DepthLimit(t *testing.T) { d.Dispatch(context.Background(), "ns", "topic", []byte("data"), maxTriggerDepth+1) } -func TestCacheKey(t *testing.T) { - key := cacheKey("my-namespace", "my-topic") - if key != "triggers:my-namespace:my-topic" { - t.Errorf("unexpected cache key: %s", key) - } -} - func TestPubSubEvent_Marshal(t *testing.T) { event := PubSubEvent{ Topic: "chat", diff --git a/core/pkg/serverless/types.go b/core/pkg/serverless/types.go index 66a13f7..716f297 100644 --- a/core/pkg/serverless/types.go +++ b/core/pkg/serverless/types.go @@ -288,11 +288,21 @@ type DBTrigger struct { } // PubSubTrigger represents a pubsub trigger. +// +// Topic may be an exact topic name or a SQLite GLOB pattern (e.g. +// "presence:*"). See pkg/serverless/triggers/pattern.go for matching rules. +// +// AggregationWindowMs > 0 enables event buffering: the dispatcher accumulates +// events for at most that many milliseconds (or until AggregationMaxBatchSize +// events have been collected, whichever comes first), then invokes the +// function once with a batched payload of type BatchedPubSubEvent. type PubSubTrigger struct { - ID string `json:"id"` - FunctionID string `json:"function_id"` - Topic string `json:"topic"` - Enabled bool `json:"enabled"` + ID string `json:"id"` + FunctionID string `json:"function_id"` + Topic string `json:"topic"` + Enabled bool `json:"enabled"` + AggregationWindowMs int `json:"aggregation_window_ms,omitempty"` + AggregationMaxBatchSize int `json:"aggregation_max_batch_size,omitempty"` } // Timer represents a one-time scheduled execution. @@ -337,6 +347,14 @@ type HostServices interface { // PubSub operations PubSubPublish(ctx context.Context, topic string, data []byte) error + PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error + + // Push notifications. Sends to all of `userID`'s registered devices in + // the function's namespace. `msgJSON` is the JSON-encoded PushSendArgs + // shape (see hostfunctions.PushSend). Returns nil if push is not + // configured (silent no-op) so functions can be portable across + // namespaces with/without push enabled. + PushSend(ctx context.Context, userID string, msgJSON []byte) error // WebSocket operations (only valid in WS context) WSSend(ctx context.Context, clientID string, data []byte) error diff --git a/core/pkg/sniproxy/router.go b/core/pkg/sniproxy/router.go new file mode 100644 index 0000000..a2f521b --- /dev/null +++ b/core/pkg/sniproxy/router.go @@ -0,0 +1,106 @@ +package sniproxy + +import ( + "strings" + "sync" +) + +// Backend describes where to forward a connection. +type Backend struct { + // Name is for logs/metrics only. Optional. + Name string + // Network is the dial network ("tcp", "tcp4", "tcp6"). Default "tcp". + Network string + // Addr is the dial target ("127.0.0.1:5349"). + Addr string +} + +// Route maps an SNI value (or wildcard pattern) to a Backend. +// +// Match semantics: +// - "example.com" matches exactly "example.com" +// - "*.example.com" matches any single-label subdomain ("a.example.com" +// but not "a.b.example.com" — single-label like DNS wildcards) +type Route struct { + Match string + Backend Backend +} + +// Router atomically swaps a routing table while concurrent reads are in +// flight. Reads are lock-free after the slice is published. +type Router struct { + mu sync.RWMutex + routes []Route + fallback Backend +} + +// NewRouter creates a router with no routes and the given fallback. +func NewRouter(fallback Backend) *Router { + return &Router{fallback: fallback} +} + +// Pick returns the matching backend for an SNI value, or the fallback if +// no route matches (or if sni is empty). +func (r *Router) Pick(sni string) Backend { + if sni == "" { + r.mu.RLock() + defer r.mu.RUnlock() + return r.fallback + } + sni = strings.ToLower(sni) + r.mu.RLock() + defer r.mu.RUnlock() + for _, route := range r.routes { + if matchSNI(route.Match, sni) { + return route.Backend + } + } + return r.fallback +} + +// Replace atomically swaps the routing table. The new routes replace the +// old ones in their entirety; partial updates are not supported. +func (r *Router) Replace(routes []Route, fallback Backend) { + r.mu.Lock() + defer r.mu.Unlock() + r.routes = routes + r.fallback = fallback +} + +// Routes returns a defensive copy of the current routes. For introspection. +func (r *Router) Routes() []Route { + r.mu.RLock() + defer r.mu.RUnlock() + out := make([]Route, len(r.routes)) + copy(out, r.routes) + return out +} + +// Fallback returns the current fallback backend. +func (r *Router) Fallback() Backend { + r.mu.RLock() + defer r.mu.RUnlock() + return r.fallback +} + +// matchSNI implements the Match semantics documented on Route. +func matchSNI(pattern, sni string) bool { + pattern = strings.ToLower(pattern) + if pattern == sni { + return true + } + // "*.example.com" matches ".example.com". + if strings.HasPrefix(pattern, "*.") { + suffix := pattern[1:] // ".example.com" + if !strings.HasSuffix(sni, suffix) { + return false + } + labelEnd := len(sni) - len(suffix) + if labelEnd <= 0 { + return false + } + // No additional dots in the wildcard label. + return !strings.Contains(sni[:labelEnd], ".") + } + return false +} diff --git a/core/pkg/sniproxy/router_test.go b/core/pkg/sniproxy/router_test.go new file mode 100644 index 0000000..bece8d2 --- /dev/null +++ b/core/pkg/sniproxy/router_test.go @@ -0,0 +1,113 @@ +package sniproxy + +import ( + "sync" + "testing" +) + +func TestRouter_pick_exact_match(t *testing.T) { + fb := Backend{Name: "fallback", Addr: "127.0.0.1:9000"} + r := NewRouter(fb) + r.Replace([]Route{ + {Match: "turn.example.com", Backend: Backend{Name: "turn", Addr: "127.0.0.1:5349"}}, + }, fb) + + got := r.Pick("turn.example.com") + if got.Addr != "127.0.0.1:5349" { + t.Errorf("expected turn backend, got %+v", got) + } +} + +func TestRouter_pick_unmatched_returns_fallback(t *testing.T) { + fb := Backend{Name: "caddy", Addr: "127.0.0.1:8443"} + r := NewRouter(fb) + r.Replace([]Route{ + {Match: "turn.example.com", Backend: Backend{Addr: "127.0.0.1:5349"}}, + }, fb) + + if got := r.Pick("api.example.com"); got != fb { + t.Errorf("expected fallback, got %+v", got) + } + if got := r.Pick(""); got != fb { + t.Errorf("expected fallback for empty SNI, got %+v", got) + } +} + +func TestRouter_pick_case_insensitive(t *testing.T) { + fb := Backend{Addr: "127.0.0.1:8443"} + r := NewRouter(fb) + r.Replace([]Route{ + {Match: "Turn.Example.Com", Backend: Backend{Addr: "127.0.0.1:5349"}}, + }, fb) + + if got := r.Pick("turn.example.com"); got.Addr != "127.0.0.1:5349" { + t.Errorf("expected case-insensitive match, got %+v", got) + } +} + +func TestRouter_pick_wildcard_subdomain(t *testing.T) { + fb := Backend{Addr: "127.0.0.1:8443"} + r := NewRouter(fb) + r.Replace([]Route{ + {Match: "*.example.com", Backend: Backend{Name: "wild", Addr: "127.0.0.1:5349"}}, + }, fb) + + cases := map[string]bool{ + "a.example.com": true, + "foo.example.com": true, + "a.b.example.com": false, // multi-label not allowed + "example.com": false, // bare domain doesn't match *.example.com + "other.com": false, + } + for sni, want := range cases { + got := r.Pick(sni) == Backend{Name: "wild", Addr: "127.0.0.1:5349"} + if got != want { + t.Errorf("Pick(%q): want match=%v, got match=%v", sni, want, got) + } + } +} + +func TestRouter_replace_atomic(t *testing.T) { + // Many concurrent reads against many concurrent Replace calls — should + // never observe partial state. Run with -race. + fb := Backend{Addr: "fb"} + r := NewRouter(fb) + r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "1"}}}, fb) + + var wg sync.WaitGroup + stop := make(chan struct{}) + + // Readers + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + _ = r.Pick("a.com") + } + } + }() + } + + // Writers + for i := 0; i < 200; i++ { + r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "x"}}}, fb) + } + close(stop) + wg.Wait() +} + +func TestRouter_routes_returns_copy(t *testing.T) { + r := NewRouter(Backend{}) + original := []Route{{Match: "a", Backend: Backend{Addr: "1"}}} + r.Replace(original, Backend{}) + got := r.Routes() + got[0].Match = "mutated" + if r.Routes()[0].Match != "a" { + t.Error("Routes() should return a defensive copy") + } +} diff --git a/core/pkg/sniproxy/server.go b/core/pkg/sniproxy/server.go new file mode 100644 index 0000000..4b9607e --- /dev/null +++ b/core/pkg/sniproxy/server.go @@ -0,0 +1,177 @@ +package sniproxy + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "go.uber.org/zap" +) + +// Config tunes the proxy server. +type Config struct { + // ClientHelloTimeout bounds the wait for a parseable ClientHello. + // 0 selects 5 seconds. + ClientHelloTimeout time.Duration + // BackendDialTimeout bounds backend connect time. 0 selects 5 seconds. + BackendDialTimeout time.Duration + // MaxConcurrentConns caps total in-flight connections to prevent + // resource exhaustion. 0 selects 10000. + MaxConcurrentConns int +} + +// Server is a TCP-level SNI router. Create via NewServer, then call +// Serve(listener) in a goroutine. Close cancels in-flight connections. +type Server struct { + router *Router + cfg Config + logger *zap.Logger + + gate chan struct{} // bounded semaphore for concurrent connections + wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc +} + +// NewServer constructs a Server with the given router and config. +func NewServer(router *Router, cfg Config, logger *zap.Logger) *Server { + if logger == nil { + logger = zap.NewNop() + } + if cfg.ClientHelloTimeout <= 0 { + cfg.ClientHelloTimeout = 5 * time.Second + } + if cfg.BackendDialTimeout <= 0 { + cfg.BackendDialTimeout = 5 * time.Second + } + if cfg.MaxConcurrentConns <= 0 { + cfg.MaxConcurrentConns = 10000 + } + ctx, cancel := context.WithCancel(context.Background()) + return &Server{ + router: router, + cfg: cfg, + logger: logger.Named("sniproxy"), + gate: make(chan struct{}, cfg.MaxConcurrentConns), + ctx: ctx, + cancel: cancel, + } +} + +// Serve accepts connections from ln until ln.Accept returns a permanent +// error or Close is called. Serve always returns a non-nil error. +func (s *Server) Serve(ln net.Listener) error { + for { + conn, err := ln.Accept() + if err != nil { + // Check for shutdown via cancelled ctx. + if s.ctx.Err() != nil { + return s.ctx.Err() + } + // Net errors temporarily? Backoff briefly so we don't busy-loop. + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + time.Sleep(50 * time.Millisecond) + continue + } + return err + } + select { + case s.gate <- struct{}{}: + default: + s.logger.Warn("max concurrent connections reached, dropping", + zap.Int("limit", s.cfg.MaxConcurrentConns), + zap.String("remote", conn.RemoteAddr().String()), + ) + conn.Close() + continue + } + s.wg.Add(1) + go func(c net.Conn) { + defer s.wg.Done() + defer func() { <-s.gate }() + s.handle(c) + }(conn) + } +} + +// Close cancels in-flight connections and waits for handlers to drain. +func (s *Server) Close() { + s.cancel() + s.wg.Wait() +} + +// handle processes a single accepted connection: peek SNI, dial backend, +// replay peeked bytes, then bidirectional copy. +func (s *Server) handle(conn net.Conn) { + defer conn.Close() + + sni, peeked, err := PeekClientHello(conn, s.cfg.ClientHelloTimeout) + if err != nil { + s.logger.Debug("ClientHello peek failed", + zap.String("remote", conn.RemoteAddr().String()), + zap.Error(err), + ) + return + } + + backend := s.router.Pick(sni) + if backend.Addr == "" { + s.logger.Warn("no backend for SNI", + zap.String("sni", sni), + zap.String("remote", conn.RemoteAddr().String()), + ) + return + } + + network := backend.Network + if network == "" { + network = "tcp" + } + + upstream, err := net.DialTimeout(network, backend.Addr, s.cfg.BackendDialTimeout) + if err != nil { + s.logger.Warn("backend dial failed", + zap.String("sni", sni), + zap.String("backend", backend.Addr), + zap.Error(err), + ) + return + } + defer upstream.Close() + + // Replay peeked bytes (the ClientHello + anything else buffered). + if len(peeked) > 0 { + if _, err := upstream.Write(peeked); err != nil { + s.logger.Debug("replay to backend failed", + zap.String("sni", sni), + zap.Error(err), + ) + return + } + } + + // Bidirectional copy. We close both connections when either side + // finishes OR when the server is shutting down, so handle() can't + // hang forever on a half-stuck peer. + done := make(chan struct{}, 2) + go func() { + _, _ = io.Copy(upstream, conn) + done <- struct{}{} + }() + go func() { + _, _ = io.Copy(conn, upstream) + done <- struct{}{} + }() + select { + case <-done: + case <-s.ctx.Done(): + } + // Force both sides closed; second copy will exit immediately. + upstream.Close() + conn.Close() + <-done // drain the second goroutine +} diff --git a/core/pkg/sniproxy/server_test.go b/core/pkg/sniproxy/server_test.go new file mode 100644 index 0000000..bca9d60 --- /dev/null +++ b/core/pkg/sniproxy/server_test.go @@ -0,0 +1,143 @@ +package sniproxy + +import ( + "bufio" + "crypto/tls" + "errors" + "io" + "net" + "testing" + "time" + + "go.uber.org/zap" +) + +// startEchoBackend creates a TCP server that echoes the first 1024 bytes +// it reads, then closes. Returns the listener and a chan that receives +// the bytes the server saw. +func startEchoBackend(t *testing.T) (net.Listener, <-chan []byte) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + got := make(chan []byte, 4) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + _ = c.SetReadDeadline(time.Now().Add(2 * time.Second)) + buf := make([]byte, 1024) + n, _ := c.Read(buf) + got <- append([]byte(nil), buf[:n]...) + }(conn) + } + }() + return ln, got +} + +func TestServer_routes_TLS_to_correct_backend(t *testing.T) { + turnLn, turnGot := startEchoBackend(t) + defer turnLn.Close() + caddyLn, caddyGot := startEchoBackend(t) + defer caddyLn.Close() + + router := NewRouter(Backend{Network: "tcp", Addr: caddyLn.Addr().String()}) + router.Replace([]Route{ + {Match: "turn.example.com", Backend: Backend{Network: "tcp", Addr: turnLn.Addr().String()}}, + }, router.Fallback()) + + srv := NewServer(router, Config{}, zap.NewNop()) + defer srv.Close() + + frontLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer frontLn.Close() + + go func() { _ = srv.Serve(frontLn) }() + + // Client A: SNI=turn.example.com -> goes to turnLn + dialAndStartTLS(t, frontLn.Addr().String(), "turn.example.com") + + // Client B: SNI=other.example.com -> falls through to caddyLn + dialAndStartTLS(t, frontLn.Addr().String(), "other.example.com") + + select { + case b := <-turnGot: + if len(b) == 0 { + t.Error("turn backend received empty bytes") + } + case <-time.After(3 * time.Second): + t.Error("turn backend did not receive bytes") + } + + select { + case b := <-caddyGot: + if len(b) == 0 { + t.Error("caddy fallback received empty bytes") + } + case <-time.After(3 * time.Second): + t.Error("caddy fallback did not receive bytes") + } +} + +// dialAndStartTLS opens a TLS handshake (which produces a ClientHello) +// against the given address with the given SNI. Returns immediately — +// the test only needs the proxy to forward the bytes; it doesn't +// require handshake completion. +func dialAndStartTLS(t *testing.T, addr, sni string) { + t.Helper() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + go func() { + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + c := tls.Client(conn, &tls.Config{ServerName: sni, InsecureSkipVerify: true}) + _ = c.Handshake() // expected to fail (echo backend isn't TLS) + }() +} + +func TestServer_no_backend_drops_connection(t *testing.T) { + router := NewRouter(Backend{}) // empty fallback, empty Addr -> dropped + srv := NewServer(router, Config{}, zap.NewNop()) + defer srv.Close() + + frontLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer frontLn.Close() + + go func() { _ = srv.Serve(frontLn) }() + + conn, err := net.Dial("tcp", frontLn.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true}) + // Handshake should fail because connection is closed by proxy. + go func() { _ = c.Handshake() }() + + // Reader should see EOF quickly. + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + br := bufio.NewReader(conn) + _, err = br.ReadByte() + if err == nil { + t.Error("expected connection drop") + } + if !errors.Is(err, io.EOF) { + // "use of closed network connection" is also fine. + t.Logf("acceptable read error: %v", err) + } +} + diff --git a/core/pkg/sniproxy/sni.go b/core/pkg/sniproxy/sni.go new file mode 100644 index 0000000..cf349c1 --- /dev/null +++ b/core/pkg/sniproxy/sni.go @@ -0,0 +1,235 @@ +// Package sniproxy provides a TCP-level Server Name Indication (SNI) router. +// +// The router peeks at the unencrypted TLS ClientHello on each accepted +// connection, extracts the SNI host name, and forwards the raw stream to +// a backend. It does NOT terminate TLS — encrypted bytes pass through +// verbatim. This lets one TCP port serve multiple TLS-speaking backends +// (HTTPS for the gateway, TURNS for stealth WebRTC, etc.) without +// sharing private keys with the proxy. +// +// Design goals: +// - Zero TLS material on the proxy +// - Bounded ClientHello read (no slowloris) +// - Backend dial timeout +// - Per-IP rate limiting +// +// SNI parsing follows RFC 5246 §7.4.1.2 (TLS record + ClientHello) and +// RFC 6066 §3 (server_name extension). +package sniproxy + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strings" + "time" +) + +// ErrNoSNI is returned when the ClientHello has no server_name extension. +var ErrNoSNI = errors.New("sniproxy: ClientHello has no SNI") + +// MaxClientHelloBytes bounds how many bytes we'll read while looking for +// the SNI. TLS ClientHello records are typically 200–500 bytes; this is +// a generous cap that still defends against memory abuse. +const MaxClientHelloBytes = 16 * 1024 + +// PeekClientHello reads bytes from conn until a TLS ClientHello has +// been parsed (or MaxClientHelloBytes is exceeded). Returns the SNI +// hostname (lowercased), the bytes consumed (must be replayed to the +// backend), and any error. +// +// readTimeout bounds the wait — slowloris-style stalls return an error +// quickly without holding the goroutine indefinitely. +func PeekClientHello(conn net.Conn, readTimeout time.Duration) (string, []byte, error) { + if readTimeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(readTimeout)) + defer conn.SetReadDeadline(time.Time{}) + } + + br := bufio.NewReaderSize(conn, MaxClientHelloBytes) + + // Peek the TLS record header (5 bytes): content_type, version (2), + // length (2). content_type for ClientHello is 22 (handshake). + header, err := br.Peek(5) + if err != nil { + return "", nil, fmt.Errorf("read tls record header: %w", err) + } + if header[0] != 22 { + return "", nil, fmt.Errorf("not a TLS handshake record (type=%d)", header[0]) + } + recLen := int(binary.BigEndian.Uint16(header[3:5])) + if recLen <= 0 || 5+recLen > MaxClientHelloBytes { + return "", nil, fmt.Errorf("invalid record length %d", recLen) + } + + full, err := br.Peek(5 + recLen) + if err != nil { + return "", nil, fmt.Errorf("read tls record body: %w", err) + } + + sni, err := parseSNI(full[5:]) + if err != nil { + return "", nil, err + } + + // We've only peeked — drain the buffer to capture the bytes for replay. + consumed := make([]byte, br.Buffered()) + if _, err := io.ReadFull(br, consumed); err != nil { + return "", nil, fmt.Errorf("drain peeked bytes: %w", err) + } + + return sni, consumed, nil +} + +// parseSNI parses a TLS ClientHello body (without the 5-byte record +// header) and returns the server_name extension value if present. +func parseSNI(body []byte) (string, error) { + r := newReader(body) + + // Handshake type (1 byte) — must be 1 (ClientHello). + hsType, err := r.readByte() + if err != nil { + return "", err + } + if hsType != 1 { + return "", fmt.Errorf("not a ClientHello (handshake type %d)", hsType) + } + // Handshake length (3 bytes). + if _, err := r.readBytes(3); err != nil { + return "", err + } + // client_version (2) + random (32). + if _, err := r.readBytes(2 + 32); err != nil { + return "", err + } + // session_id. + sidLen, err := r.readByte() + if err != nil { + return "", err + } + if _, err := r.readBytes(int(sidLen)); err != nil { + return "", err + } + // cipher_suites length (2). + csLen, err := r.readUint16() + if err != nil { + return "", err + } + if _, err := r.readBytes(int(csLen)); err != nil { + return "", err + } + // compression_methods length (1). + cmLen, err := r.readByte() + if err != nil { + return "", err + } + if _, err := r.readBytes(int(cmLen)); err != nil { + return "", err + } + // Extensions length (2). Optional — TLS 1.0 ClientHello can skip it. + if r.remaining() < 2 { + return "", ErrNoSNI + } + extTotalLen, err := r.readUint16() + if err != nil { + return "", err + } + if int(extTotalLen) > r.remaining() { + return "", fmt.Errorf("extensions truncated") + } + + end := r.pos + int(extTotalLen) + for r.pos < end { + extType, err := r.readUint16() + if err != nil { + return "", err + } + extLen, err := r.readUint16() + if err != nil { + return "", err + } + extData, err := r.readBytes(int(extLen)) + if err != nil { + return "", err + } + // server_name extension is type 0. + if extType != 0 { + continue + } + return parseServerName(extData) + } + return "", ErrNoSNI +} + +// parseServerName parses the body of a server_name extension and returns +// the first host_name (type 0) entry. +func parseServerName(data []byte) (string, error) { + r := newReader(data) + // server_name_list length (2). + listLen, err := r.readUint16() + if err != nil { + return "", err + } + if int(listLen) > r.remaining() { + return "", fmt.Errorf("server_name list truncated") + } + end := r.pos + int(listLen) + for r.pos < end { + nameType, err := r.readByte() + if err != nil { + return "", err + } + nameLen, err := r.readUint16() + if err != nil { + return "", err + } + nameBytes, err := r.readBytes(int(nameLen)) + if err != nil { + return "", err + } + if nameType == 0 { // host_name + return strings.ToLower(string(nameBytes)), nil + } + } + return "", ErrNoSNI +} + +// reader is a tiny byte-slice cursor used by parseSNI/parseServerName. +type reader struct { + buf []byte + pos int +} + +func newReader(buf []byte) *reader { return &reader{buf: buf} } + +func (r *reader) remaining() int { return len(r.buf) - r.pos } + +func (r *reader) readByte() (byte, error) { + if r.pos >= len(r.buf) { + return 0, io.ErrUnexpectedEOF + } + b := r.buf[r.pos] + r.pos++ + return b, nil +} + +func (r *reader) readUint16() (uint16, error) { + if r.pos+2 > len(r.buf) { + return 0, io.ErrUnexpectedEOF + } + v := binary.BigEndian.Uint16(r.buf[r.pos : r.pos+2]) + r.pos += 2 + return v, nil +} + +func (r *reader) readBytes(n int) ([]byte, error) { + if r.pos+n > len(r.buf) { + return nil, io.ErrUnexpectedEOF + } + b := r.buf[r.pos : r.pos+n] + r.pos += n + return b, nil +} diff --git a/core/pkg/sniproxy/sni_test.go b/core/pkg/sniproxy/sni_test.go new file mode 100644 index 0000000..581dc44 --- /dev/null +++ b/core/pkg/sniproxy/sni_test.go @@ -0,0 +1,173 @@ +package sniproxy + +import ( + "crypto/tls" + "errors" + "io" + "net" + "sync" + "testing" + "time" +) + +// dialAndPeek dials a TLS handshake to the given listener and returns +// what PeekClientHello on the server side parsed. +func dialAndPeek(t *testing.T, ln net.Listener, sni string) (string, []byte, error) { + t.Helper() + + type result struct { + sni string + peeked []byte + err error + } + resCh := make(chan result, 1) + + // Server side: accept once, peek SNI. + go func() { + conn, err := ln.Accept() + if err != nil { + resCh <- result{err: err} + return + } + defer conn.Close() + s, p, err := PeekClientHello(conn, 2*time.Second) + resCh <- result{sni: s, peeked: p, err: err} + }() + + // Client side: kick off a TLS handshake. We don't care if it + // completes (no server cert) — we only need ClientHello to be sent. + // Use a goroutine so the test doesn't deadlock waiting on Handshake. + go func() { + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + return + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + c := tls.Client(conn, &tls.Config{ + ServerName: sni, + InsecureSkipVerify: true, + }) + _ = c.Handshake() // expected to fail; we only needed the ClientHello + }() + + select { + case r := <-resCh: + return r.sni, r.peeked, r.err + case <-time.After(5 * time.Second): + return "", nil, errors.New("test timeout") + } +} + +func TestPeekClientHello_returns_sni(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + sni, peeked, err := dialAndPeek(t, ln, "example.com") + if err != nil { + t.Fatalf("PeekClientHello: %v", err) + } + if sni != "example.com" { + t.Errorf("expected sni=example.com, got %q", sni) + } + if len(peeked) == 0 { + t.Error("expected non-empty peeked bytes") + } +} + +func TestPeekClientHello_lowercases_sni(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + sni, _, err := dialAndPeek(t, ln, "Example.COM") + if err != nil { + t.Fatal(err) + } + if sni != "example.com" { + t.Errorf("expected lowercase, got %q", sni) + } +} + +func TestPeekClientHello_non_tls_returns_error(t *testing.T) { + a, b := net.Pipe() + defer a.Close() + defer b.Close() + + go func() { + // Send something that isn't a TLS handshake record. + _, _ = a.Write([]byte("GET / HTTP/1.1\r\n\r\n")) + _ = a.Close() + }() + + _, _, err := PeekClientHello(b, 1*time.Second) + if err == nil { + t.Fatal("expected error for non-TLS bytes") + } +} + +func TestPeekClientHello_short_record_returns_error(t *testing.T) { + a, b := net.Pipe() + defer a.Close() + defer b.Close() + + go func() { + // One byte, then close — too short for record header. + _, _ = a.Write([]byte{22}) + _ = a.Close() + }() + + _, _, err := PeekClientHello(b, 1*time.Second) + if err == nil { + t.Fatal("expected error for short record") + } + // EOF or read error is acceptable. + if !errors.Is(err, io.EOF) && err.Error() == "" { + t.Logf("error: %v", err) + } +} + +func TestPeekClientHello_concurrent_safe(t *testing.T) { + // Verify no shared state leaks between PeekClientHello calls. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + return + } + defer conn.Close() + _ = conn.SetDeadline(time.Now().Add(2 * time.Second)) + c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true}) + _ = c.Handshake() + }() + } + for i := 0; i < 4; i++ { + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + sni, _, err := PeekClientHello(conn, 2*time.Second) + conn.Close() + if err != nil { + t.Errorf("peek %d: %v", i, err) + } + if sni != "x.example.com" { + t.Errorf("peek %d: got %q", i, sni) + } + } + wg.Wait() +} diff --git a/core/sni-router b/core/sni-router new file mode 100755 index 0000000..44aa374 Binary files /dev/null and b/core/sni-router differ diff --git a/core/systemd/orama-sni-router.service b/core/systemd/orama-sni-router.service new file mode 100644 index 0000000..d41837a --- /dev/null +++ b/core/systemd/orama-sni-router.service @@ -0,0 +1,38 @@ +[Unit] +Description=Orama SNI Router (TLS-level :443 → backend forwarder) +Documentation=https://github.com/DeBrosOfficial/network +After=network.target +Before=caddy.service +PartOf=orama-node.service + +[Service] +Type=simple +WorkingDirectory=/opt/orama +EnvironmentFile=-/opt/orama/.orama/data/sni-router.env +ExecStart=/opt/orama/bin/orama-sni-router --config sni-router.yaml + +# Bind privileged ports (:80, :443) without running as root. +AmbientCapabilities=CAP_NET_BIND_SERVICE +CapabilityBoundingSet=CAP_NET_BIND_SERVICE + +User=orama +Group=orama +NoNewPrivileges=yes +ProtectSystem=strict +ProtectHome=yes +PrivateTmp=yes +LimitNOFILE=65536 + +TimeoutStopSec=15s +KillMode=mixed +KillSignal=SIGTERM + +Restart=on-failure +RestartSec=5s + +StandardOutput=journal +StandardError=journal +SyslogIdentifier=orama-sni-router + +[Install] +WantedBy=multi-user.target