diff --git a/core/Makefile b/core/Makefile
index da8ab1a..91595b1 100644
--- a/core/Makefile
+++ b/core/Makefile
@@ -63,7 +63,7 @@ test-e2e-quick:
.PHONY: build clean test deps tidy fmt vet lint install-hooks push-devnet push-testnet rollout-devnet rollout-testnet release
-VERSION := 0.120.0
+VERSION := 0.121.0
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)'
@@ -80,6 +80,7 @@ build: deps
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
go build -ldflags "$(LDFLAGS)" -o bin/sfu ./cmd/sfu
go build -ldflags "$(LDFLAGS)" -o bin/turn ./cmd/turn
+ go build -ldflags "$(LDFLAGS)" -o bin/orama-sni-router ./cmd/sni-router
@echo "Build complete! Run ./bin/orama version"
# Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source)
diff --git a/core/cmd/sni-router/main.go b/core/cmd/sni-router/main.go
new file mode 100644
index 0000000..cc727df
--- /dev/null
+++ b/core/cmd/sni-router/main.go
@@ -0,0 +1,242 @@
+// Command sni-router is a TLS-level Server Name Indication router.
+//
+// It listens on a public TCP port (typically :443), peeks at the TLS
+// ClientHello SNI on each connection, and forwards the raw stream to
+// a configured backend. It does NOT terminate TLS — encrypted bytes
+// pass through verbatim. This lets one port serve multiple TLS-speaking
+// backends (HTTPS for the gateway, TURN-over-TLS for stealth WebRTC).
+//
+// See pkg/sniproxy for the underlying library.
+//
+// Configuration: YAML file at --config (defaults to ~/.orama/sni-router.yaml).
+//
+// Example sni-router.yaml:
+//
+// listen: ":443"
+// client_hello_timeout: 5s
+// backend_dial_timeout: 5s
+// max_concurrent_conns: 10000
+// fallback:
+// name: caddy
+// addr: "127.0.0.1:8443"
+// routes:
+// - match: "cdn.example.com"
+// backend:
+// name: turn-tls
+// addr: "127.0.0.1:5349"
+// - match: "turn.example.com"
+// backend:
+// name: turn-tls
+// addr: "127.0.0.1:5349"
+// - match: "*.ns-myapp.example.com"
+// backend:
+// name: gateway
+// addr: "127.0.0.1:8443"
+package main
+
+import (
+ "flag"
+ "fmt"
+ "net"
+ "os"
+ "os/signal"
+ "path/filepath"
+ "strings"
+ "syscall"
+ "time"
+
+ "github.com/DeBrosOfficial/network/pkg/config"
+ "github.com/DeBrosOfficial/network/pkg/logging"
+ "github.com/DeBrosOfficial/network/pkg/sniproxy"
+ "go.uber.org/zap"
+)
+
+var (
+ version = "dev"
+ commit = "unknown"
+)
+
+// yamlBackend mirrors sniproxy.Backend for YAML decoding.
+type yamlBackend struct {
+ Name string `yaml:"name"`
+ Network string `yaml:"network"`
+ Addr string `yaml:"addr"`
+}
+
+// yamlRoute mirrors sniproxy.Route for YAML decoding.
+type yamlRoute struct {
+ Match string `yaml:"match"`
+ Backend yamlBackend `yaml:"backend"`
+}
+
+// yamlConfig is the on-disk configuration shape.
+type yamlConfig struct {
+ Listen string `yaml:"listen"`
+ ClientHelloTimeout time.Duration `yaml:"client_hello_timeout"`
+ BackendDialTimeout time.Duration `yaml:"backend_dial_timeout"`
+ MaxConcurrentConns int `yaml:"max_concurrent_conns"`
+ Fallback yamlBackend `yaml:"fallback"`
+ Routes []yamlRoute `yaml:"routes"`
+}
+
+func main() {
+ logger, err := logging.NewColoredLogger(logging.ComponentSNI, true)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err)
+ os.Exit(1)
+ }
+
+ logger.ComponentInfo(logging.ComponentSNI, "Starting SNI router",
+ zap.String("version", version),
+ zap.String("commit", commit))
+
+ cfg := parseConfig(logger)
+
+ router := sniproxy.NewRouter(toBackend(cfg.Fallback))
+ router.Replace(toRoutes(cfg.Routes), toBackend(cfg.Fallback))
+
+ srv := sniproxy.NewServer(router, sniproxy.Config{
+ ClientHelloTimeout: cfg.ClientHelloTimeout,
+ BackendDialTimeout: cfg.BackendDialTimeout,
+ MaxConcurrentConns: cfg.MaxConcurrentConns,
+ }, logger.Logger)
+
+ ln, err := net.Listen("tcp", cfg.Listen)
+ if err != nil {
+ logger.ComponentError(logging.ComponentSNI, "Failed to listen",
+ zap.String("addr", cfg.Listen), zap.Error(err))
+ os.Exit(1)
+ }
+
+ logger.ComponentInfo(logging.ComponentSNI, "SNI router listening",
+ zap.String("addr", cfg.Listen),
+ zap.Int("routes", len(cfg.Routes)),
+ zap.String("fallback", cfg.Fallback.Addr),
+ )
+
+ // Run Serve in a goroutine so the main goroutine can wait on signals.
+ serveErrCh := make(chan error, 1)
+ go func() {
+ serveErrCh <- srv.Serve(ln)
+ }()
+
+ // Wait for termination signal or unrecoverable Serve error.
+ quit := make(chan os.Signal, 1)
+ signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
+
+ select {
+ case sig := <-quit:
+ logger.ComponentInfo(logging.ComponentSNI, "Shutdown signal received",
+ zap.String("signal", sig.String()))
+ case err := <-serveErrCh:
+ logger.ComponentError(logging.ComponentSNI, "Serve returned",
+ zap.Error(err))
+ }
+
+ // Stop accepting new connections, then drain in-flight ones.
+ _ = ln.Close()
+ srv.Close()
+
+ logger.ComponentInfo(logging.ComponentSNI, "SNI router shutdown complete")
+}
+
+func parseConfig(logger *logging.ColoredLogger) yamlConfig {
+ configFlag := flag.String("config", "", "Config file path (absolute or filename in ~/.orama)")
+ flag.Parse()
+
+ var configPath string
+ var err error
+ if *configFlag != "" {
+ if filepath.IsAbs(*configFlag) {
+ configPath = *configFlag
+ } else {
+ configPath, err = config.DefaultPath(*configFlag)
+ if err != nil {
+ logger.ComponentError(logging.ComponentSNI, "Failed to determine config path",
+ zap.Error(err))
+ os.Exit(1)
+ }
+ }
+ } else {
+ configPath, err = config.DefaultPath("sni-router.yaml")
+ if err != nil {
+ logger.ComponentError(logging.ComponentSNI, "Failed to determine config path",
+ zap.Error(err))
+ os.Exit(1)
+ }
+ }
+
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ logger.ComponentError(logging.ComponentSNI, "Config file not found",
+ zap.String("path", configPath), zap.Error(err))
+ fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
+ os.Exit(1)
+ }
+
+ var y yamlConfig
+ if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
+ logger.ComponentError(logging.ComponentSNI, "Failed to parse SNI router config",
+ zap.Error(err))
+ fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
+ os.Exit(1)
+ }
+
+ if errs := validateConfig(&y); len(errs) > 0 {
+ fmt.Fprintf(os.Stderr, "\nSNI router configuration errors (%d):\n", len(errs))
+ for _, e := range errs {
+ fmt.Fprintf(os.Stderr, " - %s\n", e)
+ }
+ fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
+ os.Exit(1)
+ }
+
+ logger.ComponentInfo(logging.ComponentSNI, "Loaded SNI router configuration",
+ zap.String("path", configPath),
+ )
+
+ return y
+}
+
+// validateConfig returns a non-empty slice of human-readable errors on misconfig.
+func validateConfig(y *yamlConfig) []string {
+ var errs []string
+ if y.Listen == "" {
+ errs = append(errs, "listen: required (e.g. \":443\")")
+ }
+ if y.Fallback.Addr == "" {
+ errs = append(errs, "fallback.addr: required (where to send unmatched SNIs, typically Caddy)")
+ }
+ for i, r := range y.Routes {
+ if r.Match == "" {
+ errs = append(errs, fmt.Sprintf("routes[%d].match: required", i))
+ }
+ if r.Backend.Addr == "" {
+ errs = append(errs, fmt.Sprintf("routes[%d].backend.addr: required", i))
+ }
+ }
+ return errs
+}
+
+func toBackend(b yamlBackend) sniproxy.Backend {
+ network := b.Network
+ if network == "" {
+ network = "tcp"
+ }
+ return sniproxy.Backend{
+ Name: b.Name,
+ Network: network,
+ Addr: b.Addr,
+ }
+}
+
+func toRoutes(in []yamlRoute) []sniproxy.Route {
+ out := make([]sniproxy.Route, len(in))
+ for i, r := range in {
+ out[i] = sniproxy.Route{
+ Match: r.Match,
+ Backend: toBackend(r.Backend),
+ }
+ }
+ return out
+}
diff --git a/core/docs/STEALTH_TURN.md b/core/docs/STEALTH_TURN.md
new file mode 100644
index 0000000..1005e89
--- /dev/null
+++ b/core/docs/STEALTH_TURN.md
@@ -0,0 +1,187 @@
+# Stealth TURN Deployment Guide
+
+## What this is
+
+A TLS-level SNI router that lets Orama serve TURN-over-TLS on `:443`,
+sharing the port with Caddy HTTPS. From a network observer's
+perspective, TURN traffic is indistinguishable from ordinary HTTPS —
+useful for users in regions that block standard VoIP ports (UAE, Saudi
+Arabia, China, Iran).
+
+## Architecture
+
+```
+ Internet
+ │
+ ▼
+ TCP :443
+ │
+ ┌─────────┴─────────┐
+ │ orama-sni-router │ peeks SNI, forwards bytes
+ └─────────┬─────────┘
+ │
+ ┌───────────────┼────────────────┐
+ ▼ ▼
+ cdn. *.,
+ turn. (everything else)
+ │ │
+ ▼ ▼
+ Pion TURN-TLS Caddy
+ 127.0.0.1:5349 127.0.0.1:8443
+ (existing) (moved from :443)
+```
+
+The router does **not** terminate TLS. It reads the unencrypted TLS
+ClientHello (first ~5 KB), inspects the SNI extension, and dials the
+matching backend. Encrypted bytes pass through verbatim.
+
+## Components
+
+- **Library:** `pkg/sniproxy/` — ClientHello parser, route table, TCP server
+- **Binary:** `cmd/sni-router/` (built as `bin/orama-sni-router`)
+- **Systemd unit:** `systemd/orama-sni-router.service`
+- **Config:** `~/.orama/sni-router.yaml`
+
+## Deployment cutover
+
+⚠️ **This change touches production `:443`. Stage on one node first, watch for 24h, then roll out.**
+
+### 1. Reconfigure Caddy to listen on `:8443`
+
+Update wherever the Caddy config is generated (`pkg/environments/production/installers/caddy.go`)
+so Caddy binds `:8443` (HTTPS) and `:8080` (HTTP) instead of `:443` and `:80`.
+
+Drop `CAP_NET_BIND_SERVICE` from Caddy's systemd unit — it no longer needs privileged ports.
+
+### 2. Provision the cert SAN for `cdn.`
+
+Caddy's automatic Let's Encrypt flow needs to issue a cert covering
+`cdn.` and `cdn.ns-*.` so Pion TURN can read it
+on startup. Add these names to Caddy's TLS config block.
+
+### 3. Drop `sni-router.yaml` config
+
+Example for a single-namespace node:
+
+```yaml
+listen: ":443"
+client_hello_timeout: 5s
+backend_dial_timeout: 5s
+max_concurrent_conns: 10000
+fallback:
+ name: caddy
+ addr: "127.0.0.1:8443"
+routes:
+ - match: "cdn.example.com"
+ backend:
+ name: turn-tls
+ addr: "127.0.0.1:5349"
+ - match: "turn.example.com"
+ backend:
+ name: turn-tls
+ addr: "127.0.0.1:5349"
+```
+
+For multi-namespace, add per-namespace TURN backends (each namespace's
+TURN-TLS port is allocated by `pkg/namespace`):
+
+```yaml
+ - match: "cdn.ns-myapp.example.com"
+ backend: { name: "turn-myapp", addr: "127.0.0.1:5349" }
+ - match: "cdn.ns-other.example.com"
+ backend: { name: "turn-other", addr: "127.0.0.1:5350" }
+```
+
+### 4. Deploy + start in order
+
+```bash
+# Install binary
+sudo cp bin-linux/orama-sni-router /opt/orama/bin/
+
+# Install service
+sudo cp systemd/orama-sni-router.service /etc/systemd/system/
+sudo systemctl daemon-reload
+
+# Stop Caddy briefly (it's about to lose :443)
+sudo systemctl stop caddy
+
+# Start the SNI router (it takes :443)
+sudo systemctl enable --now orama-sni-router
+
+# Restart Caddy on its new port
+sudo systemctl start caddy
+
+# Verify
+curl -v https://cdn.:443 # should hit TURN backend (TLS handshake will fail; that's fine)
+curl -v https://:443 # should hit Caddy (normal HTTPS response)
+```
+
+### 5. Enable stealth in the gateway
+
+Once the SNI router is live, tell the gateway to advertise the stealth URI:
+
+```go
+// in gateway dependencies / startup
+webrtcHandlers.SetStealthCDNDomain("cdn.")
+```
+
+The credentials handler will start including `turns:cdn.:443`
+in `POST /v1/webrtc/turn/credentials` responses automatically.
+
+### 6. Monitor
+
+```bash
+journalctl -u orama-sni-router.service -f
+journalctl -u caddy.service -f
+```
+
+Watch for:
+- `Connection limit reached` warnings (bump `max_concurrent_conns`)
+- `backend dial failed` warnings (Caddy isn't listening on `:8443`, or TURN isn't on `:5349`)
+- `ClientHello peek failed` debugs (curious clients sending non-TLS to `:443` — usually port scanners)
+
+## Rollback
+
+If anything is wrong:
+
+```bash
+sudo systemctl stop orama-sni-router
+# Reconfigure Caddy back to :443 and restart
+sudo systemctl restart caddy
+```
+
+Caddy reclaiming `:443` from the disabled router is the fastest way back to
+the previous topology.
+
+## Known gaps
+
+- **Dynamic route source:** today's router reads YAML once at startup. To
+ pick up new namespaces without restart, implement a `RouteSource` that
+ polls `pkg/namespace` for active TURN deployments. The library is
+ already designed for `Router.Replace` to be called concurrently.
+- **TLS cert hot-reload:** Pion TURN reads the cert once at startup. When
+ Caddy renews `cdn.`, Pion needs to be restarted to pick up
+ the new cert. A small file-watcher service (or a periodic restart in
+ off-peak hours) handles this for now.
+
+## What clients see
+
+Once enabled, the credentials response gains one entry:
+
+```json
+{
+ "username": "...",
+ "password": "...",
+ "ttl": 600,
+ "uris": [
+ "turn:turn.example.com:3478?transport=udp",
+ "turn:turn.example.com:3478?transport=tcp",
+ "turns:turn.example.com:5349",
+ "turns:cdn.example.com:443"
+ ]
+}
+```
+
+Browsers iterate ICE candidates; users in restricted regions will silently
+succeed via the `:443` URI when others fail. No client-side change is
+required.
diff --git a/core/migrations/021_pubsub_trigger_patterns.sql b/core/migrations/021_pubsub_trigger_patterns.sql
new file mode 100644
index 0000000..5b91e35
--- /dev/null
+++ b/core/migrations/021_pubsub_trigger_patterns.sql
@@ -0,0 +1,28 @@
+-- =============================================================================
+-- 021_pubsub_trigger_patterns.sql
+--
+-- Add `topic_pattern` column alongside the existing `topic` column to
+-- function_pubsub_triggers. The new column may contain SQLite GLOB
+-- patterns (e.g. "presence:*") in addition to exact topic names.
+--
+-- This is intentionally ADDITIVE rather than a column rename to remain
+-- safe under rolling upgrades:
+-- - Old binaries continue reading `topic` and keep working.
+-- - New binaries read `topic_pattern` (which is back-filled from
+-- `topic` for existing rows) and write BOTH columns.
+-- A future migration can DROP COLUMN topic once every node is on the
+-- new release.
+-- =============================================================================
+
+ALTER TABLE function_pubsub_triggers
+ ADD COLUMN topic_pattern TEXT NOT NULL DEFAULT '';
+
+UPDATE function_pubsub_triggers
+SET topic_pattern = topic
+WHERE topic_pattern = '';
+
+CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_function
+ ON function_pubsub_triggers(function_id);
+
+CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_enabled
+ ON function_pubsub_triggers(enabled);
diff --git a/core/migrations/022_aggregation_windows.sql b/core/migrations/022_aggregation_windows.sql
new file mode 100644
index 0000000..05f68f0
--- /dev/null
+++ b/core/migrations/022_aggregation_windows.sql
@@ -0,0 +1,20 @@
+-- =============================================================================
+-- 022_aggregation_windows.sql
+--
+-- Add per-trigger aggregation parameters to function_pubsub_triggers.
+--
+-- aggregation_window_ms = 0 means "no aggregation, invoke once per event"
+-- (the existing behaviour). Any positive value enables buffering of events
+-- in-memory on the dispatching node; the function is invoked once per
+-- window with a batched payload.
+--
+-- aggregation_max_batch_size caps the per-window batch. When the buffer
+-- reaches this size, the dispatcher flushes immediately even if the
+-- window timer hasn't fired yet.
+-- =============================================================================
+
+ALTER TABLE function_pubsub_triggers
+ ADD COLUMN aggregation_window_ms INTEGER NOT NULL DEFAULT 0;
+
+ALTER TABLE function_pubsub_triggers
+ ADD COLUMN aggregation_max_batch_size INTEGER NOT NULL DEFAULT 100;
diff --git a/core/migrations/023_push_devices.sql b/core/migrations/023_push_devices.sql
new file mode 100644
index 0000000..bc6c908
--- /dev/null
+++ b/core/migrations/023_push_devices.sql
@@ -0,0 +1,33 @@
+-- =============================================================================
+-- 023_push_devices.sql
+--
+-- Per-namespace, per-user push notification device registry.
+--
+-- token_encrypted is AES-256-GCM ciphertext (prefix 'enc:') derived via
+-- pkg/secrets. Tokens are sensitive — they let the holder spam a user's
+-- device — so they are never returned via any API or written to logs.
+--
+-- provider matches a registered push.PushProvider name:
+-- 'ntfy', 'expo', 'apns', 'fcm' (future), ...
+-- =============================================================================
+
+CREATE TABLE IF NOT EXISTS push_devices (
+ id TEXT PRIMARY KEY,
+ namespace TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ provider TEXT NOT NULL,
+ token_encrypted TEXT NOT NULL,
+ platform TEXT,
+ app_version TEXT,
+ created_at INTEGER NOT NULL,
+ updated_at INTEGER NOT NULL,
+ last_seen INTEGER,
+ UNIQUE(namespace, user_id, device_id)
+);
+
+CREATE INDEX IF NOT EXISTS idx_push_devices_user
+ ON push_devices(namespace, user_id);
+
+CREATE INDEX IF NOT EXISTS idx_push_devices_provider
+ ON push_devices(provider);
diff --git a/core/pkg/cli/build/builder.go b/core/pkg/cli/build/builder.go
index 2e61dcf..51a32e1 100644
--- a/core/pkg/cli/build/builder.go
+++ b/core/pkg/cli/build/builder.go
@@ -157,6 +157,7 @@ func (b *Builder) buildOramaBinaries() error {
{Name: "identity", Package: "./cmd/identity/"},
{Name: "sfu", Package: "./cmd/sfu/"},
{Name: "turn", Package: "./cmd/turn/"},
+ {Name: "orama-sni-router", Package: "./cmd/sni-router/"},
}
for _, bin := range binaries {
diff --git a/core/pkg/cli/production/clean/clean.go b/core/pkg/cli/production/clean/clean.go
index f473683..fe4b61f 100644
--- a/core/pkg/cli/production/clean/clean.go
+++ b/core/pkg/cli/production/clean/clean.go
@@ -171,7 +171,7 @@ rm -f /tmp/orama-*.sh /tmp/network-source.tar.gz /tmp/orama-*.tar.gz
# Nuclear: remove binaries
if [ -n "$NUCLEAR" ]; then
rm -f /usr/local/bin/orama /usr/local/bin/orama-node /usr/local/bin/gateway
- rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn
+ rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn /usr/local/bin/orama-sni-router
rm -f /usr/local/bin/olric-server /usr/local/bin/ipfs /usr/local/bin/ipfs-cluster-service
rm -f /usr/local/bin/rqlited /usr/local/bin/coredns
rm -f /usr/bin/caddy
diff --git a/core/pkg/client/interface.go b/core/pkg/client/interface.go
index 2c7e40b..1fff4c9 100644
--- a/core/pkg/client/interface.go
+++ b/core/pkg/client/interface.go
@@ -47,10 +47,29 @@ type DatabaseClient interface {
type PubSubClient interface {
Subscribe(ctx context.Context, topic string, handler MessageHandler) error
Publish(ctx context.Context, topic string, data []byte) error
+ // PublishBatch publishes multiple messages in parallel, one per topic.
+ // See pubsub.Manager.PublishBatch for semantics (fail-fast vs. best-effort).
+ PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error
+ // PublishSame sends the same payload to every topic in parallel.
+ PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error
Unsubscribe(ctx context.Context, topic string) error
ListTopics(ctx context.Context) ([]string, error)
}
+// TopicMessage is one entry in a batch publish.
+// Mirrors pubsub.TopicMessage to avoid forcing client callers to import pkg/pubsub.
+type TopicMessage struct {
+ Topic string
+ Data []byte
+}
+
+// PublishBatchOptions controls batch publish behavior.
+// Mirrors pubsub.PublishBatchOptions.
+type PublishBatchOptions struct {
+ BestEffort bool
+ MaxConcurrency int
+}
+
// NetworkInfo provides network status and peer information
type NetworkInfo interface {
GetPeers(ctx context.Context) ([]PeerInfo, error)
diff --git a/core/pkg/client/pubsub_bridge.go b/core/pkg/client/pubsub_bridge.go
index 653301e..79780b0 100644
--- a/core/pkg/client/pubsub_bridge.go
+++ b/core/pkg/client/pubsub_bridge.go
@@ -4,13 +4,13 @@ import (
"context"
"fmt"
- "github.com/DeBrosOfficial/network/pkg/pubsub"
+ pkgpubsub "github.com/DeBrosOfficial/network/pkg/pubsub"
)
// pubSubBridge bridges between our PubSubClient interface and the pubsub package
type pubSubBridge struct {
client *Client
- adapter *pubsub.ClientAdapter
+ adapter *pkgpubsub.ClientAdapter
}
func (p *pubSubBridge) Subscribe(ctx context.Context, topic string, handler MessageHandler) error {
@@ -31,6 +31,26 @@ func (p *pubSubBridge) Publish(ctx context.Context, topic string, data []byte) e
return p.adapter.Publish(ctx, topic, data)
}
+func (p *pubSubBridge) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
+ if err := p.client.requireAccess(ctx); err != nil {
+ return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
+ }
+ pkgMsgs := make([]pkgpubsub.TopicMessage, len(msgs))
+ for i, m := range msgs {
+ pkgMsgs[i] = pkgpubsub.TopicMessage{Topic: m.Topic, Data: m.Data}
+ }
+ pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency}
+ return p.adapter.PublishBatch(ctx, pkgMsgs, pkgOpts)
+}
+
+func (p *pubSubBridge) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
+ if err := p.client.requireAccess(ctx); err != nil {
+ return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
+ }
+ pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency}
+ return p.adapter.PublishSame(ctx, topics, data, pkgOpts)
+}
+
func (p *pubSubBridge) Unsubscribe(ctx context.Context, topic string) error {
if err := p.client.requireAccess(ctx); err != nil {
return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
diff --git a/core/pkg/gateway/config.go b/core/pkg/gateway/config.go
index 41cdebb..323ae48 100644
--- a/core/pkg/gateway/config.go
+++ b/core/pkg/gateway/config.go
@@ -56,4 +56,15 @@ type Config struct {
SFUPort int // Local SFU signaling port to proxy WebSocket connections to
TURNDomain string // TURN server domain for credential generation
TURNSecret string // HMAC-SHA1 shared secret for TURN credential generation
+
+ // StealthCDNDomain, when set, makes the WebRTC credentials handler
+ // advertise turns::443 (served by the SNI router).
+ StealthCDNDomain string
+
+ // Push notification configuration. Push is enabled when at least one
+ // provider URL/token is set. Tokens stored in the push_devices table
+ // are encrypted at rest via pkg/secrets using the cluster secret.
+ NtfyBaseURL string // ntfy server URL (e.g. "http://localhost:8080")
+ NtfyAuthToken string // optional bearer token for ntfy
+ ExpoAccessToken string // optional Expo access token
}
diff --git a/core/pkg/gateway/dependencies.go b/core/pkg/gateway/dependencies.go
index eaad2dd..f109cd6 100644
--- a/core/pkg/gateway/dependencies.go
+++ b/core/pkg/gateway/dependencies.go
@@ -19,6 +19,9 @@ import (
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/pubsub"
+ "github.com/DeBrosOfficial/network/pkg/push"
+ pushexpo "github.com/DeBrosOfficial/network/pkg/push/providers/expo"
+ pushntfy "github.com/DeBrosOfficial/network/pkg/push/providers/ntfy"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/DeBrosOfficial/network/pkg/serverless/hostfunctions"
@@ -63,6 +66,11 @@ type Dependencies struct {
// PubSub trigger dispatcher (used to wire into PubSubHandlers)
PubSubDispatcher *triggers.PubSubDispatcher
+ // Push notification dispatcher (nil when push isn't configured —
+ // hostfunc + HTTP handlers degrade to no-op / 503).
+ PushDispatcher *push.PushDispatcher
+ PushDeviceStore push.PushDeviceStore
+
// Authentication service
AuthService *auth.Service
}
@@ -412,6 +420,19 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
secretsMgr = smImpl
}
+ // Initialize push notification dispatcher if any provider is configured.
+ // Devices are stored encrypted in RQLite (see migration 023). Providers
+ // are registered based on gateway config; missing config = provider absent.
+ pushDispatcher, pushStore, err := buildPushDispatcher(cfg, deps.ORMClient, logger)
+ if err != nil {
+ // Non-fatal: log and continue. Functions calling push_send will get nil
+ // (silent no-op) and HTTP /v1/push/* endpoints return 503.
+ logger.ComponentWarn(logging.ComponentGeneral,
+ "push notifications disabled (init failed)", zap.Error(err))
+ }
+ deps.PushDispatcher = pushDispatcher
+ deps.PushDeviceStore = pushStore
+
// Create host functions provider (allows functions to call Orama services)
hostFuncsCfg := hostfunctions.HostFunctionsConfig{
IPFSAPIURL: cfg.IPFSAPIURL,
@@ -424,6 +445,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
pubsubAdapter, // pubsub adapter for serverless functions
deps.ServerlessWSMgr,
secretsMgr,
+ pushDispatcher, // may be nil — PushSend hostfunc handles that
hostFuncsCfg,
logger.Logger,
)
@@ -686,3 +708,40 @@ func injectRQLiteAuth(dsn, username, password string) string {
}
return dsn
}
+
+// buildPushDispatcher constructs a push.PushDispatcher + device store with
+// all enabled providers. Returns (nil, nil, nil) when no provider is
+// configured — that's a supported state, not an error. Returns (nil, nil,
+// err) on hard init failures (e.g. cluster secret missing for the
+// encrypted device store).
+func buildPushDispatcher(cfg *Config, db rqlite.Client, logger *logging.ColoredLogger) (*push.PushDispatcher, push.PushDeviceStore, error) {
+ if cfg.NtfyBaseURL == "" && cfg.ExpoAccessToken == "" {
+ // No providers configured — push is disabled.
+ return nil, nil, nil
+ }
+ if cfg.ClusterSecret == "" {
+ // Devices are encrypted at rest using a cluster-secret-derived key.
+ // Without it we can't store anything safely.
+ return nil, nil, fmt.Errorf("push enabled but ClusterSecret is empty")
+ }
+ store, err := push.NewRqliteDeviceStore(db, cfg.ClusterSecret, logger.Logger)
+ if err != nil {
+ return nil, nil, fmt.Errorf("init push device store: %w", err)
+ }
+ d := push.New(store, logger.Logger)
+ if cfg.NtfyBaseURL != "" {
+ d.Register(pushntfy.New(pushntfy.Config{
+ BaseURL: cfg.NtfyBaseURL,
+ AuthToken: cfg.NtfyAuthToken,
+ }, logger.Logger))
+ logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: ntfy",
+ zap.String("base_url", cfg.NtfyBaseURL))
+ }
+ if cfg.ExpoAccessToken != "" {
+ d.Register(pushexpo.New(pushexpo.Config{
+ AccessToken: cfg.ExpoAccessToken,
+ }, logger.Logger))
+ logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: expo")
+ }
+ return d, store, nil
+}
diff --git a/core/pkg/gateway/gateway.go b/core/pkg/gateway/gateway.go
index 531f883..61bfa2b 100644
--- a/core/pkg/gateway/gateway.go
+++ b/core/pkg/gateway/gateway.go
@@ -28,6 +28,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache"
deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments"
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
+ pushhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/push"
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
enrollhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/enroll"
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
@@ -43,6 +44,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
+ "github.com/DeBrosOfficial/network/pkg/serverless/triggers"
_ "github.com/mattn/go-sqlite3"
"go.uber.org/zap"
)
@@ -83,13 +85,15 @@ type Gateway struct {
mu sync.RWMutex
presenceMu sync.RWMutex
pubsubHandlers *pubsubhandlers.PubSubHandlers
+ pushHandlers *pushhandlers.Handlers
// Serverless function engine
- serverlessEngine *serverless.Engine
- serverlessRegistry *serverless.Registry
- serverlessInvoker *serverless.Invoker
- serverlessWSMgr *serverless.WSManager
- serverlessHandlers *serverlesshandlers.ServerlessHandlers
+ serverlessEngine *serverless.Engine
+ serverlessRegistry *serverless.Registry
+ serverlessInvoker *serverless.Invoker
+ serverlessWSMgr *serverless.WSManager
+ serverlessHandlers *serverlesshandlers.ServerlessHandlers
+ pubsubDispatcher *triggers.PubSubDispatcher
// Authentication service
authService *auth.Service
@@ -342,11 +346,20 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
// Wire PubSub trigger dispatch if serverless is available
if deps.PubSubDispatcher != nil {
+ gw.pubsubDispatcher = deps.PubSubDispatcher
gw.pubsubHandlers.SetOnPublish(func(ctx context.Context, namespace, topic string, data []byte) {
deps.PubSubDispatcher.Dispatch(ctx, namespace, topic, data, 0)
})
}
+ // Push notification handlers — disabled when no provider is configured.
+ // The handlers themselves return 503 if dispatcher/store is nil; we
+ // register them unconditionally so the routes always exist with a
+ // predictable shape.
+ if deps.PushDispatcher != nil {
+ gw.pushHandlers = pushhandlers.NewHandlers(deps.PushDispatcher, deps.PushDeviceStore, logger)
+ }
+
if cfg.WebRTCEnabled && cfg.SFUPort > 0 {
gw.webrtcHandlers = webrtchandlers.NewWebRTCHandlers(
logger,
diff --git a/core/pkg/gateway/handlers/pubsub/handlers_test.go b/core/pkg/gateway/handlers/pubsub/handlers_test.go
index 71263b2..7e3b37c 100644
--- a/core/pkg/gateway/handlers/pubsub/handlers_test.go
+++ b/core/pkg/gateway/handlers/pubsub/handlers_test.go
@@ -21,10 +21,12 @@ import (
// mockPubSubClient implements client.PubSubClient for testing
type mockPubSubClient struct {
- PublishFunc func(ctx context.Context, topic string, data []byte) error
- SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error
- UnsubscribeFunc func(ctx context.Context, topic string) error
- ListTopicsFunc func(ctx context.Context) ([]string, error)
+ PublishFunc func(ctx context.Context, topic string, data []byte) error
+ PublishBatchFunc func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error
+ PublishSameFunc func(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error
+ SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error
+ UnsubscribeFunc func(ctx context.Context, topic string) error
+ ListTopicsFunc func(ctx context.Context) ([]string, error)
}
func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byte) error {
@@ -34,6 +36,20 @@ func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byt
return nil
}
+func (m *mockPubSubClient) PublishBatch(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error {
+ if m.PublishBatchFunc != nil {
+ return m.PublishBatchFunc(ctx, msgs, opts)
+ }
+ return nil
+}
+
+func (m *mockPubSubClient) PublishSame(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error {
+ if m.PublishSameFunc != nil {
+ return m.PublishSameFunc(ctx, topics, data, opts)
+ }
+ return nil
+}
+
func (m *mockPubSubClient) Subscribe(ctx context.Context, topic string, handler client.MessageHandler) error {
if m.SubscribeFunc != nil {
return m.SubscribeFunc(ctx, topic, handler)
diff --git a/core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go b/core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go
new file mode 100644
index 0000000..21f05cc
--- /dev/null
+++ b/core/pkg/gateway/handlers/pubsub/publish_batch_handler_test.go
@@ -0,0 +1,156 @@
+package pubsub
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/DeBrosOfficial/network/pkg/client"
+)
+
+func TestPublishBatchHandler_invalid_method(t *testing.T) {
+ h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
+
+ req := withNamespace(httptest.NewRequest(http.MethodGet, "/v1/pubsub/publish-batch", nil), "ns")
+ rr := httptest.NewRecorder()
+ h.PublishBatchHandler(rr, req)
+
+ if rr.Code != http.StatusMethodNotAllowed {
+ t.Errorf("expected 405, got %d", rr.Code)
+ }
+}
+
+func TestPublishBatchHandler_missing_namespace(t *testing.T) {
+ h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
+
+ body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{{Topic: "a", DataB64: "AA=="}}})
+ req := httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body))
+ rr := httptest.NewRecorder()
+ h.PublishBatchHandler(rr, req)
+
+ if rr.Code != http.StatusForbidden {
+ t.Errorf("expected 403, got %d (body: %s)", rr.Code, rr.Body.String())
+ }
+}
+
+func TestPublishBatchHandler_empty_messages_rejected(t *testing.T) {
+ h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
+
+ body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{}})
+ req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
+ rr := httptest.NewRecorder()
+ h.PublishBatchHandler(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for empty messages, got %d", rr.Code)
+ }
+}
+
+func TestPublishBatchHandler_oversize_batch_rejected(t *testing.T) {
+ h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
+
+ entries := make([]PublishBatchEntry, MaxPublishBatchSize+1)
+ for i := range entries {
+ entries[i] = PublishBatchEntry{Topic: "t", DataB64: "AA=="}
+ }
+ body, _ := json.Marshal(PublishBatchRequest{Messages: entries})
+ req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
+ rr := httptest.NewRecorder()
+ h.PublishBatchHandler(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for oversize batch, got %d", rr.Code)
+ }
+}
+
+func TestPublishBatchHandler_invalid_base64_rejected(t *testing.T) {
+ h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
+
+ body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{
+ {Topic: "good", DataB64: base64.StdEncoding.EncodeToString([]byte("ok"))},
+ {Topic: "bad", DataB64: "!!!not-base64"},
+ }})
+ req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
+ rr := httptest.NewRecorder()
+ h.PublishBatchHandler(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for invalid base64, got %d", rr.Code)
+ }
+}
+
+func TestPublishBatchHandler_missing_topic_rejected(t *testing.T) {
+ h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
+
+ body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{
+ {Topic: "", DataB64: base64.StdEncoding.EncodeToString([]byte("x"))},
+ }})
+ req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
+ rr := httptest.NewRecorder()
+ h.PublishBatchHandler(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400 for missing topic, got %d", rr.Code)
+ }
+}
+
+func TestPublishBatchHandler_happy_calls_PublishBatch(t *testing.T) {
+ var (
+ called int32
+ gotMessages []client.TopicMessage
+ mu sync.Mutex
+ )
+ mock := &mockPubSubClient{
+ PublishBatchFunc: func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error {
+ atomic.AddInt32(&called, 1)
+ mu.Lock()
+ gotMessages = append(gotMessages, msgs...)
+ mu.Unlock()
+ return nil
+ },
+ }
+ h := newTestHandlers(&mockNetworkClient{pubsub: mock})
+
+ entries := []PublishBatchEntry{
+ {Topic: "a", DataB64: base64.StdEncoding.EncodeToString([]byte("data-a"))},
+ {Topic: "b", DataB64: base64.StdEncoding.EncodeToString([]byte("data-b"))},
+ }
+ body, _ := json.Marshal(PublishBatchRequest{Messages: entries})
+ req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "test-ns")
+ rr := httptest.NewRecorder()
+
+ h.PublishBatchHandler(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
+ }
+
+ // PublishBatch is invoked from a goroutine; give it a moment to run.
+ deadline := time.Now().Add(2 * time.Second)
+ for atomic.LoadInt32(&called) == 0 {
+ if time.Now().After(deadline) {
+ t.Fatal("PublishBatch was not called within 2s")
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+ if len(gotMessages) != 2 {
+ t.Fatalf("expected 2 messages forwarded, got %d", len(gotMessages))
+ }
+ if gotMessages[0].Topic != "a" || string(gotMessages[0].Data) != "data-a" {
+ t.Errorf("unexpected first message: %+v", gotMessages[0])
+ }
+ if gotMessages[1].Topic != "b" || string(gotMessages[1].Data) != "data-b" {
+ t.Errorf("unexpected second message: %+v", gotMessages[1])
+ }
+}
+
diff --git a/core/pkg/gateway/handlers/pubsub/publish_handler.go b/core/pkg/gateway/handlers/pubsub/publish_handler.go
index a3cedd5..7c7e9fc 100644
--- a/core/pkg/gateway/handlers/pubsub/publish_handler.go
+++ b/core/pkg/gateway/handlers/pubsub/publish_handler.go
@@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"net/http"
+ "strconv"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
@@ -12,6 +13,10 @@ import (
"go.uber.org/zap"
)
+// MaxPublishBatchSize is the maximum number of messages allowed in a single
+// /v1/pubsub/publish-batch request. Mirrors pubsub.MaxBatchSize.
+const MaxPublishBatchSize = pubsub.MaxBatchSize
+
// PublishHandler handles POST /v1/pubsub/publish {topic, data_base64}
func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request) {
if p.client == nil {
@@ -39,9 +44,133 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request)
return
}
- // Check for local websocket subscribers FIRST and deliver directly
+ p.deliverLocal(ns, body.Topic, data)
+
+ // Publish to libp2p asynchronously for cross-node delivery.
+ go func() {
+ publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
+ if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
+ p.logger.ComponentWarn("gateway", "async libp2p publish failed",
+ zap.String("topic", body.Topic),
+ zap.Error(err))
+ }
+ }()
+
+ writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
+}
+
+// PublishBatchRequest is the request body for POST /v1/pubsub/publish-batch.
+type PublishBatchRequest struct {
+ Messages []PublishBatchEntry `json:"messages"`
+ BestEffort bool `json:"best_effort,omitempty"`
+}
+
+// PublishBatchEntry is one message in a batch publish request.
+type PublishBatchEntry struct {
+ Topic string `json:"topic"`
+ DataB64 string `json:"data_base64"`
+}
+
+// PublishBatchResponse is the response body for /v1/pubsub/publish-batch.
+//
+// libp2p delivery is asynchronous and not awaited here, mirroring the
+// single-publish handler's fire-and-forget contract. Per-topic failures
+// are not surfaced via this response — operators should consult logs /
+// metrics for delivery health.
+type PublishBatchResponse struct {
+ Status string `json:"status"` // always "ok" — request was accepted
+}
+
+// MaxPerMessageBytes caps an individual message payload inside a batch.
+// Mirrors the 1MB cap on /v1/pubsub/publish.
+const MaxPerMessageBytes = 1 << 20
+
+// PublishBatchHandler handles POST /v1/pubsub/publish-batch.
+// Accepts up to MaxPublishBatchSize messages and publishes them in parallel,
+// preserving namespace isolation. Local subscribers receive messages
+// immediately; libp2p delivery is async.
+func (p *PubSubHandlers) PublishBatchHandler(w http.ResponseWriter, r *http.Request) {
+ if p.client == nil {
+ writeError(w, http.StatusServiceUnavailable, "client not initialized")
+ return
+ }
+ if r.Method != http.MethodPost {
+ writeError(w, http.StatusMethodNotAllowed, "method not allowed")
+ return
+ }
+ ns := resolveNamespaceFromRequest(r)
+ if ns == "" {
+ writeError(w, http.StatusForbidden, "namespace not resolved")
+ return
+ }
+
+ // Limit body size: MaxPublishBatchSize messages * ~1MB each = up to ~100MB.
+ // Cap conservatively at 16MB to discourage huge payloads.
+ r.Body = http.MaxBytesReader(w, r.Body, 16<<20)
+
+ var body PublishBatchRequest
+ if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
+ writeError(w, http.StatusBadRequest, "invalid body: expected {messages:[{topic,data_base64}]}")
+ return
+ }
+ if len(body.Messages) == 0 {
+ writeError(w, http.StatusBadRequest, "messages required")
+ return
+ }
+ if len(body.Messages) > MaxPublishBatchSize {
+ writeError(w, http.StatusBadRequest, "too many messages: max is 100 per batch")
+ return
+ }
+
+ // Decode all messages up-front so we can fail fast on bad input.
+ decoded := make([]pubsub.TopicMessage, 0, len(body.Messages))
+ for i, m := range body.Messages {
+ if m.Topic == "" {
+ writeError(w, http.StatusBadRequest, "message missing topic at index "+strconv.Itoa(i))
+ return
+ }
+ data, err := base64.StdEncoding.DecodeString(m.DataB64)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, "invalid base64 data at index "+strconv.Itoa(i))
+ return
+ }
+ if len(data) > MaxPerMessageBytes {
+ writeError(w, http.StatusBadRequest, "message too large at index "+strconv.Itoa(i))
+ return
+ }
+ decoded = append(decoded, pubsub.TopicMessage{Topic: m.Topic, Data: data})
+ }
+
+ // Deliver locally + dispatch triggers per topic synchronously (fast in-process).
+ for _, msg := range decoded {
+ p.deliverLocal(ns, msg.Topic, msg.Data)
+ }
+
+ // Async libp2p batch publish, similar to PublishHandler's approach.
+ go func() {
+ publishCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+ ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
+ opts := pubsub.PublishBatchOptions{BestEffort: body.BestEffort}
+ err := p.client.PubSub().PublishBatch(ctx, toClientMessages(decoded), clientOpts(opts))
+ if err != nil {
+ p.logger.ComponentWarn("gateway", "async libp2p batch publish failed",
+ zap.Int("messages", len(decoded)),
+ zap.Error(err))
+ }
+ }()
+
+ writeJSON(w, http.StatusOK, PublishBatchResponse{Status: "ok"})
+}
+
+// deliverLocal handles local-subscriber delivery and fires PubSub triggers.
+// It does NOT publish to libp2p — callers handle that themselves (single
+// or batched) so this helper stays focused on in-process fan-out.
+func (p *PubSubHandlers) deliverLocal(ns, topic string, data []byte) {
p.mu.RLock()
- localSubs := p.getLocalSubscribers(body.Topic, ns)
+ localSubs := p.getLocalSubscribers(topic, ns)
p.mu.RUnlock()
localDeliveryCount := 0
@@ -50,48 +179,38 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request)
select {
case sub.msgChan <- data:
localDeliveryCount++
- p.logger.ComponentDebug("gateway", "delivered to local subscriber",
- zap.String("topic", body.Topic))
default:
- // Drop if buffer full
p.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message",
- zap.String("topic", body.Topic))
+ zap.String("topic", topic))
}
}
}
p.logger.ComponentInfo("gateway", "pubsub publish: processing message",
- zap.String("topic", body.Topic),
+ zap.String("topic", topic),
zap.String("namespace", ns),
zap.Int("data_len", len(data)),
zap.Int("local_subscribers", len(localSubs)),
zap.Int("local_delivered", localDeliveryCount))
- // Fire PubSub triggers for serverless functions (non-blocking)
+ // Fire PubSub triggers for serverless functions (non-blocking).
if p.onPublish != nil {
- go p.onPublish(context.Background(), ns, body.Topic, data)
+ go p.onPublish(context.Background(), ns, topic, data)
}
+}
- // Publish to libp2p asynchronously for cross-node delivery
- // This prevents blocking the HTTP response if libp2p network is slow
- go func() {
- publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
+// toClientMessages converts pubsub.TopicMessage to client.TopicMessage for
+// passing through the PubSubClient interface.
+func toClientMessages(msgs []pubsub.TopicMessage) []client.TopicMessage {
+ out := make([]client.TopicMessage, len(msgs))
+ for i, m := range msgs {
+ out[i] = client.TopicMessage{Topic: m.Topic, Data: m.Data}
+ }
+ return out
+}
- ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
- if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
- p.logger.ComponentWarn("gateway", "async libp2p publish failed",
- zap.String("topic", body.Topic),
- zap.Error(err))
- } else {
- p.logger.ComponentDebug("gateway", "async libp2p publish succeeded",
- zap.String("topic", body.Topic))
- }
- }()
-
- // Return immediately after local delivery
- // Local WebSocket subscribers already received the message
- writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
+func clientOpts(o pubsub.PublishBatchOptions) client.PublishBatchOptions {
+ return client.PublishBatchOptions{BestEffort: o.BestEffort, MaxConcurrency: o.MaxConcurrency}
}
// TopicsHandler lists topics within the caller's namespace
diff --git a/core/pkg/gateway/handlers/push/handlers.go b/core/pkg/gateway/handlers/push/handlers.go
new file mode 100644
index 0000000..36eff7b
--- /dev/null
+++ b/core/pkg/gateway/handlers/push/handlers.go
@@ -0,0 +1,291 @@
+package push
+
+import (
+ "encoding/json"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/DeBrosOfficial/network/pkg/push"
+ "go.uber.org/zap"
+)
+
+// validProviders is the allowlist for the `provider` field on RegisterDevice.
+// Keep in sync with what the dispatcher actually has registered at startup.
+var validProviders = map[string]struct{}{
+ "ntfy": {},
+ "expo": {},
+ "apns": {}, // future — accepted at registration so apps can pre-flight
+}
+
+// MaxTokenBytes caps the device-token length to prevent abuse.
+// Real ntfy topic paths and Expo tokens are well under this.
+const MaxTokenBytes = 512
+
+// RegisterDeviceHandler handles POST /v1/push/devices.
+//
+// The caller must be authenticated; their JWT subject (Sub) is used as the
+// user_id. API-key callers are allowed only if the body explicitly carries
+// a user_id — currently rejected to keep the surface small.
+func (h *Handlers) RegisterDeviceHandler(w http.ResponseWriter, r *http.Request) {
+ if h.store == nil {
+ writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
+ return
+ }
+ if r.Method != http.MethodPost {
+ writeError(w, http.StatusMethodNotAllowed, "method not allowed")
+ return
+ }
+
+ ns := resolveNamespace(r)
+ if ns == "" {
+ writeError(w, http.StatusForbidden, "namespace not resolved")
+ return
+ }
+ userID := resolveCallerUserID(r)
+ if userID == "" {
+ // We require a JWT-authenticated user to bind the device to.
+ // API-key-only callers can't register devices on behalf of users.
+ writeError(w, http.StatusUnauthorized, "user authentication required")
+ return
+ }
+
+ r.Body = http.MaxBytesReader(w, r.Body, 4096)
+ var body RegisterDeviceRequest
+ if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
+ writeError(w, http.StatusBadRequest, "invalid body")
+ return
+ }
+ body.DeviceID = strings.TrimSpace(body.DeviceID)
+ body.Provider = strings.TrimSpace(body.Provider)
+ body.Token = strings.TrimSpace(body.Token)
+
+ if body.DeviceID == "" {
+ writeError(w, http.StatusBadRequest, "device_id required")
+ return
+ }
+ if _, ok := validProviders[body.Provider]; !ok {
+ writeError(w, http.StatusBadRequest, "unknown provider: "+body.Provider)
+ return
+ }
+ if body.Token == "" {
+ writeError(w, http.StatusBadRequest, "token required")
+ return
+ }
+ if len(body.Token) > MaxTokenBytes {
+ writeError(w, http.StatusBadRequest, "token too long")
+ return
+ }
+
+ now := time.Now().Unix()
+ dev := push.PushDevice{
+ Namespace: ns,
+ UserID: userID,
+ DeviceID: body.DeviceID,
+ Provider: body.Provider,
+ Token: body.Token,
+ Platform: body.Platform,
+ AppVer: body.AppVersion,
+ LastSeen: now,
+ }
+ if err := h.store.Upsert(boundCtx(r), dev); err != nil {
+ h.logger.ComponentWarn("push", "device upsert failed",
+ zap.String("namespace", ns),
+ zap.String("user_id", userID),
+ zap.Error(err))
+ writeError(w, http.StatusInternalServerError, "registration failed")
+ return
+ }
+
+ writeJSON(w, http.StatusOK, RegisterDeviceResponse{Status: "ok"})
+}
+
+// ListDevicesHandler handles GET /v1/push/devices.
+//
+// Returns the caller's own devices; tokens are NEVER included in the
+// response. Other namespaces / other users are inaccessible.
+func (h *Handlers) ListDevicesHandler(w http.ResponseWriter, r *http.Request) {
+ if h.store == nil {
+ writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
+ return
+ }
+ if r.Method != http.MethodGet {
+ writeError(w, http.StatusMethodNotAllowed, "method not allowed")
+ return
+ }
+
+ ns := resolveNamespace(r)
+ if ns == "" {
+ writeError(w, http.StatusForbidden, "namespace not resolved")
+ return
+ }
+ userID := resolveCallerUserID(r)
+ if userID == "" {
+ writeError(w, http.StatusUnauthorized, "user authentication required")
+ return
+ }
+
+ devs, err := h.store.ListForUser(boundCtx(r), ns, userID)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "list failed")
+ return
+ }
+ views := make([]PushDeviceView, len(devs))
+ for i, d := range devs {
+ views[i] = PushDeviceView{
+ ID: d.ID,
+ DeviceID: d.DeviceID,
+ Provider: d.Provider,
+ Platform: d.Platform,
+ AppVersion: d.AppVer,
+ CreatedAt: d.CreatedAt,
+ UpdatedAt: d.UpdatedAt,
+ LastSeen: d.LastSeen,
+ }
+ }
+ writeJSON(w, http.StatusOK, map[string]interface{}{"devices": views})
+}
+
+// DeleteDeviceHandler handles DELETE /v1/push/devices/{id}.
+//
+// `{id}` is the database row ID returned at registration / by ListDevices.
+// Only devices belonging to the caller (matched by namespace + user_id +
+// the device ID lookup) can be deleted.
+func (h *Handlers) DeleteDeviceHandler(w http.ResponseWriter, r *http.Request) {
+ if h.store == nil {
+ writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
+ return
+ }
+ if r.Method != http.MethodDelete {
+ writeError(w, http.StatusMethodNotAllowed, "method not allowed")
+ return
+ }
+
+ ns := resolveNamespace(r)
+ if ns == "" {
+ writeError(w, http.StatusForbidden, "namespace not resolved")
+ return
+ }
+ userID := resolveCallerUserID(r)
+ if userID == "" {
+ writeError(w, http.StatusUnauthorized, "user authentication required")
+ return
+ }
+
+ id := extractIDFromPath(r.URL.Path, "/v1/push/devices/")
+ if id == "" {
+ writeError(w, http.StatusBadRequest, "device id required in path")
+ return
+ }
+
+ // Authorization check: confirm the device belongs to the caller.
+ devs, err := h.store.ListForUser(boundCtx(r), ns, userID)
+ if err != nil {
+ writeError(w, http.StatusInternalServerError, "lookup failed")
+ return
+ }
+ owns := false
+ for _, d := range devs {
+ if d.ID == id {
+ owns = true
+ break
+ }
+ }
+ if !owns {
+ // 404, not 403 — don't leak whether the ID exists in another scope.
+ writeError(w, http.StatusNotFound, "not found")
+ return
+ }
+
+ if err := h.store.Delete(boundCtx(r), ns, id); err != nil {
+ h.logger.ComponentWarn("push", "device delete failed",
+ zap.String("namespace", ns),
+ zap.String("device_row_id", id),
+ zap.Error(err))
+ writeError(w, http.StatusInternalServerError, "delete failed")
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
+}
+
+// SendHandler handles POST /v1/push/send.
+//
+// SECURITY: this endpoint sends arbitrary push messages to any user_id
+// in the caller's namespace. It MUST be gated to a small set of trusted
+// callers — typically only the namespace's own serverless functions
+// (which can send via the WASM `push_send` hostfunc directly without
+// going through HTTP) and the namespace operator.
+//
+// The current implementation accepts any JWT-authenticated caller within
+// the namespace. **Add an explicit allow-list or admin-scope check before
+// exposing this in production.** The WASM hostfunc bypasses this issue
+// because trigger registration already gates which functions exist.
+func (h *Handlers) SendHandler(w http.ResponseWriter, r *http.Request) {
+ if h.dispatcher == nil {
+ writeError(w, http.StatusServiceUnavailable, "push: dispatcher not configured")
+ return
+ }
+ if r.Method != http.MethodPost {
+ writeError(w, http.StatusMethodNotAllowed, "method not allowed")
+ return
+ }
+
+ ns := resolveNamespace(r)
+ if ns == "" {
+ writeError(w, http.StatusForbidden, "namespace not resolved")
+ return
+ }
+ if resolveCallerUserID(r) == "" {
+ writeError(w, http.StatusUnauthorized, "user authentication required")
+ return
+ }
+
+ r.Body = http.MaxBytesReader(w, r.Body, 64*1024) // generous for Data payloads
+ var body SendRequest
+ if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
+ writeError(w, http.StatusBadRequest, "invalid body")
+ return
+ }
+ body.UserID = strings.TrimSpace(body.UserID)
+ if body.UserID == "" {
+ writeError(w, http.StatusBadRequest, "user_id required")
+ return
+ }
+
+ msg := push.PushMessage{
+ Title: body.Title,
+ Body: body.Body,
+ Channel: body.Channel,
+ Priority: pickPriority(body.Priority),
+ Badge: body.Badge,
+ Sound: body.Sound,
+ Data: body.Data,
+ }
+ if err := h.dispatcher.SendToUser(boundCtx(r), ns, body.UserID, msg); err != nil {
+ // Treat as non-fatal: some devices may have failed but others may
+ // have succeeded. Surface as 502 to signal partial trouble; logs
+ // have the per-device detail.
+ h.logger.ComponentWarn("push", "send to user partially failed",
+ zap.String("namespace", ns),
+ zap.String("user_id", body.UserID),
+ zap.Error(err))
+ writeError(w, http.StatusBadGateway, "one or more devices failed")
+ return
+ }
+ writeJSON(w, http.StatusOK, SendResponse{Status: "ok"})
+}
+
+// extractIDFromPath returns the trailing path segment after `prefix`, or
+// empty string if the path doesn't match. Used because the gateway uses
+// the standard `net/http` mux which doesn't extract path params.
+func extractIDFromPath(urlPath, prefix string) string {
+ if !strings.HasPrefix(urlPath, prefix) {
+ return ""
+ }
+ rest := urlPath[len(prefix):]
+ // Drop any query string (shouldn't normally appear in path here).
+ if i := strings.IndexAny(rest, "?#/"); i >= 0 {
+ rest = rest[:i]
+ }
+ return rest
+}
diff --git a/core/pkg/gateway/handlers/push/handlers_test.go b/core/pkg/gateway/handlers/push/handlers_test.go
new file mode 100644
index 0000000..19509d4
--- /dev/null
+++ b/core/pkg/gateway/handlers/push/handlers_test.go
@@ -0,0 +1,330 @@
+package push
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "sync/atomic"
+ "testing"
+
+ authsvc "github.com/DeBrosOfficial/network/pkg/gateway/auth"
+ "github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
+ "github.com/DeBrosOfficial/network/pkg/logging"
+ "github.com/DeBrosOfficial/network/pkg/push"
+ "go.uber.org/zap"
+)
+
+// fakeStore is an in-memory PushDeviceStore for tests.
+type fakeStore struct {
+ devices []push.PushDevice
+ upsertFn func(push.PushDevice) error
+ deleteFn func(ns, id string) error
+ listErr error
+}
+
+func (s *fakeStore) Upsert(ctx context.Context, dev push.PushDevice) error {
+ if s.upsertFn != nil {
+ return s.upsertFn(dev)
+ }
+ if dev.ID == "" {
+ dev.ID = "row-" + dev.DeviceID
+ }
+ s.devices = append(s.devices, dev)
+ return nil
+}
+func (s *fakeStore) Delete(ctx context.Context, ns, id string) error {
+ if s.deleteFn != nil {
+ return s.deleteFn(ns, id)
+ }
+ for i, d := range s.devices {
+ if d.ID == id && d.Namespace == ns {
+ s.devices = append(s.devices[:i], s.devices[i+1:]...)
+ return nil
+ }
+ }
+ return errors.New("not found")
+}
+func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]push.PushDevice, error) {
+ if s.listErr != nil {
+ return nil, s.listErr
+ }
+ out := []push.PushDevice{}
+ for _, d := range s.devices {
+ if d.Namespace == ns && d.UserID == userID {
+ out = append(out, d)
+ }
+ }
+ return out, nil
+}
+
+// withAuth populates the namespace + JWT claims (caller user ID).
+func withAuth(r *http.Request, namespace, userID string) *http.Request {
+ ctx := r.Context()
+ if namespace != "" {
+ ctx = context.WithValue(ctx, ctxkeys.NamespaceOverride, namespace)
+ }
+ if userID != "" {
+ ctx = context.WithValue(ctx, ctxkeys.JWT, &authsvc.JWTClaims{Sub: userID, Namespace: namespace})
+ }
+ return r.WithContext(ctx)
+}
+
+func newHandlers(store push.PushDeviceStore, dispatcher *push.PushDispatcher) *Handlers {
+ logger := &logging.ColoredLogger{Logger: zap.NewNop()}
+ return NewHandlers(dispatcher, store, logger)
+}
+
+// --- RegisterDeviceHandler ---
+
+func TestRegister_happy_path(t *testing.T) {
+ store := &fakeStore{}
+ h := newHandlers(store, nil)
+
+ body, _ := json.Marshal(RegisterDeviceRequest{
+ DeviceID: "iphone-abc",
+ Provider: "ntfy",
+ Token: "ns/myapp/user-1",
+ Platform: "ios",
+ })
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "myapp", "user-1")
+ rr := httptest.NewRecorder()
+ h.RegisterDeviceHandler(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
+ }
+ if len(store.devices) != 1 {
+ t.Fatalf("expected 1 device stored, got %d", len(store.devices))
+ }
+ d := store.devices[0]
+ if d.Namespace != "myapp" || d.UserID != "user-1" || d.Token != "ns/myapp/user-1" {
+ t.Errorf("unexpected device: %+v", d)
+ }
+}
+
+func TestRegister_unauthenticated_rejected(t *testing.T) {
+ h := newHandlers(&fakeStore{}, nil)
+ body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: "t"})
+
+ // No JWT in context.
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "")
+ rr := httptest.NewRecorder()
+ h.RegisterDeviceHandler(rr, req)
+
+ if rr.Code != http.StatusUnauthorized {
+ t.Errorf("expected 401, got %d", rr.Code)
+ }
+}
+
+func TestRegister_unknown_provider_rejected(t *testing.T) {
+ h := newHandlers(&fakeStore{}, nil)
+ body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "weirdmail", Token: "t"})
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u")
+ rr := httptest.NewRecorder()
+ h.RegisterDeviceHandler(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400, got %d", rr.Code)
+ }
+}
+
+func TestRegister_oversize_token_rejected(t *testing.T) {
+ h := newHandlers(&fakeStore{}, nil)
+ huge := make([]byte, MaxTokenBytes+1)
+ for i := range huge {
+ huge[i] = 'a'
+ }
+ body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: string(huge)})
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u")
+ rr := httptest.NewRecorder()
+ h.RegisterDeviceHandler(rr, req)
+
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400, got %d", rr.Code)
+ }
+}
+
+func TestRegister_no_store_returns_503(t *testing.T) {
+ h := newHandlers(nil, nil)
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader([]byte(`{}`))), "ns", "u")
+ rr := httptest.NewRecorder()
+ h.RegisterDeviceHandler(rr, req)
+ if rr.Code != http.StatusServiceUnavailable {
+ t.Errorf("expected 503, got %d", rr.Code)
+ }
+}
+
+// --- ListDevicesHandler ---
+
+func TestList_returns_only_callers_devices_without_tokens(t *testing.T) {
+ store := &fakeStore{
+ devices: []push.PushDevice{
+ {ID: "1", Namespace: "myapp", UserID: "u1", DeviceID: "d1", Provider: "ntfy", Token: "secret-token-1"},
+ {ID: "2", Namespace: "myapp", UserID: "u1", DeviceID: "d2", Provider: "expo", Token: "secret-token-2"},
+ {ID: "3", Namespace: "myapp", UserID: "u2", DeviceID: "d3", Provider: "ntfy", Token: "secret-token-3"},
+ {ID: "4", Namespace: "other", UserID: "u1", DeviceID: "d4", Provider: "ntfy", Token: "secret-token-4"},
+ },
+ }
+ h := newHandlers(store, nil)
+
+ req := withAuth(httptest.NewRequest(http.MethodGet, "/v1/push/devices", nil), "myapp", "u1")
+ rr := httptest.NewRecorder()
+ h.ListDevicesHandler(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d", rr.Code)
+ }
+ var resp struct {
+ Devices []PushDeviceView `json:"devices"`
+ }
+ if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+ if len(resp.Devices) != 2 {
+ t.Fatalf("expected 2 devices, got %d", len(resp.Devices))
+ }
+ // Tokens must NOT appear in response — they're not even in the struct.
+ if bytes.Contains(rr.Body.Bytes(), []byte("secret-token")) {
+ t.Errorf("response leaked a token: %s", rr.Body.String())
+ }
+}
+
+// --- DeleteDeviceHandler ---
+
+func TestDelete_owns_device_succeeds(t *testing.T) {
+ store := &fakeStore{
+ devices: []push.PushDevice{
+ {ID: "row-d1", Namespace: "myapp", UserID: "u1", DeviceID: "d1"},
+ },
+ }
+ h := newHandlers(store, nil)
+
+ req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1")
+ rr := httptest.NewRecorder()
+ h.DeleteDeviceHandler(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
+ }
+ if len(store.devices) != 0 {
+ t.Errorf("expected device removed")
+ }
+}
+
+func TestDelete_other_users_device_returns_404(t *testing.T) {
+ store := &fakeStore{
+ devices: []push.PushDevice{
+ {ID: "row-d1", Namespace: "myapp", UserID: "other-user", DeviceID: "d1"},
+ },
+ }
+ h := newHandlers(store, nil)
+
+ req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1")
+ rr := httptest.NewRecorder()
+ h.DeleteDeviceHandler(rr, req)
+
+ if rr.Code != http.StatusNotFound {
+ t.Errorf("expected 404, got %d", rr.Code)
+ }
+ if len(store.devices) != 1 {
+ t.Errorf("expected device NOT removed")
+ }
+}
+
+func TestDelete_missing_id_returns_400(t *testing.T) {
+ h := newHandlers(&fakeStore{}, nil)
+ req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/", nil), "myapp", "u1")
+ rr := httptest.NewRecorder()
+ h.DeleteDeviceHandler(rr, req)
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400, got %d", rr.Code)
+ }
+}
+
+// --- SendHandler ---
+
+func TestSend_dispatcher_called_for_user(t *testing.T) {
+ var sent int32
+ dispatcher := push.New(&fakeStore{
+ devices: []push.PushDevice{
+ {ID: "row-1", Namespace: "myapp", UserID: "target-user", Provider: "fake", Token: "tok"},
+ },
+ }, zap.NewNop())
+ dispatcher.Register(&fakePushProvider{
+ name: "fake",
+ fn: func(ctx context.Context, msg push.PushMessage) error { atomic.AddInt32(&sent, 1); return nil },
+ })
+
+ h := newHandlers(&fakeStore{}, dispatcher)
+
+ body, _ := json.Marshal(SendRequest{
+ UserID: "target-user", Title: "hi", Body: "world",
+ })
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1")
+ rr := httptest.NewRecorder()
+ h.SendHandler(rr, req)
+
+ if rr.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
+ }
+ if atomic.LoadInt32(&sent) != 1 {
+ t.Errorf("expected provider called once, got %d", sent)
+ }
+}
+
+func TestSend_no_dispatcher_returns_503(t *testing.T) {
+ h := newHandlers(&fakeStore{}, nil)
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader([]byte(`{"user_id":"u"}`))), "myapp", "u1")
+ rr := httptest.NewRecorder()
+ h.SendHandler(rr, req)
+ if rr.Code != http.StatusServiceUnavailable {
+ t.Errorf("expected 503, got %d", rr.Code)
+ }
+}
+
+func TestSend_missing_user_id_returns_400(t *testing.T) {
+ dispatcher := push.New(&fakeStore{}, zap.NewNop())
+ h := newHandlers(&fakeStore{}, dispatcher)
+
+ body, _ := json.Marshal(SendRequest{})
+ req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1")
+ rr := httptest.NewRecorder()
+ h.SendHandler(rr, req)
+ if rr.Code != http.StatusBadRequest {
+ t.Errorf("expected 400, got %d", rr.Code)
+ }
+}
+
+// --- helpers ---
+
+type fakePushProvider struct {
+ name string
+ fn func(ctx context.Context, msg push.PushMessage) error
+}
+
+func (p *fakePushProvider) Name() string { return p.name }
+func (p *fakePushProvider) Send(ctx context.Context, msg push.PushMessage) error {
+ if p.fn != nil {
+ return p.fn(ctx, msg)
+ }
+ return nil
+}
+
+func TestExtractIDFromPath(t *testing.T) {
+ cases := []struct {
+ path, prefix, want string
+ }{
+ {"/v1/push/devices/abc", "/v1/push/devices/", "abc"},
+ {"/v1/push/devices/abc?x=1", "/v1/push/devices/", "abc"},
+ {"/v1/push/devices/", "/v1/push/devices/", ""},
+ {"/v1/other/abc", "/v1/push/devices/", ""},
+ }
+ for _, c := range cases {
+ if got := extractIDFromPath(c.path, c.prefix); got != c.want {
+ t.Errorf("extractIDFromPath(%q, %q) = %q, want %q", c.path, c.prefix, got, c.want)
+ }
+ }
+}
diff --git a/core/pkg/gateway/handlers/push/types.go b/core/pkg/gateway/handlers/push/types.go
new file mode 100644
index 0000000..8410409
--- /dev/null
+++ b/core/pkg/gateway/handlers/push/types.go
@@ -0,0 +1,150 @@
+// Package push provides HTTP handlers for managing push-notification
+// device registrations and sending pushes.
+//
+// Endpoints:
+//
+// GET /v1/push/devices — list caller's registered devices (tokens omitted)
+// POST /v1/push/devices — register / update a device
+// DELETE /v1/push/devices/{id} — unregister a device
+// POST /v1/push/send — send a push to a user (admin/internal scope)
+//
+// Device tokens are stored AES-256-GCM-encrypted in RQLite via the
+// pkg/push.RqliteDeviceStore. Tokens are NEVER returned by any endpoint —
+// the GET endpoint omits the token field for safety.
+package push
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+
+ "github.com/DeBrosOfficial/network/pkg/gateway/auth"
+ "github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
+ "github.com/DeBrosOfficial/network/pkg/logging"
+ "github.com/DeBrosOfficial/network/pkg/push"
+)
+
+// Handlers serves the /v1/push/* HTTP endpoints. Construct via NewHandlers;
+// it's safe for concurrent use.
+type Handlers struct {
+ dispatcher *push.PushDispatcher
+ store push.PushDeviceStore
+ logger *logging.ColoredLogger
+}
+
+// NewHandlers constructs a Handlers. Either argument may be nil — in which
+// case the corresponding endpoints return 503 Service Unavailable.
+func NewHandlers(dispatcher *push.PushDispatcher, store push.PushDeviceStore, logger *logging.ColoredLogger) *Handlers {
+ return &Handlers{
+ dispatcher: dispatcher,
+ store: store,
+ logger: logger,
+ }
+}
+
+// RegisterDeviceRequest is the body of POST /v1/push/devices.
+//
+// `device_id` is an app-supplied stable identifier (e.g. the OS-assigned
+// device UUID). Combined with (namespace, user_id) it uniquely identifies
+// the registration; re-posting with the same device_id updates the token.
+//
+// `token` is provider-specific:
+// - ntfy: the topic path the device subscribes to (e.g. "ns/myapp/user-1")
+// - expo: an ExponentPushToken[...]
+// - apns: a hex APNs device token (future)
+type RegisterDeviceRequest struct {
+ DeviceID string `json:"device_id"`
+ Provider string `json:"provider"` // "ntfy" | "expo" | "apns"
+ Token string `json:"token"`
+ Platform string `json:"platform,omitempty"` // "ios" | "android" | "web"
+ AppVersion string `json:"app_version,omitempty"`
+}
+
+// RegisterDeviceResponse is the body of POST /v1/push/devices.
+type RegisterDeviceResponse struct {
+ Status string `json:"status"`
+}
+
+// PushDeviceView is the safe (token-omitting) representation returned
+// by GET /v1/push/devices.
+type PushDeviceView struct {
+ ID string `json:"id"`
+ DeviceID string `json:"device_id"`
+ Provider string `json:"provider"`
+ Platform string `json:"platform,omitempty"`
+ AppVersion string `json:"app_version,omitempty"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+ LastSeen int64 `json:"last_seen,omitempty"`
+}
+
+// SendRequest is the body of POST /v1/push/send.
+//
+// The dispatcher fans out to all of `user_id`'s registered devices in
+// the caller's namespace. Auth scope: see SendHandler — currently
+// requires the caller to act on behalf of their own namespace; finer
+// per-user authorization is the app's responsibility.
+type SendRequest struct {
+ UserID string `json:"user_id"`
+ Title string `json:"title"`
+ Body string `json:"body"`
+ Channel string `json:"channel,omitempty"`
+ Priority string `json:"priority,omitempty"` // "high" | "normal" | "" (default)
+ Badge int `json:"badge,omitempty"`
+ Sound string `json:"sound,omitempty"`
+ Data map[string]interface{} `json:"data,omitempty"`
+}
+
+// SendResponse is the body of POST /v1/push/send.
+type SendResponse struct {
+ Status string `json:"status"`
+}
+
+// resolveNamespace pulls the namespace set by auth middleware out of context.
+func resolveNamespace(r *http.Request) string {
+ if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil {
+ if s, ok := v.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
+// resolveCallerUserID extracts the JWT subject (typically the wallet) of
+// the caller, or empty if the request was authenticated by API key only.
+func resolveCallerUserID(r *http.Request) string {
+ if v := r.Context().Value(ctxkeys.JWT); v != nil {
+ if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
+ return claims.Sub
+ }
+ }
+ return ""
+}
+
+func writeError(w http.ResponseWriter, code int, message string) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(code)
+ _ = json.NewEncoder(w).Encode(map[string]string{"error": message})
+}
+
+func writeJSON(w http.ResponseWriter, code int, v interface{}) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(code)
+ _ = json.NewEncoder(w).Encode(v)
+}
+
+// pickPriority maps the wire-format priority string to the typed enum.
+func pickPriority(s string) push.PushPriority {
+ switch s {
+ case "high":
+ return push.PriorityHigh
+ case "normal":
+ return push.PriorityNormal
+ default:
+ return push.PriorityNormal
+ }
+}
+
+// boundCtx returns a request-scoped context with no extra wrapping;
+// kept as a seam for future scope (rate-limit context etc.).
+func boundCtx(r *http.Request) context.Context { return r.Context() }
diff --git a/core/pkg/gateway/handlers/webrtc/credentials.go b/core/pkg/gateway/handlers/webrtc/credentials.go
index 405b734..446323a 100644
--- a/core/pkg/gateway/handlers/webrtc/credentials.go
+++ b/core/pkg/gateway/handlers/webrtc/credentials.go
@@ -42,6 +42,11 @@ func (h *WebRTCHandlers) CredentialsHandler(w http.ResponseWriter, r *http.Reque
fmt.Sprintf("turns:%s:5349", h.turnDomain),
)
}
+ // Stealth: TURNS via the SNI router on :443. Looks like ordinary HTTPS
+ // to a passive observer / DPI; usable in restricted regions.
+ if h.stealthCDNDomain != "" {
+ uris = append(uris, fmt.Sprintf("turns:%s:443", h.stealthCDNDomain))
+ }
h.logger.ComponentInfo(logging.ComponentGeneral, "Issued TURN credentials",
zap.String("namespace", ns),
diff --git a/core/pkg/gateway/handlers/webrtc/types.go b/core/pkg/gateway/handlers/webrtc/types.go
index 62167f0..2580f59 100644
--- a/core/pkg/gateway/handlers/webrtc/types.go
+++ b/core/pkg/gateway/handlers/webrtc/types.go
@@ -17,10 +17,21 @@ type WebRTCHandlers struct {
turnDomain string // TURN server domain for building URIs
turnSecret string // HMAC-SHA1 shared secret for TURN credential generation
+ // stealthCDNDomain, when non-empty, causes CredentialsHandler to also
+ // advertise turns://:443 — the stealth TURN URI served
+ // via the in-house SNI router. See pkg/sniproxy.
+ stealthCDNDomain string
+
// proxyWebSocket is injected from the gateway to reuse its WebSocket proxy logic
proxyWebSocket func(w http.ResponseWriter, r *http.Request, targetHost string) bool
}
+// SetStealthCDNDomain enables the stealth TURN URI in CredentialsHandler.
+// Pass empty string to disable. Safe to call before serving begins.
+func (h *WebRTCHandlers) SetStealthCDNDomain(domain string) {
+ h.stealthCDNDomain = domain
+}
+
// NewWebRTCHandlers creates a new WebRTCHandlers instance.
func NewWebRTCHandlers(
logger *logging.ColoredLogger,
diff --git a/core/pkg/gateway/lifecycle.go b/core/pkg/gateway/lifecycle.go
index 049336d..4fc0de7 100644
--- a/core/pkg/gateway/lifecycle.go
+++ b/core/pkg/gateway/lifecycle.go
@@ -12,6 +12,23 @@ import (
// It closes the serverless engine, network client, database connections,
// Olric cache client, and IPFS client in sequence.
func (g *Gateway) Close() {
+ // Flush PubSub aggregator buffers before tearing down the engine.
+ // Pending events are dispatched via the invoker which still needs the
+ // engine to be alive, so this MUST happen before the engine close.
+ // Aggregator state is local to this node — events not flushed here are
+ // lost (intended trade-off for high-frequency lossy streams).
+ if g.pubsubDispatcher != nil {
+ if agg := g.pubsubDispatcher.Aggregator(); agg != nil {
+ // 5s budget — same as the engine close timeout below.
+ // In-flight flushes call back into the invoker which still
+ // needs the engine to be alive.
+ if !agg.Shutdown(5 * time.Second) {
+ g.logger.ComponentWarn(logging.ComponentGeneral,
+ "PubSub aggregator shutdown timed out; some buffered events may be lost")
+ }
+ }
+ }
+
// Close serverless engine first
if g.serverlessEngine != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
diff --git a/core/pkg/gateway/middleware.go b/core/pkg/gateway/middleware.go
index 1cb5a07..00aedb1 100644
--- a/core/pkg/gateway/middleware.go
+++ b/core/pkg/gateway/middleware.go
@@ -643,6 +643,9 @@ func requiresNamespaceOwnership(p string) bool {
if strings.HasPrefix(p, "/v1/webrtc/") {
return true
}
+ if strings.HasPrefix(p, "/v1/push/") {
+ return true
+ }
return false
}
diff --git a/core/pkg/gateway/routes.go b/core/pkg/gateway/routes.go
index 4d3cd08..a03762c 100644
--- a/core/pkg/gateway/routes.go
+++ b/core/pkg/gateway/routes.go
@@ -119,10 +119,30 @@ func (g *Gateway) Routes() http.Handler {
if g.pubsubHandlers != nil {
mux.HandleFunc("/v1/pubsub/ws", g.pubsubHandlers.WebsocketHandler)
mux.HandleFunc("/v1/pubsub/publish", g.pubsubHandlers.PublishHandler)
+ mux.HandleFunc("/v1/pubsub/publish-batch", g.pubsubHandlers.PublishBatchHandler)
mux.HandleFunc("/v1/pubsub/topics", g.pubsubHandlers.TopicsHandler)
mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler)
}
+ // push notifications
+ if g.pushHandlers != nil {
+ // GET + POST share the path; the handler dispatches by method.
+ mux.HandleFunc("/v1/push/devices", func(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodGet:
+ g.pushHandlers.ListDevicesHandler(w, r)
+ case http.MethodPost:
+ g.pushHandlers.RegisterDeviceHandler(w, r)
+ default:
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ }
+ })
+ // DELETE /v1/push/devices/{id} — uses path-prefix routing because
+ // net/http mux doesn't extract path params; the handler parses {id}.
+ mux.HandleFunc("/v1/push/devices/", g.pushHandlers.DeleteDeviceHandler)
+ mux.HandleFunc("/v1/push/send", g.pushHandlers.SendHandler)
+ }
+
// operator node management (wallet JWT auth via middleware)
if g.operatorHandler != nil {
mux.HandleFunc("/v1/operator/invite", g.operatorHandler.HandleInvite)
diff --git a/core/pkg/logging/logger.go b/core/pkg/logging/logger.go
index 4b78345..a741ae4 100644
--- a/core/pkg/logging/logger.go
+++ b/core/pkg/logging/logger.go
@@ -57,6 +57,7 @@ const (
ComponentGateway Component = "GATEWAY"
ComponentSFU Component = "SFU"
ComponentTURN Component = "TURN"
+ ComponentSNI Component = "SNI"
)
// getComponentColor returns the color for a specific component
diff --git a/core/pkg/pubsub/adapter.go b/core/pkg/pubsub/adapter.go
index de8f4c5..3e03097 100644
--- a/core/pkg/pubsub/adapter.go
+++ b/core/pkg/pubsub/adapter.go
@@ -29,6 +29,18 @@ func (a *ClientAdapter) Publish(ctx context.Context, topic string, data []byte)
return a.manager.Publish(ctx, topic, data)
}
+// PublishBatch publishes multiple messages in parallel.
+// See Manager.PublishBatch for semantics.
+func (a *ClientAdapter) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
+ return a.manager.PublishBatch(ctx, msgs, opts)
+}
+
+// PublishSame sends the same payload to every topic in parallel.
+// See Manager.PublishSame for semantics.
+func (a *ClientAdapter) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
+ return a.manager.PublishSame(ctx, topics, data, opts)
+}
+
// Unsubscribe unsubscribes from a topic
func (a *ClientAdapter) Unsubscribe(ctx context.Context, topic string) error {
return a.manager.Unsubscribe(ctx, topic)
diff --git a/core/pkg/pubsub/publish.go b/core/pkg/pubsub/publish.go
index 3fb309a..56c7163 100644
--- a/core/pkg/pubsub/publish.go
+++ b/core/pkg/pubsub/publish.go
@@ -3,9 +3,56 @@ package pubsub
import (
"context"
"fmt"
+ "strings"
+ "sync"
"time"
+
+ "golang.org/x/sync/errgroup"
)
+// defaultBatchConcurrency is the default cap on in-flight publishes within a single batch.
+const defaultBatchConcurrency = 32
+
+// MaxBatchSize is the maximum number of messages allowed per PublishBatch call.
+// The HTTP handler enforces this; the Manager itself does not, so internal callers
+// (e.g. the SDK) can pass larger batches if they accept the responsibility.
+const MaxBatchSize = 100
+
+// TopicMessage is one entry in a batch publish.
+type TopicMessage struct {
+ Topic string
+ Data []byte
+}
+
+// PublishBatchOptions controls batch publish behavior.
+type PublishBatchOptions struct {
+ // BestEffort, if true, attempts every publish even when some fail and returns
+ // a *BatchError summarizing per-topic failures. Default (false) is fail-fast:
+ // the first failure cancels remaining in-flight publishes and returns that error.
+ BestEffort bool
+
+ // MaxConcurrency caps the number of in-flight publishes within this batch.
+ // 0 means use defaultBatchConcurrency.
+ MaxConcurrency int
+}
+
+// BatchError aggregates per-topic errors returned when PublishBatch is called
+// with BestEffort=true and at least one publish failed.
+type BatchError struct {
+ Errors map[string]error // topic -> error
+}
+
+func (e *BatchError) Error() string {
+ if len(e.Errors) == 0 {
+ return "batch publish: no errors"
+ }
+ names := make([]string, 0, len(e.Errors))
+ for t := range e.Errors {
+ names = append(names, t)
+ }
+ return fmt.Sprintf("batch publish: %d topic(s) failed: %s", len(e.Errors), strings.Join(names, ", "))
+}
+
// Publish publishes a message to a topic
func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error {
if m.pubsub == nil {
@@ -58,3 +105,90 @@ func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error
return nil
}
+
+// PublishBatch publishes multiple messages in parallel, one per topic.
+// Default behavior is fail-fast: the first publish error cancels remaining work
+// and is returned. If opts.BestEffort is set, every publish is attempted and a
+// *BatchError is returned if any failed.
+//
+// Concurrency is bounded by opts.MaxConcurrency (default 32).
+// Empty msgs slice is a no-op and returns nil.
+func (m *Manager) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
+ if len(msgs) == 0 {
+ return nil
+ }
+ if m.pubsub == nil {
+ return fmt.Errorf("pubsub not initialized")
+ }
+
+ maxConc := opts.MaxConcurrency
+ if maxConc <= 0 {
+ maxConc = defaultBatchConcurrency
+ }
+ if maxConc > len(msgs) {
+ maxConc = len(msgs)
+ }
+ sem := make(chan struct{}, maxConc)
+
+ if !opts.BestEffort {
+ g, gctx := errgroup.WithContext(ctx)
+ for _, msg := range msgs {
+ msg := msg
+ g.Go(func() error {
+ select {
+ case sem <- struct{}{}:
+ case <-gctx.Done():
+ return gctx.Err()
+ }
+ defer func() { <-sem }()
+ return m.Publish(gctx, msg.Topic, msg.Data)
+ })
+ }
+ return g.Wait()
+ }
+
+ // Best-effort path: attempt every publish, collect per-topic errors.
+ var (
+ wg sync.WaitGroup
+ errMu sync.Mutex
+ errMap = map[string]error{}
+ )
+ for _, msg := range msgs {
+ msg := msg
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ select {
+ case sem <- struct{}{}:
+ case <-ctx.Done():
+ errMu.Lock()
+ errMap[msg.Topic] = ctx.Err()
+ errMu.Unlock()
+ return
+ }
+ defer func() { <-sem }()
+ if err := m.Publish(ctx, msg.Topic, msg.Data); err != nil {
+ errMu.Lock()
+ errMap[msg.Topic] = err
+ errMu.Unlock()
+ }
+ }()
+ }
+ wg.Wait()
+ if len(errMap) > 0 {
+ return &BatchError{Errors: errMap}
+ }
+ return nil
+}
+
+// PublishSame is a convenience wrapper that sends the same payload to every topic.
+func (m *Manager) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
+ if len(topics) == 0 {
+ return nil
+ }
+ msgs := make([]TopicMessage, len(topics))
+ for i, t := range topics {
+ msgs[i] = TopicMessage{Topic: t, Data: data}
+ }
+ return m.PublishBatch(ctx, msgs, opts)
+}
diff --git a/core/pkg/pubsub/publish_batch_test.go b/core/pkg/pubsub/publish_batch_test.go
new file mode 100644
index 0000000..4d13349
--- /dev/null
+++ b/core/pkg/pubsub/publish_batch_test.go
@@ -0,0 +1,196 @@
+package pubsub
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestPublishBatch_empty_slice_returns_nil(t *testing.T) {
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+
+ if err := mgr.PublishBatch(context.Background(), nil, PublishBatchOptions{}); err != nil {
+ t.Fatalf("expected nil error for empty slice, got: %v", err)
+ }
+ if err := mgr.PublishBatch(context.Background(), []TopicMessage{}, PublishBatchOptions{}); err != nil {
+ t.Fatalf("expected nil error for empty slice, got: %v", err)
+ }
+}
+
+func TestPublishBatch_happy_path(t *testing.T) {
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+
+ msgs := []TopicMessage{
+ {Topic: "a", Data: []byte("data-a")},
+ {Topic: "b", Data: []byte("data-b")},
+ {Topic: "c", Data: []byte("data-c")},
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil {
+ t.Fatalf("PublishBatch failed: %v", err)
+ }
+}
+
+func TestPublishSame_uses_same_payload(t *testing.T) {
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+
+ topics := []string{"x", "y", "z"}
+ data := []byte("shared")
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ if err := mgr.PublishSame(ctx, topics, data, PublishBatchOptions{}); err != nil {
+ t.Fatalf("PublishSame failed: %v", err)
+ }
+}
+
+func TestPublishSame_empty_returns_nil(t *testing.T) {
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+ if err := mgr.PublishSame(context.Background(), nil, []byte("x"), PublishBatchOptions{}); err != nil {
+ t.Fatalf("expected nil for empty topics, got: %v", err)
+ }
+}
+
+func TestPublishBatch_context_cancel_returns_error(t *testing.T) {
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+
+ msgs := make([]TopicMessage, 50)
+ for i := range msgs {
+ msgs[i] = TopicMessage{Topic: fmt.Sprintf("topic-%d", i), Data: []byte("d")}
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // cancel immediately
+
+ err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{})
+ if err == nil {
+ t.Fatal("expected context.Canceled error, got nil")
+ }
+ if !errors.Is(err, context.Canceled) {
+ t.Logf("got error (acceptable as long as it's an error): %v", err)
+ }
+}
+
+func TestPublishBatch_concurrency_limit(t *testing.T) {
+ // Verify PublishBatch with low MaxConcurrency completes without deadlocking.
+ // Each Publish in a no-peer test environment waits up to 2s for mesh formation,
+ // so we use a small batch size to keep wall time bounded.
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+
+ msgs := make([]TopicMessage, 8)
+ for i := range msgs {
+ msgs[i] = TopicMessage{Topic: fmt.Sprintf("ct-%d", i), Data: []byte("d")}
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 2}); err != nil {
+ t.Fatalf("PublishBatch with low concurrency failed: %v", err)
+ }
+}
+
+// TestPublishBatch_caps_concurrency_above_msg_count verifies that MaxConcurrency
+// is clamped to len(msgs) — passing 100 with 3 messages should not panic on
+// channel capacity.
+func TestPublishBatch_caps_concurrency_above_msg_count(t *testing.T) {
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+
+ msgs := []TopicMessage{
+ {Topic: "a", Data: []byte("1")},
+ {Topic: "b", Data: []byte("2")},
+ {Topic: "c", Data: []byte("3")},
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 100}); err != nil {
+ t.Fatalf("PublishBatch failed: %v", err)
+ }
+}
+
+func TestBatchError_Error_summarizes(t *testing.T) {
+ be := &BatchError{Errors: map[string]error{
+ "topic-a": errors.New("boom"),
+ "topic-b": errors.New("kaboom"),
+ }}
+ s := be.Error()
+ if s == "" {
+ t.Fatal("expected non-empty error string")
+ }
+ // Should mention both topics.
+ if !contains(s, "topic-a") || !contains(s, "topic-b") {
+ t.Errorf("error string %q should mention both failing topics", s)
+ }
+}
+
+func TestBatchError_Error_empty_map(t *testing.T) {
+ be := &BatchError{}
+ if s := be.Error(); s == "" {
+ t.Fatal("expected non-empty string even for empty map")
+ }
+}
+
+// contains is a tiny helper to avoid importing strings just for this.
+func contains(s, substr string) bool {
+ for i := 0; i+len(substr) <= len(s); i++ {
+ if s[i:i+len(substr)] == substr {
+ return true
+ }
+ }
+ return false
+}
+
+// TestPublishBatch_concurrent_publishes_thread_safe ensures concurrent
+// PublishBatch invocations don't race on internal state.
+func TestPublishBatch_concurrent_publishes_thread_safe(t *testing.T) {
+ mgr, cleanup := createTestManager(t, "test-ns")
+ defer cleanup()
+
+ const goroutines = 8
+ const msgsPerGoroutine = 5
+
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ var wg sync.WaitGroup
+ var failures int64
+
+ for g := 0; g < goroutines; g++ {
+ wg.Add(1)
+ go func(gid int) {
+ defer wg.Done()
+ msgs := make([]TopicMessage, msgsPerGoroutine)
+ for i := range msgs {
+ msgs[i] = TopicMessage{
+ Topic: fmt.Sprintf("g%d-t%d", gid, i),
+ Data: []byte("d"),
+ }
+ }
+ if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil {
+ atomic.AddInt64(&failures, 1)
+ t.Logf("goroutine %d failed: %v", gid, err)
+ }
+ }(g)
+ }
+ wg.Wait()
+
+ if failures > 0 {
+ t.Errorf("%d concurrent batches failed", failures)
+ }
+}
diff --git a/core/pkg/push/device_store_rqlite.go b/core/pkg/push/device_store_rqlite.go
new file mode 100644
index 0000000..1f1971e
--- /dev/null
+++ b/core/pkg/push/device_store_rqlite.go
@@ -0,0 +1,172 @@
+package push
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/DeBrosOfficial/network/pkg/rqlite"
+ "github.com/DeBrosOfficial/network/pkg/secrets"
+ "github.com/google/uuid"
+ "go.uber.org/zap"
+)
+
+// SecretsKeyPurpose is the HKDF purpose string for push token encryption.
+// Used in pkg/secrets.DeriveKey to derive a domain-separated AES key.
+const SecretsKeyPurpose = "push-device-tokens"
+
+// RqliteDeviceStore is a PushDeviceStore backed by RQLite + AES-256-GCM
+// at-rest encryption of the push token.
+type RqliteDeviceStore struct {
+ db rqlite.Client
+ encKey []byte // derived once at construction
+ logger *zap.Logger
+}
+
+// NewRqliteDeviceStore derives the per-cluster encryption key from the
+// cluster secret and returns a ready-to-use store. The cluster secret is
+// the same one used for other at-rest encryption (see pkg/secrets).
+func NewRqliteDeviceStore(db rqlite.Client, clusterSecret string, logger *zap.Logger) (*RqliteDeviceStore, error) {
+ if logger == nil {
+ logger = zap.NewNop()
+ }
+ key, err := secrets.DeriveKey(clusterSecret, SecretsKeyPurpose)
+ if err != nil {
+ return nil, fmt.Errorf("derive push-device key: %w", err)
+ }
+ return &RqliteDeviceStore{
+ db: db,
+ encKey: key,
+ logger: logger.Named("push-store"),
+ }, nil
+}
+
+// deviceRow is the scan target for SELECT queries.
+type deviceRow struct {
+ ID string
+ Namespace string
+ UserID string
+ DeviceID string
+ Provider string
+ TokenEncrypted string
+ Platform string
+ AppVersion string
+ CreatedAt int64
+ UpdatedAt int64
+ LastSeen int64
+}
+
+// Upsert implements PushDeviceStore.
+func (s *RqliteDeviceStore) Upsert(ctx context.Context, dev PushDevice) error {
+ if dev.Namespace == "" || dev.UserID == "" || dev.DeviceID == "" {
+ return fmt.Errorf("namespace, user_id, device_id required")
+ }
+ if dev.Provider == "" {
+ return fmt.Errorf("provider required")
+ }
+ if dev.Token == "" {
+ return ErrEmptyToken
+ }
+
+ encToken, err := secrets.Encrypt(dev.Token, s.encKey)
+ if err != nil {
+ return fmt.Errorf("encrypt token: %w", err)
+ }
+
+ now := time.Now().Unix()
+ if dev.CreatedAt == 0 {
+ dev.CreatedAt = now
+ }
+ dev.UpdatedAt = now
+
+ id := dev.ID
+ if id == "" {
+ id = uuid.New().String()
+ }
+
+ // SQLite UPSERT keyed on (namespace, user_id, device_id) per the migration's
+ // UNIQUE constraint. On conflict we replace token + provider + metadata
+ // while preserving the original id and created_at.
+ query := `
+ INSERT INTO push_devices
+ (id, namespace, user_id, device_id, provider, token_encrypted,
+ platform, app_version, created_at, updated_at, last_seen)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ ON CONFLICT(namespace, user_id, device_id) DO UPDATE SET
+ provider = excluded.provider,
+ token_encrypted = excluded.token_encrypted,
+ platform = excluded.platform,
+ app_version = excluded.app_version,
+ updated_at = excluded.updated_at,
+ last_seen = excluded.last_seen
+ `
+ _, err = s.db.Exec(ctx, query,
+ id, dev.Namespace, dev.UserID, dev.DeviceID, dev.Provider, encToken,
+ dev.Platform, dev.AppVer, dev.CreatedAt, dev.UpdatedAt, dev.LastSeen,
+ )
+ if err != nil {
+ return fmt.Errorf("upsert push device: %w", err)
+ }
+ return nil
+}
+
+// Delete implements PushDeviceStore.
+func (s *RqliteDeviceStore) Delete(ctx context.Context, namespace, id string) error {
+ if namespace == "" || id == "" {
+ return fmt.Errorf("namespace and id required")
+ }
+ query := `DELETE FROM push_devices WHERE id = ? AND namespace = ?`
+ res, err := s.db.Exec(ctx, query, id, namespace)
+ if err != nil {
+ return fmt.Errorf("delete push device: %w", err)
+ }
+ n, _ := res.RowsAffected()
+ if n == 0 {
+ return fmt.Errorf("push device not found: %s", id)
+ }
+ return nil
+}
+
+// ListForUser implements PushDeviceStore. Returns devices with decrypted tokens.
+// Caller MUST treat tokens as sensitive.
+func (s *RqliteDeviceStore) ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error) {
+ if namespace == "" || userID == "" {
+ return nil, nil
+ }
+ query := `
+ SELECT id, namespace, user_id, device_id, provider, token_encrypted,
+ COALESCE(platform, ''), COALESCE(app_version, ''),
+ created_at, updated_at, COALESCE(last_seen, 0)
+ FROM push_devices
+ WHERE namespace = ? AND user_id = ?
+ `
+ var rows []deviceRow
+ if err := s.db.Query(ctx, &rows, query, namespace, userID); err != nil {
+ return nil, fmt.Errorf("query push devices: %w", err)
+ }
+
+ out := make([]PushDevice, 0, len(rows))
+ for _, r := range rows {
+ token, err := secrets.Decrypt(r.TokenEncrypted, s.encKey)
+ if err != nil {
+ s.logger.Warn("failed to decrypt push token; skipping device",
+ zap.String("device_id", r.DeviceID),
+ zap.Error(err))
+ continue
+ }
+ out = append(out, PushDevice{
+ ID: r.ID,
+ Namespace: r.Namespace,
+ UserID: r.UserID,
+ DeviceID: r.DeviceID,
+ Provider: r.Provider,
+ Token: token,
+ Platform: r.Platform,
+ AppVer: r.AppVersion,
+ CreatedAt: r.CreatedAt,
+ UpdatedAt: r.UpdatedAt,
+ LastSeen: r.LastSeen,
+ })
+ }
+ return out, nil
+}
diff --git a/core/pkg/push/dispatcher.go b/core/pkg/push/dispatcher.go
new file mode 100644
index 0000000..f43017d
--- /dev/null
+++ b/core/pkg/push/dispatcher.go
@@ -0,0 +1,97 @@
+package push
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ "go.uber.org/zap"
+)
+
+// PushDispatcher routes push messages to the matching provider for each
+// of a user's registered devices.
+type PushDispatcher struct {
+ mu sync.RWMutex
+ providers map[string]PushProvider
+ devices PushDeviceStore
+ logger *zap.Logger
+}
+
+// New creates a dispatcher with the given device store. Register
+// providers before sending.
+func New(devices PushDeviceStore, logger *zap.Logger) *PushDispatcher {
+ if logger == nil {
+ logger = zap.NewNop()
+ }
+ return &PushDispatcher{
+ providers: map[string]PushProvider{},
+ devices: devices,
+ logger: logger.Named("push"),
+ }
+}
+
+// Register makes a provider available to dispatch. Calling Register with
+// the same name twice replaces the previous provider — useful in tests.
+func (d *PushDispatcher) Register(p PushProvider) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.providers[p.Name()] = p
+}
+
+// Provider returns the registered provider by name, or nil.
+func (d *PushDispatcher) Provider(name string) PushProvider {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.providers[name]
+}
+
+// SendToUser fans out the message to every registered device for the
+// user. Each provider failure is logged but does not stop subsequent
+// devices. Returns the first encountered error (if any) so callers can
+// surface a partial-failure signal.
+//
+// SendToUser returns nil if the user has no registered devices — that
+// is normal, not an error.
+func (d *PushDispatcher) SendToUser(
+ ctx context.Context,
+ namespace, userID string,
+ msg PushMessage,
+) error {
+ devs, err := d.devices.ListForUser(ctx, namespace, userID)
+ if err != nil {
+ return fmt.Errorf("list devices: %w", err)
+ }
+ if len(devs) == 0 {
+ return nil
+ }
+
+ var firstErr error
+ for _, dev := range devs {
+ d.mu.RLock()
+ p, ok := d.providers[dev.Provider]
+ d.mu.RUnlock()
+ if !ok {
+ d.logger.Warn("push: dropping device with unregistered provider",
+ zap.String("provider", dev.Provider),
+ zap.String("device_id", dev.DeviceID),
+ )
+ if firstErr == nil {
+ firstErr = fmt.Errorf("%w: %s", ErrUnknownProvider, dev.Provider)
+ }
+ continue
+ }
+ m := msg
+ m.DeviceToken = dev.Token
+ if err := p.Send(ctx, m); err != nil {
+ d.logger.Warn("push: provider send failed",
+ zap.String("provider", dev.Provider),
+ zap.String("device_id", dev.DeviceID),
+ zap.Error(err),
+ )
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ }
+ return firstErr
+}
diff --git a/core/pkg/push/dispatcher_test.go b/core/pkg/push/dispatcher_test.go
new file mode 100644
index 0000000..46c93ac
--- /dev/null
+++ b/core/pkg/push/dispatcher_test.go
@@ -0,0 +1,149 @@
+package push
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+
+ "go.uber.org/zap"
+)
+
+// fakeProvider records every Send call.
+type fakeProvider struct {
+ name string
+ sent int32
+ lastToken string
+ err error
+}
+
+func (f *fakeProvider) Name() string { return f.name }
+func (f *fakeProvider) Send(ctx context.Context, msg PushMessage) error {
+ atomic.AddInt32(&f.sent, 1)
+ f.lastToken = msg.DeviceToken
+ return f.err
+}
+
+// fakeStore is an in-memory PushDeviceStore.
+type fakeStore struct {
+ devices []PushDevice
+ err error
+}
+
+func (s *fakeStore) Upsert(ctx context.Context, dev PushDevice) error {
+ if s.err != nil {
+ return s.err
+ }
+ s.devices = append(s.devices, dev)
+ return nil
+}
+func (s *fakeStore) Delete(ctx context.Context, ns, id string) error { return nil }
+func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]PushDevice, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ out := []PushDevice{}
+ for _, d := range s.devices {
+ if d.Namespace == ns && d.UserID == userID {
+ out = append(out, d)
+ }
+ }
+ return out, nil
+}
+
+func TestSendToUser_no_devices_returns_nil(t *testing.T) {
+ d := New(&fakeStore{}, zap.NewNop())
+ if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "x"}); err != nil {
+ t.Fatalf("expected nil for no devices, got: %v", err)
+ }
+}
+
+func TestSendToUser_routes_to_correct_provider(t *testing.T) {
+ store := &fakeStore{devices: []PushDevice{
+ {Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "ntfy-tok"},
+ {Namespace: "ns", UserID: "u", Provider: "expo", Token: "expo-tok"},
+ }}
+ ntfy := &fakeProvider{name: "ntfy"}
+ expo := &fakeProvider{name: "expo"}
+
+ d := New(store, zap.NewNop())
+ d.Register(ntfy)
+ d.Register(expo)
+
+ if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "hi"}); err != nil {
+ t.Fatalf("SendToUser: %v", err)
+ }
+
+ if atomic.LoadInt32(&ntfy.sent) != 1 || ntfy.lastToken != "ntfy-tok" {
+ t.Errorf("ntfy provider not called correctly: sent=%d token=%s", ntfy.sent, ntfy.lastToken)
+ }
+ if atomic.LoadInt32(&expo.sent) != 1 || expo.lastToken != "expo-tok" {
+ t.Errorf("expo provider not called correctly: sent=%d token=%s", expo.sent, expo.lastToken)
+ }
+}
+
+func TestSendToUser_unknown_provider_returns_error_continues(t *testing.T) {
+ store := &fakeStore{devices: []PushDevice{
+ {Namespace: "ns", UserID: "u", Provider: "ghost", Token: "tok"},
+ {Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "real"},
+ }}
+ ntfy := &fakeProvider{name: "ntfy"}
+
+ d := New(store, zap.NewNop())
+ d.Register(ntfy)
+
+ err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
+ if err == nil {
+ t.Fatal("expected error for unknown provider")
+ }
+ if !errors.Is(err, ErrUnknownProvider) {
+ t.Errorf("expected ErrUnknownProvider, got %v", err)
+ }
+ // ntfy should still have been called.
+ if atomic.LoadInt32(&ntfy.sent) != 1 {
+ t.Error("ntfy should have been called for the second device")
+ }
+}
+
+func TestSendToUser_provider_failure_returned_but_other_devices_still_processed(t *testing.T) {
+ store := &fakeStore{devices: []PushDevice{
+ {Namespace: "ns", UserID: "u", Provider: "expo", Token: "tok-1"},
+ {Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "tok-2"},
+ }}
+ expoErr := errors.New("expo down")
+ expo := &fakeProvider{name: "expo", err: expoErr}
+ ntfy := &fakeProvider{name: "ntfy"}
+
+ d := New(store, zap.NewNop())
+ d.Register(expo)
+ d.Register(ntfy)
+
+ err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
+ if !errors.Is(err, expoErr) {
+ t.Errorf("expected expo error, got %v", err)
+ }
+ if atomic.LoadInt32(&ntfy.sent) != 1 {
+ t.Error("ntfy should have been called even though expo failed")
+ }
+}
+
+func TestSendToUser_store_error_propagated(t *testing.T) {
+ storeErr := errors.New("store boom")
+ d := New(&fakeStore{err: storeErr}, zap.NewNop())
+
+ err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
+ if err == nil || !errors.Is(err, storeErr) {
+ t.Errorf("expected store error, got %v", err)
+ }
+}
+
+func TestRegister_replaces_existing_provider(t *testing.T) {
+ d := New(&fakeStore{}, zap.NewNop())
+ a := &fakeProvider{name: "ntfy"}
+ b := &fakeProvider{name: "ntfy"}
+ d.Register(a)
+ d.Register(b)
+ if d.Provider("ntfy") != b {
+ t.Error("expected second Register to replace the first")
+ }
+}
diff --git a/core/pkg/push/providers/expo/expo.go b/core/pkg/push/providers/expo/expo.go
new file mode 100644
index 0000000..38c95c7
--- /dev/null
+++ b/core/pkg/push/providers/expo/expo.go
@@ -0,0 +1,160 @@
+// Package expo wraps the Expo push relay as a push.PushProvider.
+//
+// This is a thin port of the legacy gateway.PushNotificationService —
+// behaviour preserved, surface adapted to the provider abstraction.
+//
+// Long term Expo is intended to be replaced with direct APNs (iOS) +
+// ntfy (Android). This provider exists so the gateway can keep using
+// Expo while the migration happens.
+package expo
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "github.com/DeBrosOfficial/network/pkg/push"
+ "go.uber.org/zap"
+)
+
+const expoAPIURL = "https://exp.host/--/api/v2/push/send"
+
+// Config holds Expo provider settings.
+type Config struct {
+ // AccessToken is an optional Expo access token. When set, it's sent
+ // as a Bearer token, which Expo uses for higher-priority delivery
+ // and to attribute the send to your account.
+ AccessToken string
+ // Timeout bounds each Send call. 0 selects 10 seconds (matching the
+ // previous PushNotificationService default).
+ Timeout time.Duration
+}
+
+// Provider is the Expo push.PushProvider implementation.
+type Provider struct {
+ accessToken string
+ httpClient *http.Client
+ logger *zap.Logger
+}
+
+// New creates a Provider with the given config.
+func New(cfg Config, logger *zap.Logger) *Provider {
+ if logger == nil {
+ logger = zap.NewNop()
+ }
+ timeout := cfg.Timeout
+ if timeout <= 0 {
+ timeout = 10 * time.Second
+ }
+ return &Provider{
+ accessToken: cfg.AccessToken,
+ httpClient: &http.Client{Timeout: timeout},
+ logger: logger.Named("expo"),
+ }
+}
+
+// Name implements push.PushProvider.
+func (p *Provider) Name() string { return "expo" }
+
+// expoMessage matches the wire format Expo expects.
+type expoMessage struct {
+ To string `json:"to"`
+ Title string `json:"title,omitempty"`
+ Body string `json:"body,omitempty"`
+ Data map[string]interface{} `json:"data,omitempty"`
+ Sound string `json:"sound,omitempty"`
+ Badge int `json:"badge,omitempty"`
+ Priority string `json:"priority,omitempty"`
+ MutableContent bool `json:"mutableContent,omitempty"`
+ ChannelID string `json:"channelId,omitempty"`
+}
+
+// expoTicket is the per-message response.
+type expoTicket struct {
+ Status string `json:"status"`
+ Message string `json:"message,omitempty"`
+}
+
+type expoResponse struct {
+ Data []expoTicket `json:"data"`
+}
+
+// Send delivers a push via the Expo relay.
+func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
+ if msg.DeviceToken == "" {
+ return push.ErrEmptyToken
+ }
+
+ priority := "default"
+ if msg.Priority == push.PriorityHigh {
+ priority = "high"
+ }
+
+ wire := expoMessage{
+ To: msg.DeviceToken,
+ Title: msg.Title,
+ Body: msg.Body,
+ Data: msg.Data,
+ Sound: msg.Sound,
+ Badge: msg.Badge,
+ Priority: priority,
+ MutableContent: true, // for iOS Notification Service Extension
+ ChannelID: msg.Channel,
+ }
+ if wire.Sound == "" {
+ wire.Sound = "default"
+ }
+
+ body, err := json.Marshal(wire)
+ if err != nil {
+ return fmt.Errorf("expo: marshal: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, expoAPIURL, bytes.NewReader(body))
+ if err != nil {
+ return fmt.Errorf("expo: build request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ if p.accessToken != "" {
+ req.Header.Set("Authorization", "Bearer "+p.accessToken)
+ }
+
+ resp, err := p.httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("expo: post: %w", err)
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(io.LimitReader(resp.Body, 16<<10))
+ if err != nil {
+ return fmt.Errorf("expo: read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("expo: http %d: %s", resp.StatusCode, string(respBody))
+ }
+
+ var er expoResponse
+ if err := json.Unmarshal(respBody, &er); err != nil {
+ // Older Expo responses sometimes return a bare array; try that fallback.
+ var tickets []expoTicket
+ if err2 := json.Unmarshal(respBody, &tickets); err2 == nil {
+ er.Data = tickets
+ } else {
+ return fmt.Errorf("expo: parse response: %w", err)
+ }
+ }
+
+ for _, t := range er.Data {
+ if t.Status != "" && t.Status != "ok" {
+ return fmt.Errorf("expo: ticket status %q: %s", t.Status, t.Message)
+ }
+ }
+
+ return nil
+}
diff --git a/core/pkg/push/providers/expo/expo_test.go b/core/pkg/push/providers/expo/expo_test.go
new file mode 100644
index 0000000..fa89aaa
--- /dev/null
+++ b/core/pkg/push/providers/expo/expo_test.go
@@ -0,0 +1,118 @@
+package expo
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/DeBrosOfficial/network/pkg/push"
+)
+
+// roundTripFunc lets us mock http.Client transport for the Expo provider so
+// we can assert against requests without hitting exp.host.
+type roundTripFunc func(req *http.Request) (*http.Response, error)
+
+func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) }
+
+func newTestProvider(rt roundTripFunc) *Provider {
+ p := New(Config{}, nil)
+ p.httpClient.Transport = rt
+ return p
+}
+
+func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) {
+ p := New(Config{}, nil)
+ err := p.Send(context.Background(), push.PushMessage{Body: "x"})
+ if err != push.ErrEmptyToken {
+ t.Errorf("expected ErrEmptyToken, got %v", err)
+ }
+}
+
+func TestSend_happy_path(t *testing.T) {
+ var gotPayload map[string]interface{}
+ var gotAuth string
+
+ p := newTestProvider(func(req *http.Request) (*http.Response, error) {
+ gotAuth = req.Header.Get("Authorization")
+ body, _ := io.ReadAll(req.Body)
+ _ = json.Unmarshal(body, &gotPayload)
+ resp := httptest.NewRecorder()
+ resp.WriteHeader(200)
+ _, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`)
+ return resp.Result(), nil
+ })
+ p.accessToken = "secret-token"
+
+ err := p.Send(context.Background(), push.PushMessage{
+ DeviceToken: "ExponentPushToken[abc]",
+ Title: "T", Body: "B",
+ Priority: push.PriorityHigh,
+ })
+ if err != nil {
+ t.Fatalf("Send failed: %v", err)
+ }
+ if gotAuth != "Bearer secret-token" {
+ t.Errorf("auth header wrong: %s", gotAuth)
+ }
+ if gotPayload["to"] != "ExponentPushToken[abc]" {
+ t.Errorf("to wrong: %v", gotPayload["to"])
+ }
+ if gotPayload["priority"] != "high" {
+ t.Errorf("priority wrong: %v", gotPayload["priority"])
+ }
+}
+
+func TestSend_ticket_error_returns_error(t *testing.T) {
+ p := newTestProvider(func(req *http.Request) (*http.Response, error) {
+ resp := httptest.NewRecorder()
+ resp.WriteHeader(200)
+ _, _ = resp.WriteString(`{"data":[{"status":"error","message":"DeviceNotRegistered"}]}`)
+ return resp.Result(), nil
+ })
+ err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"})
+ if err == nil {
+ t.Fatal("expected error for ticket failure")
+ }
+}
+
+func TestSend_http_error_returns_error(t *testing.T) {
+ p := newTestProvider(func(req *http.Request) (*http.Response, error) {
+ resp := httptest.NewRecorder()
+ resp.WriteHeader(500)
+ _, _ = resp.WriteString(`upstream broken`)
+ return resp.Result(), nil
+ })
+ err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"})
+ if err == nil {
+ t.Fatal("expected error for HTTP 500")
+ }
+}
+
+func TestSend_normal_priority_maps_to_default(t *testing.T) {
+ var gotPayload map[string]interface{}
+ p := newTestProvider(func(req *http.Request) (*http.Response, error) {
+ body, _ := io.ReadAll(req.Body)
+ _ = json.Unmarshal(body, &gotPayload)
+ resp := httptest.NewRecorder()
+ resp.WriteHeader(200)
+ _, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`)
+ return resp.Result(), nil
+ })
+ if err := p.Send(context.Background(), push.PushMessage{
+ DeviceToken: "x", Body: "y", Priority: push.PriorityNormal,
+ }); err != nil {
+ t.Fatal(err)
+ }
+ if gotPayload["priority"] != "default" {
+ t.Errorf("expected priority=default, got %v", gotPayload["priority"])
+ }
+}
+
+func TestName(t *testing.T) {
+ if New(Config{}, nil).Name() != "expo" {
+ t.Error("expected Name=expo")
+ }
+}
diff --git a/core/pkg/push/providers/ntfy/ntfy.go b/core/pkg/push/providers/ntfy/ntfy.go
new file mode 100644
index 0000000..adc96b6
--- /dev/null
+++ b/core/pkg/push/providers/ntfy/ntfy.go
@@ -0,0 +1,132 @@
+// Package ntfy implements a push.PushProvider backed by an ntfy server.
+//
+// ntfy delivers notifications via plain HTTP POST to /.
+// We map PushMessage fields to ntfy headers:
+// - Title -> "Title"
+// - Priority -> "Priority"
+// - Channel -> "Tags"
+// - Data -> base64-encoded JSON in "X-Data"
+//
+// See https://docs.ntfy.sh/publish/#publish-as-json for details.
+package ntfy
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/DeBrosOfficial/network/pkg/push"
+ "go.uber.org/zap"
+)
+
+// Config holds per-provider settings.
+type Config struct {
+ // BaseURL is the ntfy HTTP endpoint (e.g. "http://localhost:8080" or
+ // "https://push.example.com"). Trailing slash is tolerated.
+ BaseURL string
+ // AuthToken is an optional per-namespace bearer token. Leave empty to
+ // disable authentication.
+ AuthToken string
+ // Timeout bounds each Send call. 0 selects 5 seconds.
+ Timeout time.Duration
+}
+
+// Provider is the ntfy push.PushProvider implementation.
+type Provider struct {
+ baseURL string
+ authToken string
+ httpClient *http.Client
+ logger *zap.Logger
+}
+
+// New creates a Provider with the given config.
+func New(cfg Config, logger *zap.Logger) *Provider {
+ if logger == nil {
+ logger = zap.NewNop()
+ }
+ timeout := cfg.Timeout
+ if timeout <= 0 {
+ timeout = 5 * time.Second
+ }
+ return &Provider{
+ baseURL: strings.TrimRight(cfg.BaseURL, "/"),
+ authToken: cfg.AuthToken,
+ httpClient: &http.Client{Timeout: timeout},
+ logger: logger.Named("ntfy"),
+ }
+}
+
+// Name implements push.PushProvider.
+func (p *Provider) Name() string { return "ntfy" }
+
+// Send delivers a push notification to the device's ntfy topic.
+func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
+ if msg.DeviceToken == "" {
+ return push.ErrEmptyToken
+ }
+ if p.baseURL == "" {
+ return fmt.Errorf("ntfy: base URL not configured")
+ }
+
+ // URL-escape each path segment of the device token. ntfy topics can be
+ // hierarchical (e.g. "ns/myapp/user-1") and we want to preserve those
+ // '/' separators while escaping any other special characters that
+ // could let a malicious token escape the topic path.
+ parts := strings.Split(msg.DeviceToken, "/")
+ for i, p := range parts {
+ parts[i] = url.PathEscape(p)
+ }
+ endpointURL := p.baseURL + "/" + strings.Join(parts, "/")
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(msg.Body))
+ if err != nil {
+ return fmt.Errorf("ntfy: build request: %w", err)
+ }
+
+ if msg.Title != "" {
+ req.Header.Set("Title", msg.Title)
+ }
+ if msg.Priority == push.PriorityHigh {
+ req.Header.Set("Priority", "high")
+ } else if msg.Priority == push.PriorityNormal {
+ req.Header.Set("Priority", "default")
+ }
+ if msg.Channel != "" {
+ // ntfy uses "Tags" for both visual emoji and operator-defined tags.
+ req.Header.Set("Tags", msg.Channel)
+ }
+ if msg.Badge > 0 {
+ req.Header.Set("X-Badge", fmt.Sprintf("%d", msg.Badge))
+ }
+ if len(msg.Data) > 0 {
+ b, err := json.Marshal(msg.Data)
+ if err != nil {
+ return fmt.Errorf("ntfy: marshal data: %w", err)
+ }
+ req.Header.Set("X-Data", base64.StdEncoding.EncodeToString(b))
+ }
+ if p.authToken != "" {
+ req.Header.Set("Authorization", "Bearer "+p.authToken)
+ }
+
+ resp, err := p.httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("ntfy: post: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode >= 400 {
+ body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
+ return fmt.Errorf("ntfy: http %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
+ }
+
+ // Drain body to allow connection reuse.
+ _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
+ return nil
+}
diff --git a/core/pkg/push/providers/ntfy/ntfy_test.go b/core/pkg/push/providers/ntfy/ntfy_test.go
new file mode 100644
index 0000000..d6f08a3
--- /dev/null
+++ b/core/pkg/push/providers/ntfy/ntfy_test.go
@@ -0,0 +1,191 @@
+package ntfy
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/DeBrosOfficial/network/pkg/push"
+)
+
+func TestSend_happy_path(t *testing.T) {
+ var (
+ gotPath string
+ gotBody string
+ gotTitle string
+ gotPriority string
+ gotAuth string
+ )
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath = r.URL.Path
+ gotTitle = r.Header.Get("Title")
+ gotPriority = r.Header.Get("Priority")
+ gotAuth = r.Header.Get("Authorization")
+ b, _ := io.ReadAll(r.Body)
+ gotBody = string(b)
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ p := New(Config{BaseURL: srv.URL, AuthToken: "secret"}, nil)
+ err := p.Send(context.Background(), push.PushMessage{
+ DeviceToken: "ns/myapp/user-1",
+ Title: "Hello",
+ Body: "World",
+ Priority: push.PriorityHigh,
+ })
+ if err != nil {
+ t.Fatalf("Send failed: %v", err)
+ }
+ if gotPath != "/ns/myapp/user-1" {
+ t.Errorf("expected path /ns/myapp/user-1, got %s", gotPath)
+ }
+ if gotTitle != "Hello" {
+ t.Errorf("expected Title=Hello, got %s", gotTitle)
+ }
+ if gotPriority != "high" {
+ t.Errorf("expected Priority=high, got %s", gotPriority)
+ }
+ if gotAuth != "Bearer secret" {
+ t.Errorf("expected Authorization=Bearer secret, got %s", gotAuth)
+ }
+ if gotBody != "World" {
+ t.Errorf("expected body=World, got %s", gotBody)
+ }
+}
+
+func TestSend_includes_data_header_when_data_set(t *testing.T) {
+ var gotData string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotData = r.Header.Get("X-Data")
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ p := New(Config{BaseURL: srv.URL}, nil)
+ err := p.Send(context.Background(), push.PushMessage{
+ DeviceToken: "topic",
+ Body: "x",
+ Data: map[string]interface{}{"call_id": "abc-123"},
+ })
+ if err != nil {
+ t.Fatalf("Send: %v", err)
+ }
+ decoded, err := base64.StdEncoding.DecodeString(gotData)
+ if err != nil {
+ t.Fatalf("X-Data not valid base64: %v", err)
+ }
+ var got map[string]interface{}
+ if err := json.Unmarshal(decoded, &got); err != nil {
+ t.Fatalf("X-Data not valid JSON: %v", err)
+ }
+ if got["call_id"] != "abc-123" {
+ t.Errorf("data round-trip failed: got %v", got)
+ }
+}
+
+func TestSend_no_data_no_data_header(t *testing.T) {
+ var gotData string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotData = r.Header.Get("X-Data")
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ p := New(Config{BaseURL: srv.URL}, nil)
+ if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
+ t.Fatal(err)
+ }
+ if gotData != "" {
+ t.Errorf("expected no X-Data header, got %q", gotData)
+ }
+}
+
+func TestSend_no_auth_header_when_token_empty(t *testing.T) {
+ var gotAuth string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotAuth = r.Header.Get("Authorization")
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ p := New(Config{BaseURL: srv.URL}, nil)
+ if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
+ t.Fatal(err)
+ }
+ if gotAuth != "" {
+ t.Errorf("expected no Authorization header, got %q", gotAuth)
+ }
+}
+
+func TestSend_4xx_returns_error_with_body_excerpt(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusForbidden)
+ _, _ = w.Write([]byte("forbidden topic"))
+ }))
+ defer srv.Close()
+
+ p := New(Config{BaseURL: srv.URL}, nil)
+ err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
+ if err == nil {
+ t.Fatal("expected error for 403")
+ }
+ if !strings.Contains(err.Error(), "403") || !strings.Contains(err.Error(), "forbidden") {
+ t.Errorf("error should mention status and body, got: %v", err)
+ }
+}
+
+func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) {
+ p := New(Config{BaseURL: "http://example.invalid"}, nil)
+ err := p.Send(context.Background(), push.PushMessage{Body: "x"})
+ if err == nil {
+ t.Fatal("expected error for empty token")
+ }
+ if err != push.ErrEmptyToken {
+ t.Errorf("expected ErrEmptyToken, got %v", err)
+ }
+}
+
+func TestSend_short_timeout_returns_error(t *testing.T) {
+ // Server that blocks for 2s — provider with 100ms timeout should give up.
+ blockUntil := make(chan struct{})
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ select {
+ case <-blockUntil:
+ case <-time.After(2 * time.Second):
+ }
+ }))
+ defer func() { close(blockUntil); srv.Close() }()
+
+ p := New(Config{BaseURL: srv.URL, Timeout: 100 * time.Millisecond}, nil)
+ start := time.Now()
+ err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
+ elapsed := time.Since(start)
+ if err == nil {
+ t.Error("expected timeout error")
+ }
+ if elapsed > 1*time.Second {
+ t.Errorf("expected fast timeout, took %v", elapsed)
+ }
+}
+
+func TestSend_no_baseURL_returns_error(t *testing.T) {
+ p := New(Config{}, nil)
+ err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
+ if err == nil {
+ t.Fatal("expected error for missing base URL")
+ }
+}
+
+func TestName(t *testing.T) {
+ p := New(Config{BaseURL: "http://x"}, nil)
+ if p.Name() != "ntfy" {
+ t.Errorf("expected Name=ntfy, got %s", p.Name())
+ }
+}
diff --git a/core/pkg/push/types.go b/core/pkg/push/types.go
new file mode 100644
index 0000000..44cdbe6
--- /dev/null
+++ b/core/pkg/push/types.go
@@ -0,0 +1,91 @@
+// Package push provides a generic push-notification abstraction for Orama.
+//
+// Apps register devices with a provider name ("ntfy", "expo", "apns", ...)
+// and a provider-specific token. The PushDispatcher routes outbound push
+// messages to the matching provider so call sites stay backend-agnostic.
+//
+// Long-term the platform aims to drop Expo in favour of direct APNs +
+// ntfy. The abstraction makes that swap a configuration change rather
+// than a code change.
+package push
+
+import (
+ "context"
+ "errors"
+)
+
+// PushPriority signals delivery urgency to the provider.
+// Providers that don't support priorities ignore the value.
+type PushPriority string
+
+const (
+ PriorityNormal PushPriority = "normal"
+ PriorityHigh PushPriority = "high"
+)
+
+// PushMessage is the provider-agnostic message format.
+//
+// DeviceToken is the provider-specific identifier (e.g. an ntfy topic,
+// an Expo push token, an APNs device token). The PushDispatcher fills
+// it in per-device before calling Send.
+type PushMessage struct {
+ DeviceToken string
+ Title string
+ Body string
+ Data map[string]interface{}
+ Badge int
+ Sound string
+ Channel string // "messages", "calls", etc — provider may map to its own channel concept
+ Priority PushPriority
+}
+
+// PushProvider is implemented by each backend (ntfy, expo, apns).
+type PushProvider interface {
+ Name() string
+ // Send delivers a single push. Returning an error counts as a delivery
+ // failure for that device; the dispatcher logs it and continues.
+ Send(ctx context.Context, msg PushMessage) error
+}
+
+// PushDevice represents a registered push target for a user.
+//
+// Token is plaintext in this struct — encryption happens at the storage
+// layer. Callers who load Devices from the store must treat tokens as
+// sensitive material (don't log them).
+type PushDevice struct {
+ ID string
+ Namespace string
+ UserID string
+ DeviceID string // app-provided
+ Provider string // matches PushProvider.Name()
+ Token string
+ Platform string // "ios" | "android" | "web"
+ AppVer string
+ CreatedAt int64 // unix seconds
+ UpdatedAt int64
+ LastSeen int64
+}
+
+// PushDeviceStore persists per-user device registrations.
+type PushDeviceStore interface {
+ // Upsert registers or updates a device. The Token is encrypted by the
+ // implementation before being written to durable storage.
+ Upsert(ctx context.Context, dev PushDevice) error
+
+ // Delete removes a single device by ID, scoped to the namespace.
+ Delete(ctx context.Context, namespace, id string) error
+
+ // ListForUser returns all devices for a user within a namespace.
+ // Tokens in the returned slice are decrypted.
+ ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error)
+}
+
+// Sentinel errors.
+var (
+ // ErrUnknownProvider is returned by the dispatcher when a device
+ // references a provider that isn't registered.
+ ErrUnknownProvider = errors.New("push: unknown provider")
+ // ErrEmptyToken is returned by providers when called with an empty
+ // DeviceToken.
+ ErrEmptyToken = errors.New("push: empty device token")
+)
diff --git a/core/pkg/serverless/aggregator/aggregator.go b/core/pkg/serverless/aggregator/aggregator.go
new file mode 100644
index 0000000..8e58ea2
--- /dev/null
+++ b/core/pkg/serverless/aggregator/aggregator.go
@@ -0,0 +1,295 @@
+// Package aggregator buffers PubSub trigger events per
+// (namespace, function, trigger) and flushes them as a single batched
+// invocation. It's used by the PubSub trigger dispatcher when a trigger
+// declares aggregation_window_ms > 0.
+//
+// State is local to each node — buffers are not replicated. This is by
+// design: aggregation is intended for high-frequency, lossy event streams
+// (presence, VAD signals, metrics). Crash recovery is not provided; an
+// orderly shutdown flushes pending buffers via Shutdown().
+package aggregator
+
+import (
+ "context"
+ "encoding/json"
+ "sync"
+ "time"
+
+ "go.uber.org/zap"
+)
+
+// DefaultMaxBatchSize is used when a trigger sets MaxBatchSize=0.
+const DefaultMaxBatchSize = 100
+
+// Event is one buffered message, mirroring the dispatcher's PubSubEvent
+// shape but kept local to avoid an import cycle.
+type Event struct {
+ Topic string `json:"topic"`
+ Data json.RawMessage `json:"data"`
+ Namespace string `json:"namespace"`
+ TriggerDepth int `json:"trigger_depth"`
+ Timestamp int64 `json:"timestamp"`
+}
+
+// BatchedPayload is what the function receives on a flush.
+// `Batched: true` lets a function differentiate single vs. batched mode
+// by parsing this discriminator first.
+type BatchedPayload struct {
+ Batched bool `json:"batched"`
+ Events []Event `json:"events"`
+}
+
+// FlushFn is invoked when a buffer flushes. It receives the marshalled
+// BatchedPayload and a context with a sane timeout. The aggregator does
+// not retry on flush errors — that's the invoker's responsibility.
+type FlushFn func(ctx context.Context, payload []byte)
+
+// FlushReason annotates why a flush happened. Useful for metrics.
+type FlushReason string
+
+const (
+ FlushReasonTimer FlushReason = "timer"
+ FlushReasonSize FlushReason = "size"
+ FlushReasonShutdown FlushReason = "shutdown"
+)
+
+// FlushFnWithReason is like FlushFn but also receives the reason.
+// Internal use; FlushFn is the simple public form.
+type FlushFnWithReason func(ctx context.Context, payload []byte, reason FlushReason)
+
+// bufferKey identifies a single in-memory buffer.
+type bufferKey struct {
+ Namespace string
+ FunctionID string
+ TriggerID string
+}
+
+type bufferEntry struct {
+ events []Event
+ timer *time.Timer
+ windowMs int
+ maxBatch int
+ flushFn FlushFnWithReason
+}
+
+// Aggregator buffers events per (namespace, function, trigger) and flushes
+// either when the window timer fires or when MaxBatch is reached.
+type Aggregator struct {
+ mu sync.Mutex
+ buffers map[bufferKey]*bufferEntry
+ logger *zap.Logger
+ flushTimeout time.Duration
+ // inflight tracks dispatched flush goroutines so Shutdown can wait
+ // for them to finish (or time out) before returning.
+ inflight sync.WaitGroup
+}
+
+// New creates an Aggregator. flushTimeout bounds the context passed to FlushFn.
+// 0 selects a sane default (60s).
+func New(logger *zap.Logger, flushTimeout time.Duration) *Aggregator {
+ if flushTimeout <= 0 {
+ flushTimeout = 60 * time.Second
+ }
+ if logger == nil {
+ logger = zap.NewNop()
+ }
+ return &Aggregator{
+ buffers: map[bufferKey]*bufferEntry{},
+ logger: logger.Named("aggregator"),
+ flushTimeout: flushTimeout,
+ }
+}
+
+// BufferRequest carries everything needed to add an event.
+type BufferRequest struct {
+ Namespace string
+ FunctionID string
+ TriggerID string
+ WindowMs int
+ MaxBatchSize int
+ FlushFn FlushFn // simple public form; internally promoted to FlushFnWithReason
+ Event Event
+}
+
+// Buffer adds an event to the matching buffer. Returns immediately —
+// the function is invoked later, asynchronously, when the window or
+// size threshold fires.
+//
+// If WindowMs <= 0, this method panics with a programming-error message
+// to surface misuse: callers should not buffer non-aggregating triggers.
+func (a *Aggregator) Buffer(req BufferRequest) {
+ if req.WindowMs <= 0 {
+ // Aggregator should never be called for non-aggregating triggers.
+ // Panicking here makes the caller bug obvious during development.
+ panic("aggregator: Buffer called with WindowMs <= 0")
+ }
+ maxBatch := req.MaxBatchSize
+ if maxBatch <= 0 {
+ maxBatch = DefaultMaxBatchSize
+ }
+
+ key := bufferKey{Namespace: req.Namespace, FunctionID: req.FunctionID, TriggerID: req.TriggerID}
+
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ entry, ok := a.buffers[key]
+ if !ok {
+ // Promote the user-facing FlushFn into the reason-aware variant.
+ // We capture req.FlushFn so subsequent Buffer calls keep using it.
+ userFn := req.FlushFn
+ entry = &bufferEntry{
+ events: make([]Event, 0, maxBatch),
+ windowMs: req.WindowMs,
+ maxBatch: maxBatch,
+ flushFn: func(ctx context.Context, payload []byte, reason FlushReason) {
+ if userFn != nil {
+ userFn(ctx, payload)
+ }
+ },
+ }
+ a.buffers[key] = entry
+ }
+
+ entry.events = append(entry.events, req.Event)
+
+ // Start the window timer on the first event of a new window.
+ if entry.timer == nil {
+ // Capture key by value for the closure.
+ k := key
+ entry.timer = time.AfterFunc(time.Duration(entry.windowMs)*time.Millisecond, func() {
+ a.flushByTimer(k)
+ })
+ }
+
+ // Size-triggered flush.
+ if len(entry.events) >= entry.maxBatch {
+ a.flushLocked(key, entry, FlushReasonSize)
+ }
+}
+
+// flushByTimer is invoked by time.AfterFunc; it acquires the lock then flushes.
+func (a *Aggregator) flushByTimer(key bufferKey) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ entry, ok := a.buffers[key]
+ if !ok {
+ // Buffer already flushed by size threshold and the bucket removed.
+ return
+ }
+ if len(entry.events) == 0 {
+ // Defensive: empty buffer — drop it so the map stays bounded.
+ delete(a.buffers, key)
+ return
+ }
+ a.flushLocked(key, entry, FlushReasonTimer)
+}
+
+// flushLocked must be called with a.mu held. It snapshots the current
+// events, removes the bucket entry, then dispatches the flush in a
+// goroutine so the caller doesn't block on the function invocation.
+//
+// Removing the bucket on flush keeps the buffers map bounded over the
+// lifetime of the process. If a subsequent event arrives for the same
+// (namespace, function, trigger) tuple, Buffer recreates the entry.
+func (a *Aggregator) flushLocked(key bufferKey, entry *bufferEntry, reason FlushReason) {
+ if entry.timer != nil {
+ entry.timer.Stop()
+ entry.timer = nil
+ }
+ if len(entry.events) == 0 {
+ // Empty bucket — drop it so the map doesn't accumulate.
+ delete(a.buffers, key)
+ return
+ }
+ events := entry.events
+
+ payload, err := json.Marshal(BatchedPayload{Batched: true, Events: events})
+ if err != nil {
+ a.logger.Error("failed to marshal batched payload",
+ zap.String("namespace", key.Namespace),
+ zap.String("function_id", key.FunctionID),
+ zap.String("trigger_id", key.TriggerID),
+ zap.Int("batch_size", len(events)),
+ zap.Error(err),
+ )
+ // Still drop the bucket — there's no point retrying with the same data.
+ delete(a.buffers, key)
+ return
+ }
+
+ a.logger.Debug("aggregator flush",
+ zap.String("namespace", key.Namespace),
+ zap.String("function_id", key.FunctionID),
+ zap.String("trigger_id", key.TriggerID),
+ zap.Int("batch_size", len(events)),
+ zap.String("reason", string(reason)),
+ )
+
+ flushFn := entry.flushFn
+ timeout := a.flushTimeout
+ delete(a.buffers, key)
+
+ a.inflight.Add(1)
+ go func() {
+ defer a.inflight.Done()
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ flushFn(ctx, payload, reason)
+ }()
+}
+
+// Shutdown drains all non-empty buffers and waits for the resulting flush
+// invocations to finish, bounded by `wait`. Callers should pass a wait
+// long enough to cover one function invocation (e.g. 5–10 seconds) but
+// short enough that a misbehaving function can't delay process exit.
+//
+// Returns true if all in-flight flushes completed before the deadline,
+// false on timeout (in which case some events are effectively lost).
+func (a *Aggregator) Shutdown(wait time.Duration) bool {
+ a.mu.Lock()
+ keys := make([]bufferKey, 0, len(a.buffers))
+ for k := range a.buffers {
+ keys = append(keys, k)
+ }
+ for _, k := range keys {
+ entry := a.buffers[k]
+ if entry == nil {
+ continue
+ }
+ if entry.timer != nil {
+ entry.timer.Stop()
+ entry.timer = nil
+ }
+ a.flushLocked(k, entry, FlushReasonShutdown)
+ }
+ a.mu.Unlock()
+
+ if wait <= 0 {
+ return true
+ }
+ done := make(chan struct{})
+ go func() {
+ a.inflight.Wait()
+ close(done)
+ }()
+ select {
+ case <-done:
+ return true
+ case <-time.After(wait):
+ a.logger.Warn("aggregator shutdown timed out; some buffered events may be lost")
+ return false
+ }
+}
+
+// Stats reports the current number of buffered events across all keys.
+// Useful for metrics.
+func (a *Aggregator) Stats() (numBuffers, totalEvents int) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ numBuffers = len(a.buffers)
+ for _, e := range a.buffers {
+ totalEvents += len(e.events)
+ }
+ return
+}
diff --git a/core/pkg/serverless/aggregator/aggregator_test.go b/core/pkg/serverless/aggregator/aggregator_test.go
new file mode 100644
index 0000000..523b049
--- /dev/null
+++ b/core/pkg/serverless/aggregator/aggregator_test.go
@@ -0,0 +1,307 @@
+package aggregator
+
+import (
+ "context"
+ "encoding/json"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "go.uber.org/zap"
+)
+
+func TestBuffer_panics_on_zero_window(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("expected panic when WindowMs <= 0")
+ }
+ }()
+ a := New(zap.NewNop(), time.Second)
+ a.Buffer(BufferRequest{
+ Namespace: "ns",
+ FunctionID: "fn",
+ TriggerID: "tr",
+ WindowMs: 0,
+ FlushFn: func(ctx context.Context, payload []byte) {},
+ Event: Event{Topic: "t"},
+ })
+}
+
+func TestBuffer_flushes_on_timer(t *testing.T) {
+ a := New(zap.NewNop(), 5*time.Second)
+
+ var (
+ got []Event
+ gotMu sync.Mutex
+ done = make(chan struct{})
+ )
+
+ flush := func(ctx context.Context, payload []byte) {
+ var p BatchedPayload
+ if err := json.Unmarshal(payload, &p); err != nil {
+ t.Errorf("unmarshal: %v", err)
+ }
+ gotMu.Lock()
+ got = append(got, p.Events...)
+ gotMu.Unlock()
+ close(done)
+ }
+
+ for i := 0; i < 3; i++ {
+ a.Buffer(BufferRequest{
+ Namespace: "ns",
+ FunctionID: "fn",
+ TriggerID: "tr",
+ WindowMs: 100, // short window so test runs fast
+ FlushFn: flush,
+ Event: Event{Topic: "presence:user", Data: json.RawMessage(`"e"`)},
+ })
+ }
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("flush did not fire within 2s")
+ }
+
+ gotMu.Lock()
+ defer gotMu.Unlock()
+ if len(got) != 3 {
+ t.Errorf("expected 3 buffered events, got %d", len(got))
+ }
+}
+
+func TestBuffer_flushes_on_max_batch_size(t *testing.T) {
+ a := New(zap.NewNop(), 5*time.Second)
+
+ var (
+ flushCount int32
+ flushSize int32
+ done = make(chan struct{})
+ )
+
+ flush := func(ctx context.Context, payload []byte) {
+ var p BatchedPayload
+ _ = json.Unmarshal(payload, &p)
+ atomic.AddInt32(&flushCount, 1)
+ atomic.StoreInt32(&flushSize, int32(len(p.Events)))
+ select {
+ case <-done:
+ default:
+ close(done)
+ }
+ }
+
+ for i := 0; i < 5; i++ {
+ a.Buffer(BufferRequest{
+ Namespace: "ns",
+ FunctionID: "fn",
+ TriggerID: "tr",
+ WindowMs: 30_000, // long enough that the timer won't fire
+ MaxBatchSize: 5,
+ FlushFn: flush,
+ Event: Event{Topic: "t"},
+ })
+ }
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("max-batch flush did not fire")
+ }
+
+ if atomic.LoadInt32(&flushCount) != 1 {
+ t.Errorf("expected 1 flush, got %d", flushCount)
+ }
+ if atomic.LoadInt32(&flushSize) != 5 {
+ t.Errorf("expected batch size 5, got %d", flushSize)
+ }
+}
+
+func TestBuffer_separate_keys_independent(t *testing.T) {
+ a := New(zap.NewNop(), 5*time.Second)
+
+ var (
+ mu sync.Mutex
+ counts = map[string]int{}
+ flush = func(name string) FlushFn {
+ return func(ctx context.Context, payload []byte) {
+ var p BatchedPayload
+ _ = json.Unmarshal(payload, &p)
+ mu.Lock()
+ counts[name] += len(p.Events)
+ mu.Unlock()
+ }
+ }
+ )
+
+ a.Buffer(BufferRequest{
+ Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A",
+ WindowMs: 100, FlushFn: flush("A"),
+ Event: Event{Topic: "a"},
+ })
+ a.Buffer(BufferRequest{
+ Namespace: "ns", FunctionID: "fn", TriggerID: "tr-B",
+ WindowMs: 100, FlushFn: flush("B"),
+ Event: Event{Topic: "b"},
+ })
+ a.Buffer(BufferRequest{
+ Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A",
+ WindowMs: 100, FlushFn: flush("A"),
+ Event: Event{Topic: "a2"},
+ })
+
+ time.Sleep(500 * time.Millisecond)
+
+ mu.Lock()
+ defer mu.Unlock()
+ if counts["A"] != 2 {
+ t.Errorf("A: expected 2 events, got %d", counts["A"])
+ }
+ if counts["B"] != 1 {
+ t.Errorf("B: expected 1 event, got %d", counts["B"])
+ }
+}
+
+func TestShutdown_flushes_all_buffers(t *testing.T) {
+ a := New(zap.NewNop(), 2*time.Second)
+
+ var flushed int32
+ flush := func(ctx context.Context, payload []byte) {
+ atomic.AddInt32(&flushed, 1)
+ }
+
+ for i := 0; i < 4; i++ {
+ a.Buffer(BufferRequest{
+ Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
+ WindowMs: 30_000,
+ FlushFn: flush,
+ Event: Event{Topic: "t"},
+ })
+ }
+ // Different trigger key — should produce a separate flush.
+ a.Buffer(BufferRequest{
+ Namespace: "ns", FunctionID: "fn", TriggerID: "other",
+ WindowMs: 30_000,
+ FlushFn: flush,
+ Event: Event{Topic: "t2"},
+ })
+
+ a.Shutdown(2*time.Second)
+
+ deadline := time.Now().Add(2 * time.Second)
+ for atomic.LoadInt32(&flushed) < 2 {
+ if time.Now().After(deadline) {
+ t.Fatalf("expected 2 flushes from Shutdown, got %d", flushed)
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+}
+
+func TestShutdown_skips_empty_buffers(t *testing.T) {
+ a := New(zap.NewNop(), 2*time.Second)
+
+ var flushed int32
+ flush := func(ctx context.Context, payload []byte) {
+ atomic.AddInt32(&flushed, 1)
+ }
+
+ // Add an event to create the buffer entry, then drain via size flush.
+ a.Buffer(BufferRequest{
+ Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
+ WindowMs: 30_000, MaxBatchSize: 1,
+ FlushFn: flush, Event: Event{Topic: "t"},
+ })
+
+ // Wait for the size-triggered flush to drain.
+ deadline := time.Now().Add(2 * time.Second)
+ for atomic.LoadInt32(&flushed) < 1 {
+ if time.Now().After(deadline) {
+ t.Fatal("size flush didn't fire")
+ }
+ time.Sleep(5 * time.Millisecond)
+ }
+
+ // Now the buffer is empty. Shutdown should not flush again.
+ a.Shutdown(2*time.Second)
+ time.Sleep(200 * time.Millisecond)
+ if atomic.LoadInt32(&flushed) != 1 {
+ t.Errorf("Shutdown flushed an empty buffer: total flushes %d", flushed)
+ }
+}
+
+func TestStats_reports_buffered_state(t *testing.T) {
+ a := New(zap.NewNop(), 2*time.Second)
+ flush := func(ctx context.Context, payload []byte) {}
+
+ a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
+ a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
+ a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "b", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
+
+ bufs, evs := a.Stats()
+ if bufs != 2 {
+ t.Errorf("expected 2 buffers, got %d", bufs)
+ }
+ if evs != 3 {
+ t.Errorf("expected 3 buffered events, got %d", evs)
+ }
+}
+
+func TestBuffer_concurrent_writes_no_race(t *testing.T) {
+ // Run with -race: this should not detect any data races.
+ a := New(zap.NewNop(), 2*time.Second)
+ flush := func(ctx context.Context, payload []byte) {}
+
+ var wg sync.WaitGroup
+ for g := 0; g < 8; g++ {
+ wg.Add(1)
+ go func(gid int) {
+ defer wg.Done()
+ for i := 0; i < 50; i++ {
+ a.Buffer(BufferRequest{
+ Namespace: "ns",
+ FunctionID: "fn",
+ TriggerID: "tr",
+ WindowMs: 200,
+ FlushFn: flush,
+ Event: Event{Topic: "t"},
+ })
+ }
+ }(g)
+ }
+ wg.Wait()
+ // Drain.
+ a.Shutdown(2*time.Second)
+}
+
+func TestBuffer_payload_includes_batched_true_and_topic(t *testing.T) {
+ a := New(zap.NewNop(), 2*time.Second)
+
+ got := make(chan BatchedPayload, 1)
+ flush := func(ctx context.Context, payload []byte) {
+ var p BatchedPayload
+ if err := json.Unmarshal(payload, &p); err != nil {
+ t.Errorf("unmarshal: %v", err)
+ }
+ got <- p
+ }
+
+ a.Buffer(BufferRequest{
+ Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
+ WindowMs: 50, FlushFn: flush,
+ Event: Event{Topic: "presence:user-1", Data: json.RawMessage(`{"x":1}`)},
+ })
+
+ select {
+ case p := <-got:
+ if !p.Batched {
+ t.Error("payload should have Batched=true")
+ }
+ if len(p.Events) != 1 || p.Events[0].Topic != "presence:user-1" {
+ t.Errorf("unexpected events: %+v", p.Events)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("flush did not fire")
+ }
+}
diff --git a/core/pkg/serverless/engine.go b/core/pkg/serverless/engine.go
index aeddc8c..2a9c3e4 100644
--- a/core/pkg/serverless/engine.go
+++ b/core/pkg/serverless/engine.go
@@ -328,6 +328,8 @@ func (e *Engine) registerHostModule(ctx context.Context) error {
NewFunctionBuilder().WithFunc(e.hCacheIncrBy).Export("cache_incr_by").
NewFunctionBuilder().WithFunc(e.hHTTPFetch).Export("http_fetch").
NewFunctionBuilder().WithFunc(e.hPubSubPublish).Export("pubsub_publish").
+ NewFunctionBuilder().WithFunc(e.hPubSubPublishBatch).Export("pubsub_publish_batch").
+ NewFunctionBuilder().WithFunc(e.hPushSend).Export("push_send").
NewFunctionBuilder().WithFunc(e.hLogInfo).Export("log_info").
NewFunctionBuilder().WithFunc(e.hLogError).Export("log_error").
Instantiate(ctx)
@@ -517,6 +519,46 @@ func (e *Engine) hPubSubPublish(ctx context.Context, mod api.Module, topicPtr, t
return 1 // Success
}
+// hPubSubPublishBatch is the WASM-callable wrapper for PubSubPublishBatch.
+// Input: pointer/length of a JSON array of {topic, data_base64}.
+// Returns 1 on success, 0 on error.
+func (e *Engine) hPubSubPublishBatch(ctx context.Context, mod api.Module, msgsPtr, msgsLen uint32) uint32 {
+ msgsJSON, ok := e.executor.ReadFromGuest(mod, msgsPtr, msgsLen)
+ if !ok {
+ return 0
+ }
+ if err := e.hostServices.PubSubPublishBatch(ctx, msgsJSON); err != nil {
+ e.logger.Error("host function pubsub_publish_batch failed", zap.Error(err))
+ return 0
+ }
+ return 1
+}
+
+// hPushSend is the WASM-callable wrapper for PushSend.
+// Inputs:
+// userIDPtr/userIDLen — UTF-8 user ID to push to (within the function's
+// own namespace; the namespace is server-side trusted)
+// msgPtr/msgLen — JSON payload matching hostfunctions.PushSendArgs
+// Returns 1 on success, 0 on error.
+func (e *Engine) hPushSend(ctx context.Context, mod api.Module,
+ userIDPtr, userIDLen, msgPtr, msgLen uint32) uint32 {
+ userID, ok := e.executor.ReadFromGuest(mod, userIDPtr, userIDLen)
+ if !ok {
+ return 0
+ }
+ msgJSON, ok := e.executor.ReadFromGuest(mod, msgPtr, msgLen)
+ if !ok {
+ return 0
+ }
+ if err := e.hostServices.PushSend(ctx, string(userID), msgJSON); err != nil {
+ e.logger.Error("host function push_send failed",
+ zap.String("user_id", string(userID)),
+ zap.Error(err))
+ return 0
+ }
+ return 1
+}
+
func (e *Engine) hLogInfo(ctx context.Context, mod api.Module, ptr, size uint32) {
msg, ok := e.executor.ReadFromGuest(mod, ptr, size)
if ok {
diff --git a/core/pkg/serverless/hostfuncs_test.go b/core/pkg/serverless/hostfuncs_test.go
index cbdbb32..93e363f 100644
--- a/core/pkg/serverless/hostfuncs_test.go
+++ b/core/pkg/serverless/hostfuncs_test.go
@@ -90,6 +90,14 @@ func (m *mockHostServices) PubSubPublish(ctx context.Context, topic string, data
return nil
}
+func (m *mockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
+ return nil
+}
+
+func (m *mockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
+ return nil
+}
+
func (m *mockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
return nil
}
diff --git a/core/pkg/serverless/hostfunctions/host_services.go b/core/pkg/serverless/hostfunctions/host_services.go
index 64f6878..069adcf 100644
--- a/core/pkg/serverless/hostfunctions/host_services.go
+++ b/core/pkg/serverless/hostfunctions/host_services.go
@@ -5,6 +5,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/pubsub"
+ "github.com/DeBrosOfficial/network/pkg/push"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/DeBrosOfficial/network/pkg/tlsutil"
@@ -13,6 +14,10 @@ import (
)
// NewHostFunctions creates a new HostFunctions instance.
+//
+// pushDispatcher may be nil when push isn't configured on this gateway —
+// in that case PushSend hostfunc returns nil (silent no-op) so functions
+// remain portable across deployments with/without push.
func NewHostFunctions(
db rqlite.Client,
cacheClient olriclib.Client,
@@ -20,6 +25,7 @@ func NewHostFunctions(
pubsubAdapter *pubsub.ClientAdapter,
wsManager serverless.WebSocketManager,
secrets serverless.SecretsManager,
+ pushDispatcher *push.PushDispatcher,
cfg HostFunctionsConfig,
logger *zap.Logger,
) *HostFunctions {
@@ -29,15 +35,16 @@ func NewHostFunctions(
}
return &HostFunctions{
- db: db,
- cacheClient: cacheClient,
- storage: storage,
- ipfsAPIURL: cfg.IPFSAPIURL,
- pubsub: pubsubAdapter,
- wsManager: wsManager,
- secrets: secrets,
- httpClient: tlsutil.NewHTTPClient(httpTimeout),
- logger: logger,
- logs: make([]serverless.LogEntry, 0),
+ db: db,
+ cacheClient: cacheClient,
+ storage: storage,
+ ipfsAPIURL: cfg.IPFSAPIURL,
+ pubsub: pubsubAdapter,
+ wsManager: wsManager,
+ secrets: secrets,
+ pushDispatcher: pushDispatcher,
+ httpClient: tlsutil.NewHTTPClient(httpTimeout),
+ logger: logger,
+ logs: make([]serverless.LogEntry, 0),
}
}
diff --git a/core/pkg/serverless/hostfunctions/pubsub.go b/core/pkg/serverless/hostfunctions/pubsub.go
index 82394c1..7e9b570 100644
--- a/core/pkg/serverless/hostfunctions/pubsub.go
+++ b/core/pkg/serverless/hostfunctions/pubsub.go
@@ -2,8 +2,11 @@ package hostfunctions
import (
"context"
+ "encoding/base64"
+ "encoding/json"
"fmt"
+ "github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/serverless"
)
@@ -21,6 +24,64 @@ func (h *HostFunctions) PubSubPublish(ctx context.Context, topic string, data []
return nil
}
+// pubSubBatchEntry mirrors the JSON shape accepted by PubSubPublishBatch.
+type pubSubBatchEntry struct {
+ Topic string `json:"topic"`
+ DataB64 string `json:"data_base64"`
+}
+
+// PubSubPublishBatch publishes multiple messages in parallel.
+//
+// Input is JSON: [{"topic":"...","data_base64":"..."}, ...]
+// Up to pubsub.MaxBatchSize entries per call.
+//
+// Default behavior is fail-fast (first publish error is returned). The
+// host function does not currently expose a best-effort flag — WASM
+// callers that need it should call this function multiple times in
+// chunks they're willing to retry independently.
+func (h *HostFunctions) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
+ if h.pubsub == nil {
+ return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("pubsub not available")}
+ }
+
+ var entries []pubSubBatchEntry
+ if err := json.Unmarshal(msgsJSON, &entries); err != nil {
+ return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("invalid json: %w", err)}
+ }
+ if len(entries) == 0 {
+ return nil
+ }
+ if len(entries) > pubsub.MaxBatchSize {
+ return &serverless.HostFunctionError{
+ Function: "pubsub_publish_batch",
+ Cause: fmt.Errorf("too many messages: max %d per batch", pubsub.MaxBatchSize),
+ }
+ }
+
+ msgs := make([]pubsub.TopicMessage, 0, len(entries))
+ for i, e := range entries {
+ if e.Topic == "" {
+ return &serverless.HostFunctionError{
+ Function: "pubsub_publish_batch",
+ Cause: fmt.Errorf("entry %d: empty topic", i),
+ }
+ }
+ data, err := base64.StdEncoding.DecodeString(e.DataB64)
+ if err != nil {
+ return &serverless.HostFunctionError{
+ Function: "pubsub_publish_batch",
+ Cause: fmt.Errorf("entry %d (topic %q): bad base64: %w", i, e.Topic, err),
+ }
+ }
+ msgs = append(msgs, pubsub.TopicMessage{Topic: e.Topic, Data: data})
+ }
+
+ if err := h.pubsub.PublishBatch(ctx, msgs, pubsub.PublishBatchOptions{}); err != nil {
+ return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: err}
+ }
+ return nil
+}
+
// WSSend sends data to a specific WebSocket client.
func (h *HostFunctions) WSSend(ctx context.Context, clientID string, data []byte) error {
if h.wsManager == nil {
diff --git a/core/pkg/serverless/hostfunctions/push.go b/core/pkg/serverless/hostfunctions/push.go
new file mode 100644
index 0000000..dfd2638
--- /dev/null
+++ b/core/pkg/serverless/hostfunctions/push.go
@@ -0,0 +1,103 @@
+package hostfunctions
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/DeBrosOfficial/network/pkg/push"
+ "github.com/DeBrosOfficial/network/pkg/serverless"
+)
+
+// PushSendArgs is the JSON payload format the WASM caller marshals into
+// the `msgJSON` argument of PushSend. Mirrors push.PushMessage minus the
+// device-token (which is filled in per-device by the dispatcher).
+type PushSendArgs struct {
+ Title string `json:"title,omitempty"`
+ Body string `json:"body,omitempty"`
+ Channel string `json:"channel,omitempty"`
+ Priority string `json:"priority,omitempty"` // "high" | "normal" | ""
+ Badge int `json:"badge,omitempty"`
+ Sound string `json:"sound,omitempty"`
+ Data map[string]interface{} `json:"data,omitempty"`
+}
+
+// MaxPushSendArgsBytes caps the JSON arg size to a few KB. Push payloads
+// are small by construction (APNs caps at 4KB, ntfy/Expo similar).
+const MaxPushSendArgsBytes = 16 * 1024
+
+// PushSend implements serverless.HostServices.PushSend.
+//
+// Sends a push notification to all devices the user has registered in the
+// function's namespace. The caller can only target users in their own
+// namespace — the dispatcher reads the namespace from the invocation
+// context (set by the engine before invoking).
+//
+// If push is not configured on this gateway (no dispatcher), this returns
+// nil (silent no-op) so functions remain portable across environments.
+func (h *HostFunctions) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
+ if h.pushDispatcher == nil {
+ // Silent no-op — push isn't configured on this gateway.
+ return nil
+ }
+ if userID == "" {
+ return &serverless.HostFunctionError{
+ Function: "push_send",
+ Cause: fmt.Errorf("user_id required"),
+ }
+ }
+ if len(msgJSON) > MaxPushSendArgsBytes {
+ return &serverless.HostFunctionError{
+ Function: "push_send",
+ Cause: fmt.Errorf("msg too large: max %d bytes", MaxPushSendArgsBytes),
+ }
+ }
+
+ var args PushSendArgs
+ if err := json.Unmarshal(msgJSON, &args); err != nil {
+ return &serverless.HostFunctionError{
+ Function: "push_send",
+ Cause: fmt.Errorf("invalid json: %w", err),
+ }
+ }
+
+ // Resolve namespace from the current invocation context. A function
+ // can NEVER push to another namespace's users — the namespace is
+ // trusted server-side, not from the WASM input.
+ h.invCtxLock.RLock()
+ var namespace string
+ if h.invCtx != nil {
+ namespace = h.invCtx.Namespace
+ }
+ h.invCtxLock.RUnlock()
+
+ if namespace == "" {
+ return &serverless.HostFunctionError{
+ Function: "push_send",
+ Cause: fmt.Errorf("no namespace in invocation context"),
+ }
+ }
+
+ priority := push.PriorityNormal
+ switch args.Priority {
+ case "high":
+ priority = push.PriorityHigh
+ case "normal", "":
+ priority = push.PriorityNormal
+ }
+
+ msg := push.PushMessage{
+ Title: args.Title,
+ Body: args.Body,
+ Channel: args.Channel,
+ Priority: priority,
+ Badge: args.Badge,
+ Sound: args.Sound,
+ Data: args.Data,
+ }
+
+ if err := h.pushDispatcher.SendToUser(ctx, namespace, userID, msg); err != nil {
+ return &serverless.HostFunctionError{Function: "push_send", Cause: err}
+ }
+ return nil
+}
diff --git a/core/pkg/serverless/hostfunctions/types.go b/core/pkg/serverless/hostfunctions/types.go
index 3df7406..28e1aea 100644
--- a/core/pkg/serverless/hostfunctions/types.go
+++ b/core/pkg/serverless/hostfunctions/types.go
@@ -7,6 +7,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/pubsub"
+ "github.com/DeBrosOfficial/network/pkg/push"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
olriclib "github.com/olric-data/olric"
@@ -32,6 +33,10 @@ type HostFunctions struct {
httpClient *http.Client
logger *zap.Logger
+ // pushDispatcher may be nil when push isn't configured for this gateway.
+ // In that case PushSend returns nil silently — see hostfunctions/push.go.
+ pushDispatcher *push.PushDispatcher
+
// Current invocation context (set per-execution)
invCtx *serverless.InvocationContext
invCtxLock sync.RWMutex
diff --git a/core/pkg/serverless/mocks_test.go b/core/pkg/serverless/mocks_test.go
index 2146358..94fffba 100644
--- a/core/pkg/serverless/mocks_test.go
+++ b/core/pkg/serverless/mocks_test.go
@@ -176,6 +176,14 @@ func (m *MockHostServices) PubSubPublish(ctx context.Context, topic string, data
return nil
}
+func (m *MockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
+ return nil
+}
+
+func (m *MockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
+ return nil
+}
+
func (m *MockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
return nil
}
diff --git a/core/pkg/serverless/triggers/dispatcher.go b/core/pkg/serverless/triggers/dispatcher.go
index 94e5d55..d004003 100644
--- a/core/pkg/serverless/triggers/dispatcher.go
+++ b/core/pkg/serverless/triggers/dispatcher.go
@@ -6,14 +6,12 @@ import (
"time"
"github.com/DeBrosOfficial/network/pkg/serverless"
+ "github.com/DeBrosOfficial/network/pkg/serverless/aggregator"
olriclib "github.com/olric-data/olric"
"go.uber.org/zap"
)
const (
- // triggerCacheDMap is the Olric DMap name for caching trigger lookups.
- triggerCacheDMap = "pubsub_triggers"
-
// maxTriggerDepth prevents infinite loops when triggered functions publish
// back to the same topic via the HTTP API.
maxTriggerDepth = 5
@@ -37,6 +35,7 @@ type PubSubDispatcher struct {
store *PubSubTriggerStore
invoker *serverless.Invoker
olricClient olriclib.Client // may be nil (cache disabled)
+ aggregator *aggregator.Aggregator
logger *zap.Logger
}
@@ -51,10 +50,17 @@ func NewPubSubDispatcher(
store: store,
invoker: invoker,
olricClient: olricClient,
+ aggregator: aggregator.New(logger, dispatchTimeout),
logger: logger,
}
}
+// Aggregator exposes the underlying aggregator so callers (gateway lifecycle)
+// can flush pending buffers on shutdown.
+func (d *PubSubDispatcher) Aggregator() *aggregator.Aggregator {
+ return d.aggregator
+}
+
// Dispatch looks up all triggers registered for the given topic+namespace and
// invokes matching functions asynchronously. Each invocation runs in its own
// goroutine and does not block the caller.
@@ -82,7 +88,7 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
return
}
- // Build the event payload once for all invocations
+ // Build the per-event payload once for non-aggregating dispatches.
event := PubSubEvent{
Topic: topic,
Data: json.RawMessage(data),
@@ -90,11 +96,6 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
TriggerDepth: depth + 1,
Timestamp: time.Now().Unix(),
}
- eventJSON, err := json.Marshal(event)
- if err != nil {
- d.logger.Error("Failed to marshal PubSub event", zap.Error(err))
- return
- }
d.logger.Debug("Dispatching PubSub triggers",
zap.String("namespace", namespace),
@@ -103,94 +104,81 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
zap.Int("depth", depth),
)
+ var (
+ eventJSON []byte
+ marshalErr error
+ )
+
for _, match := range matches {
+ if match.AggregationWindowMs > 0 {
+ d.bufferEvent(match, event)
+ continue
+ }
+ // Lazily marshal — non-aggregating triggers need eventJSON.
+ if eventJSON == nil && marshalErr == nil {
+ eventJSON, marshalErr = json.Marshal(event)
+ if marshalErr != nil {
+ d.logger.Error("Failed to marshal PubSub event", zap.Error(marshalErr))
+ continue
+ }
+ }
+ if marshalErr != nil {
+ continue
+ }
go d.invokeFunction(match, eventJSON)
}
}
-// InvalidateCache removes the cached trigger lookup for a namespace+topic.
-// Call this when triggers are added or removed.
-func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) {
- if d.olricClient == nil {
- return
- }
-
- dm, err := d.olricClient.NewDMap(triggerCacheDMap)
- if err != nil {
- d.logger.Debug("Failed to get trigger cache DMap for invalidation", zap.Error(err))
- return
- }
-
- key := cacheKey(namespace, topic)
- if _, err := dm.Delete(ctx, key); err != nil {
- d.logger.Debug("Failed to invalidate trigger cache", zap.String("key", key), zap.Error(err))
- }
+// bufferEvent routes an event through the aggregator. The flush callback
+// invokes the function with the batched payload.
+func (d *PubSubDispatcher) bufferEvent(match TriggerMatch, event PubSubEvent) {
+ d.aggregator.Buffer(aggregator.BufferRequest{
+ Namespace: match.Namespace,
+ FunctionID: match.FunctionID,
+ TriggerID: match.TriggerID,
+ WindowMs: match.AggregationWindowMs,
+ MaxBatchSize: match.AggregationMaxBatchSize,
+ Event: aggregator.Event{
+ Topic: event.Topic,
+ Data: event.Data,
+ Namespace: event.Namespace,
+ TriggerDepth: event.TriggerDepth,
+ Timestamp: event.Timestamp,
+ },
+ FlushFn: func(ctx context.Context, payload []byte) {
+ req := &serverless.InvokeRequest{
+ Namespace: match.Namespace,
+ FunctionName: match.FunctionName,
+ Input: payload,
+ TriggerType: serverless.TriggerTypePubSub,
+ }
+ if _, err := d.invoker.Invoke(ctx, req); err != nil {
+ d.logger.Warn("Aggregated PubSub invocation failed",
+ zap.String("function", match.FunctionName),
+ zap.String("trigger_id", match.TriggerID),
+ zap.Error(err),
+ )
+ }
+ },
+ })
}
-// getMatches returns the trigger matches for a topic+namespace, using Olric cache when available.
+// InvalidateCache is now a no-op — the dispatcher no longer caches lookups.
+// Kept on the type so callers who used it still compile.
+func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) {}
+
+// getMatches returns the trigger matches for a topic+namespace.
+//
+// Caching note: an earlier revision cached results in Olric keyed by
+// (namespace, topic). With wildcard triggers the cache becomes
+// inconsistent — a single trigger Add/Remove invalidates an unbounded
+// number of resolved-topic keys. The cache was removed; re-introducing
+// it requires a generation-counter (or TTL) scheme that handles
+// wildcard pattern changes.
func (d *PubSubDispatcher) getMatches(ctx context.Context, namespace, topic string) ([]TriggerMatch, error) {
- // Try cache first
- if d.olricClient != nil {
- if matches, ok := d.getCached(ctx, namespace, topic); ok {
- return matches, nil
- }
- }
-
- // Cache miss — query database
- matches, err := d.store.GetByTopicAndNamespace(ctx, topic, namespace)
- if err != nil {
- return nil, err
- }
-
- // Populate cache
- if d.olricClient != nil && matches != nil {
- d.setCache(ctx, namespace, topic, matches)
- }
-
- return matches, nil
+ return d.store.GetByTopicAndNamespace(ctx, topic, namespace)
}
-// getCached attempts to retrieve trigger matches from Olric cache.
-func (d *PubSubDispatcher) getCached(ctx context.Context, namespace, topic string) ([]TriggerMatch, bool) {
- dm, err := d.olricClient.NewDMap(triggerCacheDMap)
- if err != nil {
- return nil, false
- }
-
- key := cacheKey(namespace, topic)
- result, err := dm.Get(ctx, key)
- if err != nil {
- return nil, false
- }
-
- data, err := result.Byte()
- if err != nil {
- return nil, false
- }
-
- var matches []TriggerMatch
- if err := json.Unmarshal(data, &matches); err != nil {
- return nil, false
- }
-
- return matches, true
-}
-
-// setCache stores trigger matches in Olric cache.
-func (d *PubSubDispatcher) setCache(ctx context.Context, namespace, topic string, matches []TriggerMatch) {
- dm, err := d.olricClient.NewDMap(triggerCacheDMap)
- if err != nil {
- return
- }
-
- data, err := json.Marshal(matches)
- if err != nil {
- return
- }
-
- key := cacheKey(namespace, topic)
- _ = dm.Put(ctx, key, data)
-}
// invokeFunction invokes a single function for a trigger match.
func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte) {
@@ -224,7 +212,3 @@ func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte)
)
}
-// cacheKey returns the Olric cache key for a namespace+topic pair.
-func cacheKey(namespace, topic string) string {
- return "triggers:" + namespace + ":" + topic
-}
diff --git a/core/pkg/serverless/triggers/pattern.go b/core/pkg/serverless/triggers/pattern.go
new file mode 100644
index 0000000..79f2fae
--- /dev/null
+++ b/core/pkg/serverless/triggers/pattern.go
@@ -0,0 +1,138 @@
+package triggers
+
+import (
+ "fmt"
+ "strings"
+)
+
+// MaxPatternLength is the maximum allowed glob pattern length.
+const MaxPatternLength = 256
+
+// ValidatePattern checks that a glob pattern is well-formed.
+// Empty patterns and unbalanced character classes return an error.
+// Patterns longer than MaxPatternLength are rejected to keep DB scans bounded.
+func ValidatePattern(p string) error {
+ if p == "" {
+ return fmt.Errorf("empty pattern")
+ }
+ if len(p) > MaxPatternLength {
+ return fmt.Errorf("pattern too long: %d > %d", len(p), MaxPatternLength)
+ }
+ open := 0
+ for i, c := range p {
+ switch c {
+ case '[':
+ open++
+ case ']':
+ open--
+ if open < 0 {
+ return fmt.Errorf("unmatched ']' at position %d", i)
+ }
+ }
+ }
+ if open != 0 {
+ return fmt.Errorf("unmatched '['")
+ }
+ return nil
+}
+
+// IsWildcard reports whether the pattern contains any glob metacharacter.
+// Useful for choosing between exact-match cache keys and wildcard scans.
+func IsWildcard(p string) bool {
+ return strings.ContainsAny(p, "*?[")
+}
+
+// PatternMatches returns true when topic matches pattern under Orama's glob
+// semantics:
+// - '*' matches zero or more characters EXCEPT ':'
+// - '**' matches zero or more characters INCLUDING ':' (deep wildcard)
+// - '?' matches exactly one character (any)
+// - '[abc]' / '[!abc]' character classes
+//
+// SQLite's GLOB is the first-pass filter (in pubsub_store.go); this
+// post-filter enforces segment boundaries for single-'*' patterns since
+// SQLite GLOB treats '*' as "any chars including separators".
+func PatternMatches(pattern, topic string) bool {
+ if strings.Contains(pattern, "**") {
+ // Deep wildcards already accept across segment boundaries — SQLite GLOB
+ // already accepted this row. No further filtering needed.
+ return true
+ }
+ return strictGlobMatch(pattern, topic)
+}
+
+// strictGlobMatch implements glob matching where '*' does NOT cross ':'.
+// Recursive backtracking matcher; bounded length keeps it cheap.
+func strictGlobMatch(pattern, s string) bool {
+ pi, si := 0, 0
+ starPi, starSi := -1, -1
+
+ for si < len(s) {
+ if pi < len(pattern) {
+ pc := pattern[pi]
+ switch pc {
+ case '?':
+ pi++
+ si++
+ continue
+ case '*':
+ // Remember position so we can backtrack.
+ starPi = pi
+ starSi = si
+ pi++
+ continue
+ case '[':
+ end := strings.IndexByte(pattern[pi+1:], ']')
+ if end < 0 {
+ return false
+ }
+ class := pattern[pi+1 : pi+1+end]
+ if matchClass(class, s[si]) {
+ pi += end + 2
+ si++
+ continue
+ }
+ default:
+ if pc == s[si] {
+ pi++
+ si++
+ continue
+ }
+ }
+ }
+ // No match at this position — try to extend the last '*' if any,
+ // but '*' must not cross a ':' segment separator.
+ if starPi >= 0 && s[starSi] != ':' {
+ starSi++
+ pi = starPi + 1
+ si = starSi
+ continue
+ }
+ return false
+ }
+
+ // Consume any trailing '*' in the pattern.
+ for pi < len(pattern) && pattern[pi] == '*' {
+ pi++
+ }
+ return pi == len(pattern)
+}
+
+// matchClass reports whether c matches the SQLite-style character class body
+// (between '[' and ']'). Supports negation with leading '!'.
+func matchClass(class string, c byte) bool {
+ if class == "" {
+ return false
+ }
+ negate := false
+ if class[0] == '!' {
+ negate = true
+ class = class[1:]
+ }
+ for i := 0; i < len(class); i++ {
+ if class[i] == c {
+ return !negate
+ }
+ }
+ return negate
+}
diff --git a/core/pkg/serverless/triggers/pattern_test.go b/core/pkg/serverless/triggers/pattern_test.go
new file mode 100644
index 0000000..84e25c8
--- /dev/null
+++ b/core/pkg/serverless/triggers/pattern_test.go
@@ -0,0 +1,162 @@
+package triggers
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestValidatePattern_empty_returns_error(t *testing.T) {
+ if err := ValidatePattern(""); err == nil {
+ t.Error("expected error for empty pattern")
+ }
+}
+
+func TestValidatePattern_too_long_returns_error(t *testing.T) {
+ long := strings.Repeat("a", MaxPatternLength+1)
+ if err := ValidatePattern(long); err == nil {
+ t.Error("expected error for over-long pattern")
+ }
+}
+
+func TestValidatePattern_unbalanced_brackets_returns_error(t *testing.T) {
+ cases := []string{"a[b", "a]b", "[a[b]", "a]"}
+ for _, c := range cases {
+ if err := ValidatePattern(c); err == nil {
+ t.Errorf("expected error for %q", c)
+ }
+ }
+}
+
+func TestValidatePattern_valid_patterns_no_error(t *testing.T) {
+ cases := []string{"foo", "foo:*", "foo:**", "*.bar", "[abc]xyz", "[!a]b", "?abc"}
+ for _, c := range cases {
+ if err := ValidatePattern(c); err != nil {
+ t.Errorf("expected no error for %q, got: %v", c, err)
+ }
+ }
+}
+
+func TestIsWildcard(t *testing.T) {
+ cases := map[string]bool{
+ "foo": false,
+ "foo:bar": false,
+ "foo:*": true,
+ "foo?bar": true,
+ "[abc]xyz": true,
+ "foo:**": true,
+ "a:b:c:d:e:f": false,
+ }
+ for in, want := range cases {
+ if got := IsWildcard(in); got != want {
+ t.Errorf("IsWildcard(%q) = %v, want %v", in, got, want)
+ }
+ }
+}
+
+func TestPatternMatches_exact(t *testing.T) {
+ cases := []struct {
+ pattern, topic string
+ want bool
+ }{
+ {"foo", "foo", true},
+ {"foo", "bar", false},
+ {"foo:bar", "foo:bar", true},
+ {"foo:bar", "foo:baz", false},
+ }
+ for _, c := range cases {
+ if got := PatternMatches(c.pattern, c.topic); got != c.want {
+ t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
+ }
+ }
+}
+
+func TestPatternMatches_single_star_segment_bounded(t *testing.T) {
+ cases := []struct {
+ pattern, topic string
+ want bool
+ }{
+ // '*' matches within a single segment
+ {"presence:*", "presence:user-1", true},
+ {"presence:*", "presence:user-2", true},
+ {"presence:*", "presence:", true},
+ // '*' does NOT cross ':'
+ {"presence:*", "presence:user:device", false},
+ {"a:*:b", "a:x:b", true},
+ {"a:*:b", "a:x:y:b", false},
+ // Different prefix
+ {"presence:*", "calls:invite", false},
+ }
+ for _, c := range cases {
+ if got := PatternMatches(c.pattern, c.topic); got != c.want {
+ t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
+ }
+ }
+}
+
+func TestPatternMatches_double_star_crosses_segments(t *testing.T) {
+ cases := []struct {
+ pattern, topic string
+ want bool
+ }{
+ {"notify:**", "notify:user-1", true},
+ {"notify:**", "notify:user:device:1", true},
+ {"**", "anything:goes:here", true},
+ }
+ for _, c := range cases {
+ if got := PatternMatches(c.pattern, c.topic); got != c.want {
+ t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
+ }
+ }
+}
+
+func TestPatternMatches_question_mark(t *testing.T) {
+ cases := []struct {
+ pattern, topic string
+ want bool
+ }{
+ {"a?c", "abc", true},
+ {"a?c", "axc", true},
+ {"a?c", "ac", false},
+ {"a?c", "abbc", false},
+ }
+ for _, c := range cases {
+ if got := PatternMatches(c.pattern, c.topic); got != c.want {
+ t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
+ }
+ }
+}
+
+func TestPatternMatches_character_class(t *testing.T) {
+ cases := []struct {
+ pattern, topic string
+ want bool
+ }{
+ {"[abc]xyz", "axyz", true},
+ {"[abc]xyz", "bxyz", true},
+ {"[abc]xyz", "dxyz", false},
+ {"[!a]bc", "xbc", true},
+ {"[!a]bc", "abc", false},
+ }
+ for _, c := range cases {
+ if got := PatternMatches(c.pattern, c.topic); got != c.want {
+ t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
+ }
+ }
+}
+
+func TestPatternMatches_trailing_star_with_remaining_chars(t *testing.T) {
+ // '*' can match zero characters at end.
+ cases := []struct {
+ pattern, topic string
+ want bool
+ }{
+ {"foo*", "foo", true},
+ {"foo*", "foobar", true},
+ {"foo*", "foobar:baz", false}, // ':' breaks single '*'
+ }
+ for _, c := range cases {
+ if got := PatternMatches(c.pattern, c.topic); got != c.want {
+ t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
+ }
+ }
+}
diff --git a/core/pkg/serverless/triggers/pubsub_store.go b/core/pkg/serverless/triggers/pubsub_store.go
index 7ee14fb..6125339 100644
--- a/core/pkg/serverless/triggers/pubsub_store.go
+++ b/core/pkg/serverless/triggers/pubsub_store.go
@@ -16,30 +16,43 @@ import (
// TriggerMatch contains the fields needed to dispatch a trigger invocation.
// It's the result of JOINing function_pubsub_triggers with functions.
+//
+// Topic is the *resolved* topic that the published message was sent to,
+// not the pattern stored in the trigger. This lets aggregating functions
+// see which concrete topic each event came from.
+//
+// AggregationWindowMs > 0 indicates the dispatcher should buffer events
+// instead of invoking the function per event.
type TriggerMatch struct {
- TriggerID string
- FunctionID string
- FunctionName string
- Namespace string
- Topic string
+ TriggerID string
+ FunctionID string
+ FunctionName string
+ Namespace string
+ Topic string
+ AggregationWindowMs int
+ AggregationMaxBatchSize int
}
// triggerRow maps to the function_pubsub_triggers table for query scanning.
type triggerRow struct {
- ID string
- FunctionID string
- Topic string
- Enabled bool
- CreatedAt time.Time
+ ID string
+ FunctionID string
+ TopicPattern string
+ Enabled bool
+ CreatedAt time.Time
+ AggregationWindowMs int
+ AggregationMaxBatchSize int
}
// triggerMatchRow maps to the JOIN query result for scanning.
type triggerMatchRow struct {
- TriggerID string
- FunctionID string
- FunctionName string
- Namespace string
- Topic string
+ TriggerID string
+ FunctionID string
+ FunctionName string
+ Namespace string
+ TopicPattern string
+ AggregationWindowMs int
+ AggregationMaxBatchSize int
}
// PubSubTriggerStore manages PubSub trigger persistence in RQLite.
@@ -57,30 +70,60 @@ func NewPubSubTriggerStore(db rqlite.Client, logger *zap.Logger) *PubSubTriggerS
}
// Add registers a new PubSub trigger for a function.
+// `topicPattern` may be an exact topic or a SQLite GLOB pattern (e.g. "presence:*").
// Returns the trigger ID.
-func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topic string) (string, error) {
+//
+// For backward compatibility, aggregation defaults to disabled (windowMs=0).
+// Use AddWithAggregation to opt in.
+func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topicPattern string) (string, error) {
+ return s.AddWithAggregation(ctx, functionID, topicPattern, 0, 0)
+}
+
+// AddWithAggregation registers a trigger with optional aggregation.
+// - aggregationWindowMs = 0 disables aggregation (per-event invocation, default).
+// - aggregationMaxBatchSize = 0 uses the default (100) when aggregation is enabled.
+func (s *PubSubTriggerStore) AddWithAggregation(
+ ctx context.Context,
+ functionID, topicPattern string,
+ aggregationWindowMs, aggregationMaxBatchSize int,
+) (string, error) {
if functionID == "" {
return "", fmt.Errorf("function ID required")
}
- if topic == "" {
- return "", fmt.Errorf("topic required")
+ if err := ValidatePattern(topicPattern); err != nil {
+ return "", fmt.Errorf("invalid topic pattern: %w", err)
+ }
+ if aggregationWindowMs < 0 || aggregationWindowMs > 60_000 {
+ return "", fmt.Errorf("aggregation_window_ms must be between 0 and 60000")
+ }
+ if aggregationMaxBatchSize < 0 || aggregationMaxBatchSize > 1000 {
+ return "", fmt.Errorf("aggregation_max_batch_size must be between 0 and 1000")
+ }
+ if aggregationWindowMs > 0 && aggregationMaxBatchSize == 0 {
+ aggregationMaxBatchSize = 100
}
id := uuid.New().String()
now := time.Now()
+ // Write both `topic` (legacy) and `topic_pattern` (new). Keeping `topic`
+ // populated lets old binaries running concurrently during a rolling
+ // upgrade continue reading triggers. A future migration drops `topic`.
query := `
- INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at)
- VALUES (?, ?, ?, TRUE, ?)
+ INSERT INTO function_pubsub_triggers (id, function_id, topic, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size)
+ VALUES (?, ?, ?, ?, TRUE, ?, ?, ?)
`
- if _, err := s.db.Exec(ctx, query, id, functionID, topic, now); err != nil {
+ if _, err := s.db.Exec(ctx, query, id, functionID, topicPattern, topicPattern, now, aggregationWindowMs, aggregationMaxBatchSize); err != nil {
return "", fmt.Errorf("failed to add pubsub trigger: %w", err)
}
s.logger.Info("PubSub trigger added",
zap.String("trigger_id", id),
zap.String("function_id", functionID),
- zap.String("topic", topic),
+ zap.String("topic_pattern", topicPattern),
+ zap.Bool("wildcard", IsWildcard(topicPattern)),
+ zap.Int("aggregation_window_ms", aggregationWindowMs),
+ zap.Int("aggregation_max_batch_size", aggregationMaxBatchSize),
)
return id, nil
@@ -129,7 +172,7 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri
}
query := `
- SELECT id, function_id, topic, enabled, created_at
+ SELECT id, function_id, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size
FROM function_pubsub_triggers
WHERE function_id = ?
`
@@ -142,18 +185,22 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri
triggers := make([]serverless.PubSubTrigger, len(rows))
for i, row := range rows {
triggers[i] = serverless.PubSubTrigger{
- ID: row.ID,
- FunctionID: row.FunctionID,
- Topic: row.Topic,
- Enabled: row.Enabled,
+ ID: row.ID,
+ FunctionID: row.FunctionID,
+ Topic: row.TopicPattern,
+ Enabled: row.Enabled,
+ AggregationWindowMs: row.AggregationWindowMs,
+ AggregationMaxBatchSize: row.AggregationMaxBatchSize,
}
}
return triggers, nil
}
-// GetByTopicAndNamespace returns all enabled triggers for a topic within a namespace.
-// Only returns triggers for active functions.
+// GetByTopicAndNamespace returns all enabled triggers whose topic_pattern
+// matches `topic` within the namespace. Patterns are SQLite GLOB; the
+// post-filter enforces stricter segment-aware semantics.
+// Only triggers for active functions are returned.
func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic, namespace string) ([]TriggerMatch, error) {
if topic == "" || namespace == "" {
return nil, nil
@@ -161,10 +208,12 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic,
query := `
SELECT t.id AS trigger_id, t.function_id AS function_id,
- f.name AS function_name, f.namespace AS namespace, t.topic AS topic
+ f.name AS function_name, f.namespace AS namespace, t.topic_pattern AS topic_pattern,
+ t.aggregation_window_ms AS aggregation_window_ms,
+ t.aggregation_max_batch_size AS aggregation_max_batch_size
FROM function_pubsub_triggers t
JOIN functions f ON t.function_id = f.id
- WHERE t.topic = ? AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active'
+ WHERE ? GLOB t.topic_pattern AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active'
`
var rows []triggerMatchRow
@@ -172,15 +221,21 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic,
return nil, fmt.Errorf("failed to query triggers for topic: %w", err)
}
- matches := make([]TriggerMatch, len(rows))
- for i, row := range rows {
- matches[i] = TriggerMatch{
- TriggerID: row.TriggerID,
- FunctionID: row.FunctionID,
- FunctionName: row.FunctionName,
- Namespace: row.Namespace,
- Topic: row.Topic,
+ matches := make([]TriggerMatch, 0, len(rows))
+ for _, row := range rows {
+ // Post-filter to enforce strict segment boundaries on '*'.
+ if !PatternMatches(row.TopicPattern, topic) {
+ continue
}
+ matches = append(matches, TriggerMatch{
+ TriggerID: row.TriggerID,
+ FunctionID: row.FunctionID,
+ FunctionName: row.FunctionName,
+ Namespace: row.Namespace,
+ Topic: topic, // resolved topic, not the pattern
+ AggregationWindowMs: row.AggregationWindowMs,
+ AggregationMaxBatchSize: row.AggregationMaxBatchSize,
+ })
}
return matches, nil
diff --git a/core/pkg/serverless/triggers/triggers_test.go b/core/pkg/serverless/triggers/triggers_test.go
index a9822cc..e2662f4 100644
--- a/core/pkg/serverless/triggers/triggers_test.go
+++ b/core/pkg/serverless/triggers/triggers_test.go
@@ -40,13 +40,6 @@ func TestDispatcher_DepthLimit(t *testing.T) {
d.Dispatch(context.Background(), "ns", "topic", []byte("data"), maxTriggerDepth+1)
}
-func TestCacheKey(t *testing.T) {
- key := cacheKey("my-namespace", "my-topic")
- if key != "triggers:my-namespace:my-topic" {
- t.Errorf("unexpected cache key: %s", key)
- }
-}
-
func TestPubSubEvent_Marshal(t *testing.T) {
event := PubSubEvent{
Topic: "chat",
diff --git a/core/pkg/serverless/types.go b/core/pkg/serverless/types.go
index 66a13f7..716f297 100644
--- a/core/pkg/serverless/types.go
+++ b/core/pkg/serverless/types.go
@@ -288,11 +288,21 @@ type DBTrigger struct {
}
// PubSubTrigger represents a pubsub trigger.
+//
+// Topic may be an exact topic name or a SQLite GLOB pattern (e.g.
+// "presence:*"). See pkg/serverless/triggers/pattern.go for matching rules.
+//
+// AggregationWindowMs > 0 enables event buffering: the dispatcher accumulates
+// events for at most that many milliseconds (or until AggregationMaxBatchSize
+// events have been collected, whichever comes first), then invokes the
+// function once with a batched payload of type BatchedPubSubEvent.
type PubSubTrigger struct {
- ID string `json:"id"`
- FunctionID string `json:"function_id"`
- Topic string `json:"topic"`
- Enabled bool `json:"enabled"`
+ ID string `json:"id"`
+ FunctionID string `json:"function_id"`
+ Topic string `json:"topic"`
+ Enabled bool `json:"enabled"`
+ AggregationWindowMs int `json:"aggregation_window_ms,omitempty"`
+ AggregationMaxBatchSize int `json:"aggregation_max_batch_size,omitempty"`
}
// Timer represents a one-time scheduled execution.
@@ -337,6 +347,14 @@ type HostServices interface {
// PubSub operations
PubSubPublish(ctx context.Context, topic string, data []byte) error
+ PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error
+
+ // Push notifications. Sends to all of `userID`'s registered devices in
+ // the function's namespace. `msgJSON` is the JSON-encoded PushSendArgs
+ // shape (see hostfunctions.PushSend). Returns nil if push is not
+ // configured (silent no-op) so functions can be portable across
+ // namespaces with/without push enabled.
+ PushSend(ctx context.Context, userID string, msgJSON []byte) error
// WebSocket operations (only valid in WS context)
WSSend(ctx context.Context, clientID string, data []byte) error
diff --git a/core/pkg/sniproxy/router.go b/core/pkg/sniproxy/router.go
new file mode 100644
index 0000000..a2f521b
--- /dev/null
+++ b/core/pkg/sniproxy/router.go
@@ -0,0 +1,106 @@
+package sniproxy
+
+import (
+ "strings"
+ "sync"
+)
+
+// Backend describes where to forward a connection.
+type Backend struct {
+ // Name is for logs/metrics only. Optional.
+ Name string
+ // Network is the dial network ("tcp", "tcp4", "tcp6"). Default "tcp".
+ Network string
+ // Addr is the dial target ("127.0.0.1:5349").
+ Addr string
+}
+
+// Route maps an SNI value (or wildcard pattern) to a Backend.
+//
+// Match semantics:
+// - "example.com" matches exactly "example.com"
+// - "*.example.com" matches any single-label subdomain ("a.example.com"
+// but not "a.b.example.com" — single-label like DNS wildcards)
+type Route struct {
+ Match string
+ Backend Backend
+}
+
+// Router atomically swaps a routing table while concurrent reads are in
+// flight. Reads are lock-free after the slice is published.
+type Router struct {
+ mu sync.RWMutex
+ routes []Route
+ fallback Backend
+}
+
+// NewRouter creates a router with no routes and the given fallback.
+func NewRouter(fallback Backend) *Router {
+ return &Router{fallback: fallback}
+}
+
+// Pick returns the matching backend for an SNI value, or the fallback if
+// no route matches (or if sni is empty).
+func (r *Router) Pick(sni string) Backend {
+ if sni == "" {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.fallback
+ }
+ sni = strings.ToLower(sni)
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ for _, route := range r.routes {
+ if matchSNI(route.Match, sni) {
+ return route.Backend
+ }
+ }
+ return r.fallback
+}
+
+// Replace atomically swaps the routing table. The new routes replace the
+// old ones in their entirety; partial updates are not supported.
+func (r *Router) Replace(routes []Route, fallback Backend) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.routes = routes
+ r.fallback = fallback
+}
+
+// Routes returns a defensive copy of the current routes. For introspection.
+func (r *Router) Routes() []Route {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ out := make([]Route, len(r.routes))
+ copy(out, r.routes)
+ return out
+}
+
+// Fallback returns the current fallback backend.
+func (r *Router) Fallback() Backend {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.fallback
+}
+
+// matchSNI implements the Match semantics documented on Route.
+func matchSNI(pattern, sni string) bool {
+ pattern = strings.ToLower(pattern)
+ if pattern == sni {
+ return true
+ }
+ // "*.example.com" matches ".example.com".
+ if strings.HasPrefix(pattern, "*.") {
+ suffix := pattern[1:] // ".example.com"
+ if !strings.HasSuffix(sni, suffix) {
+ return false
+ }
+ labelEnd := len(sni) - len(suffix)
+ if labelEnd <= 0 {
+ return false
+ }
+ // No additional dots in the wildcard label.
+ return !strings.Contains(sni[:labelEnd], ".")
+ }
+ return false
+}
diff --git a/core/pkg/sniproxy/router_test.go b/core/pkg/sniproxy/router_test.go
new file mode 100644
index 0000000..bece8d2
--- /dev/null
+++ b/core/pkg/sniproxy/router_test.go
@@ -0,0 +1,113 @@
+package sniproxy
+
+import (
+ "sync"
+ "testing"
+)
+
+func TestRouter_pick_exact_match(t *testing.T) {
+ fb := Backend{Name: "fallback", Addr: "127.0.0.1:9000"}
+ r := NewRouter(fb)
+ r.Replace([]Route{
+ {Match: "turn.example.com", Backend: Backend{Name: "turn", Addr: "127.0.0.1:5349"}},
+ }, fb)
+
+ got := r.Pick("turn.example.com")
+ if got.Addr != "127.0.0.1:5349" {
+ t.Errorf("expected turn backend, got %+v", got)
+ }
+}
+
+func TestRouter_pick_unmatched_returns_fallback(t *testing.T) {
+ fb := Backend{Name: "caddy", Addr: "127.0.0.1:8443"}
+ r := NewRouter(fb)
+ r.Replace([]Route{
+ {Match: "turn.example.com", Backend: Backend{Addr: "127.0.0.1:5349"}},
+ }, fb)
+
+ if got := r.Pick("api.example.com"); got != fb {
+ t.Errorf("expected fallback, got %+v", got)
+ }
+ if got := r.Pick(""); got != fb {
+ t.Errorf("expected fallback for empty SNI, got %+v", got)
+ }
+}
+
+func TestRouter_pick_case_insensitive(t *testing.T) {
+ fb := Backend{Addr: "127.0.0.1:8443"}
+ r := NewRouter(fb)
+ r.Replace([]Route{
+ {Match: "Turn.Example.Com", Backend: Backend{Addr: "127.0.0.1:5349"}},
+ }, fb)
+
+ if got := r.Pick("turn.example.com"); got.Addr != "127.0.0.1:5349" {
+ t.Errorf("expected case-insensitive match, got %+v", got)
+ }
+}
+
+func TestRouter_pick_wildcard_subdomain(t *testing.T) {
+ fb := Backend{Addr: "127.0.0.1:8443"}
+ r := NewRouter(fb)
+ r.Replace([]Route{
+ {Match: "*.example.com", Backend: Backend{Name: "wild", Addr: "127.0.0.1:5349"}},
+ }, fb)
+
+ cases := map[string]bool{
+ "a.example.com": true,
+ "foo.example.com": true,
+ "a.b.example.com": false, // multi-label not allowed
+ "example.com": false, // bare domain doesn't match *.example.com
+ "other.com": false,
+ }
+ for sni, want := range cases {
+ got := r.Pick(sni) == Backend{Name: "wild", Addr: "127.0.0.1:5349"}
+ if got != want {
+ t.Errorf("Pick(%q): want match=%v, got match=%v", sni, want, got)
+ }
+ }
+}
+
+func TestRouter_replace_atomic(t *testing.T) {
+ // Many concurrent reads against many concurrent Replace calls — should
+ // never observe partial state. Run with -race.
+ fb := Backend{Addr: "fb"}
+ r := NewRouter(fb)
+ r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "1"}}}, fb)
+
+ var wg sync.WaitGroup
+ stop := make(chan struct{})
+
+ // Readers
+ for i := 0; i < 4; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for {
+ select {
+ case <-stop:
+ return
+ default:
+ _ = r.Pick("a.com")
+ }
+ }
+ }()
+ }
+
+ // Writers
+ for i := 0; i < 200; i++ {
+ r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "x"}}}, fb)
+ }
+ close(stop)
+ wg.Wait()
+}
+
+func TestRouter_routes_returns_copy(t *testing.T) {
+ r := NewRouter(Backend{})
+ original := []Route{{Match: "a", Backend: Backend{Addr: "1"}}}
+ r.Replace(original, Backend{})
+ got := r.Routes()
+ got[0].Match = "mutated"
+ if r.Routes()[0].Match != "a" {
+ t.Error("Routes() should return a defensive copy")
+ }
+}
diff --git a/core/pkg/sniproxy/server.go b/core/pkg/sniproxy/server.go
new file mode 100644
index 0000000..4b9607e
--- /dev/null
+++ b/core/pkg/sniproxy/server.go
@@ -0,0 +1,177 @@
+package sniproxy
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "go.uber.org/zap"
+)
+
+// Config tunes the proxy server.
+type Config struct {
+ // ClientHelloTimeout bounds the wait for a parseable ClientHello.
+ // 0 selects 5 seconds.
+ ClientHelloTimeout time.Duration
+ // BackendDialTimeout bounds backend connect time. 0 selects 5 seconds.
+ BackendDialTimeout time.Duration
+ // MaxConcurrentConns caps total in-flight connections to prevent
+ // resource exhaustion. 0 selects 10000.
+ MaxConcurrentConns int
+}
+
+// Server is a TCP-level SNI router. Create via NewServer, then call
+// Serve(listener) in a goroutine. Close cancels in-flight connections.
+type Server struct {
+ router *Router
+ cfg Config
+ logger *zap.Logger
+
+ gate chan struct{} // bounded semaphore for concurrent connections
+ wg sync.WaitGroup
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+// NewServer constructs a Server with the given router and config.
+func NewServer(router *Router, cfg Config, logger *zap.Logger) *Server {
+ if logger == nil {
+ logger = zap.NewNop()
+ }
+ if cfg.ClientHelloTimeout <= 0 {
+ cfg.ClientHelloTimeout = 5 * time.Second
+ }
+ if cfg.BackendDialTimeout <= 0 {
+ cfg.BackendDialTimeout = 5 * time.Second
+ }
+ if cfg.MaxConcurrentConns <= 0 {
+ cfg.MaxConcurrentConns = 10000
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ return &Server{
+ router: router,
+ cfg: cfg,
+ logger: logger.Named("sniproxy"),
+ gate: make(chan struct{}, cfg.MaxConcurrentConns),
+ ctx: ctx,
+ cancel: cancel,
+ }
+}
+
+// Serve accepts connections from ln until ln.Accept returns a permanent
+// error or Close is called. Serve always returns a non-nil error.
+func (s *Server) Serve(ln net.Listener) error {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ // Check for shutdown via cancelled ctx.
+ if s.ctx.Err() != nil {
+ return s.ctx.Err()
+ }
+ // Net errors temporarily? Backoff briefly so we don't busy-loop.
+ var ne net.Error
+ if errors.As(err, &ne) && ne.Timeout() {
+ time.Sleep(50 * time.Millisecond)
+ continue
+ }
+ return err
+ }
+ select {
+ case s.gate <- struct{}{}:
+ default:
+ s.logger.Warn("max concurrent connections reached, dropping",
+ zap.Int("limit", s.cfg.MaxConcurrentConns),
+ zap.String("remote", conn.RemoteAddr().String()),
+ )
+ conn.Close()
+ continue
+ }
+ s.wg.Add(1)
+ go func(c net.Conn) {
+ defer s.wg.Done()
+ defer func() { <-s.gate }()
+ s.handle(c)
+ }(conn)
+ }
+}
+
+// Close cancels in-flight connections and waits for handlers to drain.
+func (s *Server) Close() {
+ s.cancel()
+ s.wg.Wait()
+}
+
+// handle processes a single accepted connection: peek SNI, dial backend,
+// replay peeked bytes, then bidirectional copy.
+func (s *Server) handle(conn net.Conn) {
+ defer conn.Close()
+
+ sni, peeked, err := PeekClientHello(conn, s.cfg.ClientHelloTimeout)
+ if err != nil {
+ s.logger.Debug("ClientHello peek failed",
+ zap.String("remote", conn.RemoteAddr().String()),
+ zap.Error(err),
+ )
+ return
+ }
+
+ backend := s.router.Pick(sni)
+ if backend.Addr == "" {
+ s.logger.Warn("no backend for SNI",
+ zap.String("sni", sni),
+ zap.String("remote", conn.RemoteAddr().String()),
+ )
+ return
+ }
+
+ network := backend.Network
+ if network == "" {
+ network = "tcp"
+ }
+
+ upstream, err := net.DialTimeout(network, backend.Addr, s.cfg.BackendDialTimeout)
+ if err != nil {
+ s.logger.Warn("backend dial failed",
+ zap.String("sni", sni),
+ zap.String("backend", backend.Addr),
+ zap.Error(err),
+ )
+ return
+ }
+ defer upstream.Close()
+
+ // Replay peeked bytes (the ClientHello + anything else buffered).
+ if len(peeked) > 0 {
+ if _, err := upstream.Write(peeked); err != nil {
+ s.logger.Debug("replay to backend failed",
+ zap.String("sni", sni),
+ zap.Error(err),
+ )
+ return
+ }
+ }
+
+ // Bidirectional copy. We close both connections when either side
+ // finishes OR when the server is shutting down, so handle() can't
+ // hang forever on a half-stuck peer.
+ done := make(chan struct{}, 2)
+ go func() {
+ _, _ = io.Copy(upstream, conn)
+ done <- struct{}{}
+ }()
+ go func() {
+ _, _ = io.Copy(conn, upstream)
+ done <- struct{}{}
+ }()
+ select {
+ case <-done:
+ case <-s.ctx.Done():
+ }
+ // Force both sides closed; second copy will exit immediately.
+ upstream.Close()
+ conn.Close()
+ <-done // drain the second goroutine
+}
diff --git a/core/pkg/sniproxy/server_test.go b/core/pkg/sniproxy/server_test.go
new file mode 100644
index 0000000..bca9d60
--- /dev/null
+++ b/core/pkg/sniproxy/server_test.go
@@ -0,0 +1,143 @@
+package sniproxy
+
+import (
+ "bufio"
+ "crypto/tls"
+ "errors"
+ "io"
+ "net"
+ "testing"
+ "time"
+
+ "go.uber.org/zap"
+)
+
+// startEchoBackend creates a TCP server that echoes the first 1024 bytes
+// it reads, then closes. Returns the listener and a chan that receives
+// the bytes the server saw.
+func startEchoBackend(t *testing.T) (net.Listener, <-chan []byte) {
+ t.Helper()
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := make(chan []byte, 4)
+ go func() {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ go func(c net.Conn) {
+ defer c.Close()
+ _ = c.SetReadDeadline(time.Now().Add(2 * time.Second))
+ buf := make([]byte, 1024)
+ n, _ := c.Read(buf)
+ got <- append([]byte(nil), buf[:n]...)
+ }(conn)
+ }
+ }()
+ return ln, got
+}
+
+func TestServer_routes_TLS_to_correct_backend(t *testing.T) {
+ turnLn, turnGot := startEchoBackend(t)
+ defer turnLn.Close()
+ caddyLn, caddyGot := startEchoBackend(t)
+ defer caddyLn.Close()
+
+ router := NewRouter(Backend{Network: "tcp", Addr: caddyLn.Addr().String()})
+ router.Replace([]Route{
+ {Match: "turn.example.com", Backend: Backend{Network: "tcp", Addr: turnLn.Addr().String()}},
+ }, router.Fallback())
+
+ srv := NewServer(router, Config{}, zap.NewNop())
+ defer srv.Close()
+
+ frontLn, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer frontLn.Close()
+
+ go func() { _ = srv.Serve(frontLn) }()
+
+ // Client A: SNI=turn.example.com -> goes to turnLn
+ dialAndStartTLS(t, frontLn.Addr().String(), "turn.example.com")
+
+ // Client B: SNI=other.example.com -> falls through to caddyLn
+ dialAndStartTLS(t, frontLn.Addr().String(), "other.example.com")
+
+ select {
+ case b := <-turnGot:
+ if len(b) == 0 {
+ t.Error("turn backend received empty bytes")
+ }
+ case <-time.After(3 * time.Second):
+ t.Error("turn backend did not receive bytes")
+ }
+
+ select {
+ case b := <-caddyGot:
+ if len(b) == 0 {
+ t.Error("caddy fallback received empty bytes")
+ }
+ case <-time.After(3 * time.Second):
+ t.Error("caddy fallback did not receive bytes")
+ }
+}
+
+// dialAndStartTLS opens a TLS handshake (which produces a ClientHello)
+// against the given address with the given SNI. Returns immediately —
+// the test only needs the proxy to forward the bytes; it doesn't
+// require handshake completion.
+func dialAndStartTLS(t *testing.T, addr, sni string) {
+ t.Helper()
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ defer conn.Close()
+ _ = conn.SetDeadline(time.Now().Add(2 * time.Second))
+ c := tls.Client(conn, &tls.Config{ServerName: sni, InsecureSkipVerify: true})
+ _ = c.Handshake() // expected to fail (echo backend isn't TLS)
+ }()
+}
+
+func TestServer_no_backend_drops_connection(t *testing.T) {
+ router := NewRouter(Backend{}) // empty fallback, empty Addr -> dropped
+ srv := NewServer(router, Config{}, zap.NewNop())
+ defer srv.Close()
+
+ frontLn, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer frontLn.Close()
+
+ go func() { _ = srv.Serve(frontLn) }()
+
+ conn, err := net.Dial("tcp", frontLn.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true})
+ // Handshake should fail because connection is closed by proxy.
+ go func() { _ = c.Handshake() }()
+
+ // Reader should see EOF quickly.
+ _ = conn.SetDeadline(time.Now().Add(2 * time.Second))
+ br := bufio.NewReader(conn)
+ _, err = br.ReadByte()
+ if err == nil {
+ t.Error("expected connection drop")
+ }
+ if !errors.Is(err, io.EOF) {
+ // "use of closed network connection" is also fine.
+ t.Logf("acceptable read error: %v", err)
+ }
+}
+
diff --git a/core/pkg/sniproxy/sni.go b/core/pkg/sniproxy/sni.go
new file mode 100644
index 0000000..cf349c1
--- /dev/null
+++ b/core/pkg/sniproxy/sni.go
@@ -0,0 +1,235 @@
+// Package sniproxy provides a TCP-level Server Name Indication (SNI) router.
+//
+// The router peeks at the unencrypted TLS ClientHello on each accepted
+// connection, extracts the SNI host name, and forwards the raw stream to
+// a backend. It does NOT terminate TLS — encrypted bytes pass through
+// verbatim. This lets one TCP port serve multiple TLS-speaking backends
+// (HTTPS for the gateway, TURNS for stealth WebRTC, etc.) without
+// sharing private keys with the proxy.
+//
+// Design goals:
+// - Zero TLS material on the proxy
+// - Bounded ClientHello read (no slowloris)
+// - Backend dial timeout
+// - Per-IP rate limiting
+//
+// SNI parsing follows RFC 5246 §7.4.1.2 (TLS record + ClientHello) and
+// RFC 6066 §3 (server_name extension).
+package sniproxy
+
+import (
+ "bufio"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "time"
+)
+
+// ErrNoSNI is returned when the ClientHello has no server_name extension.
+var ErrNoSNI = errors.New("sniproxy: ClientHello has no SNI")
+
+// MaxClientHelloBytes bounds how many bytes we'll read while looking for
+// the SNI. TLS ClientHello records are typically 200–500 bytes; this is
+// a generous cap that still defends against memory abuse.
+const MaxClientHelloBytes = 16 * 1024
+
+// PeekClientHello reads bytes from conn until a TLS ClientHello has
+// been parsed (or MaxClientHelloBytes is exceeded). Returns the SNI
+// hostname (lowercased), the bytes consumed (must be replayed to the
+// backend), and any error.
+//
+// readTimeout bounds the wait — slowloris-style stalls return an error
+// quickly without holding the goroutine indefinitely.
+func PeekClientHello(conn net.Conn, readTimeout time.Duration) (string, []byte, error) {
+ if readTimeout > 0 {
+ _ = conn.SetReadDeadline(time.Now().Add(readTimeout))
+ defer conn.SetReadDeadline(time.Time{})
+ }
+
+ br := bufio.NewReaderSize(conn, MaxClientHelloBytes)
+
+ // Peek the TLS record header (5 bytes): content_type, version (2),
+ // length (2). content_type for ClientHello is 22 (handshake).
+ header, err := br.Peek(5)
+ if err != nil {
+ return "", nil, fmt.Errorf("read tls record header: %w", err)
+ }
+ if header[0] != 22 {
+ return "", nil, fmt.Errorf("not a TLS handshake record (type=%d)", header[0])
+ }
+ recLen := int(binary.BigEndian.Uint16(header[3:5]))
+ if recLen <= 0 || 5+recLen > MaxClientHelloBytes {
+ return "", nil, fmt.Errorf("invalid record length %d", recLen)
+ }
+
+ full, err := br.Peek(5 + recLen)
+ if err != nil {
+ return "", nil, fmt.Errorf("read tls record body: %w", err)
+ }
+
+ sni, err := parseSNI(full[5:])
+ if err != nil {
+ return "", nil, err
+ }
+
+ // We've only peeked — drain the buffer to capture the bytes for replay.
+ consumed := make([]byte, br.Buffered())
+ if _, err := io.ReadFull(br, consumed); err != nil {
+ return "", nil, fmt.Errorf("drain peeked bytes: %w", err)
+ }
+
+ return sni, consumed, nil
+}
+
+// parseSNI parses a TLS ClientHello body (without the 5-byte record
+// header) and returns the server_name extension value if present.
+func parseSNI(body []byte) (string, error) {
+ r := newReader(body)
+
+ // Handshake type (1 byte) — must be 1 (ClientHello).
+ hsType, err := r.readByte()
+ if err != nil {
+ return "", err
+ }
+ if hsType != 1 {
+ return "", fmt.Errorf("not a ClientHello (handshake type %d)", hsType)
+ }
+ // Handshake length (3 bytes).
+ if _, err := r.readBytes(3); err != nil {
+ return "", err
+ }
+ // client_version (2) + random (32).
+ if _, err := r.readBytes(2 + 32); err != nil {
+ return "", err
+ }
+ // session_id.
+ sidLen, err := r.readByte()
+ if err != nil {
+ return "", err
+ }
+ if _, err := r.readBytes(int(sidLen)); err != nil {
+ return "", err
+ }
+ // cipher_suites length (2).
+ csLen, err := r.readUint16()
+ if err != nil {
+ return "", err
+ }
+ if _, err := r.readBytes(int(csLen)); err != nil {
+ return "", err
+ }
+ // compression_methods length (1).
+ cmLen, err := r.readByte()
+ if err != nil {
+ return "", err
+ }
+ if _, err := r.readBytes(int(cmLen)); err != nil {
+ return "", err
+ }
+ // Extensions length (2). Optional — TLS 1.0 ClientHello can skip it.
+ if r.remaining() < 2 {
+ return "", ErrNoSNI
+ }
+ extTotalLen, err := r.readUint16()
+ if err != nil {
+ return "", err
+ }
+ if int(extTotalLen) > r.remaining() {
+ return "", fmt.Errorf("extensions truncated")
+ }
+
+ end := r.pos + int(extTotalLen)
+ for r.pos < end {
+ extType, err := r.readUint16()
+ if err != nil {
+ return "", err
+ }
+ extLen, err := r.readUint16()
+ if err != nil {
+ return "", err
+ }
+ extData, err := r.readBytes(int(extLen))
+ if err != nil {
+ return "", err
+ }
+ // server_name extension is type 0.
+ if extType != 0 {
+ continue
+ }
+ return parseServerName(extData)
+ }
+ return "", ErrNoSNI
+}
+
+// parseServerName parses the body of a server_name extension and returns
+// the first host_name (type 0) entry.
+func parseServerName(data []byte) (string, error) {
+ r := newReader(data)
+ // server_name_list length (2).
+ listLen, err := r.readUint16()
+ if err != nil {
+ return "", err
+ }
+ if int(listLen) > r.remaining() {
+ return "", fmt.Errorf("server_name list truncated")
+ }
+ end := r.pos + int(listLen)
+ for r.pos < end {
+ nameType, err := r.readByte()
+ if err != nil {
+ return "", err
+ }
+ nameLen, err := r.readUint16()
+ if err != nil {
+ return "", err
+ }
+ nameBytes, err := r.readBytes(int(nameLen))
+ if err != nil {
+ return "", err
+ }
+ if nameType == 0 { // host_name
+ return strings.ToLower(string(nameBytes)), nil
+ }
+ }
+ return "", ErrNoSNI
+}
+
+// reader is a tiny byte-slice cursor used by parseSNI/parseServerName.
+type reader struct {
+ buf []byte
+ pos int
+}
+
+func newReader(buf []byte) *reader { return &reader{buf: buf} }
+
+func (r *reader) remaining() int { return len(r.buf) - r.pos }
+
+func (r *reader) readByte() (byte, error) {
+ if r.pos >= len(r.buf) {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := r.buf[r.pos]
+ r.pos++
+ return b, nil
+}
+
+func (r *reader) readUint16() (uint16, error) {
+ if r.pos+2 > len(r.buf) {
+ return 0, io.ErrUnexpectedEOF
+ }
+ v := binary.BigEndian.Uint16(r.buf[r.pos : r.pos+2])
+ r.pos += 2
+ return v, nil
+}
+
+func (r *reader) readBytes(n int) ([]byte, error) {
+ if r.pos+n > len(r.buf) {
+ return nil, io.ErrUnexpectedEOF
+ }
+ b := r.buf[r.pos : r.pos+n]
+ r.pos += n
+ return b, nil
+}
diff --git a/core/pkg/sniproxy/sni_test.go b/core/pkg/sniproxy/sni_test.go
new file mode 100644
index 0000000..581dc44
--- /dev/null
+++ b/core/pkg/sniproxy/sni_test.go
@@ -0,0 +1,173 @@
+package sniproxy
+
+import (
+ "crypto/tls"
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "testing"
+ "time"
+)
+
+// dialAndPeek dials a TLS handshake to the given listener and returns
+// what PeekClientHello on the server side parsed.
+func dialAndPeek(t *testing.T, ln net.Listener, sni string) (string, []byte, error) {
+ t.Helper()
+
+ type result struct {
+ sni string
+ peeked []byte
+ err error
+ }
+ resCh := make(chan result, 1)
+
+ // Server side: accept once, peek SNI.
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ resCh <- result{err: err}
+ return
+ }
+ defer conn.Close()
+ s, p, err := PeekClientHello(conn, 2*time.Second)
+ resCh <- result{sni: s, peeked: p, err: err}
+ }()
+
+ // Client side: kick off a TLS handshake. We don't care if it
+ // completes (no server cert) — we only need ClientHello to be sent.
+ // Use a goroutine so the test doesn't deadlock waiting on Handshake.
+ go func() {
+ conn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ _ = conn.SetDeadline(time.Now().Add(2 * time.Second))
+ c := tls.Client(conn, &tls.Config{
+ ServerName: sni,
+ InsecureSkipVerify: true,
+ })
+ _ = c.Handshake() // expected to fail; we only needed the ClientHello
+ }()
+
+ select {
+ case r := <-resCh:
+ return r.sni, r.peeked, r.err
+ case <-time.After(5 * time.Second):
+ return "", nil, errors.New("test timeout")
+ }
+}
+
+func TestPeekClientHello_returns_sni(t *testing.T) {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ sni, peeked, err := dialAndPeek(t, ln, "example.com")
+ if err != nil {
+ t.Fatalf("PeekClientHello: %v", err)
+ }
+ if sni != "example.com" {
+ t.Errorf("expected sni=example.com, got %q", sni)
+ }
+ if len(peeked) == 0 {
+ t.Error("expected non-empty peeked bytes")
+ }
+}
+
+func TestPeekClientHello_lowercases_sni(t *testing.T) {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ sni, _, err := dialAndPeek(t, ln, "Example.COM")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if sni != "example.com" {
+ t.Errorf("expected lowercase, got %q", sni)
+ }
+}
+
+func TestPeekClientHello_non_tls_returns_error(t *testing.T) {
+ a, b := net.Pipe()
+ defer a.Close()
+ defer b.Close()
+
+ go func() {
+ // Send something that isn't a TLS handshake record.
+ _, _ = a.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+ _ = a.Close()
+ }()
+
+ _, _, err := PeekClientHello(b, 1*time.Second)
+ if err == nil {
+ t.Fatal("expected error for non-TLS bytes")
+ }
+}
+
+func TestPeekClientHello_short_record_returns_error(t *testing.T) {
+ a, b := net.Pipe()
+ defer a.Close()
+ defer b.Close()
+
+ go func() {
+ // One byte, then close — too short for record header.
+ _, _ = a.Write([]byte{22})
+ _ = a.Close()
+ }()
+
+ _, _, err := PeekClientHello(b, 1*time.Second)
+ if err == nil {
+ t.Fatal("expected error for short record")
+ }
+ // EOF or read error is acceptable.
+ if !errors.Is(err, io.EOF) && err.Error() == "" {
+ t.Logf("error: %v", err)
+ }
+}
+
+func TestPeekClientHello_concurrent_safe(t *testing.T) {
+ // Verify no shared state leaks between PeekClientHello calls.
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ var wg sync.WaitGroup
+ for i := 0; i < 4; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ conn, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+ _ = conn.SetDeadline(time.Now().Add(2 * time.Second))
+ c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true})
+ _ = c.Handshake()
+ }()
+ }
+ for i := 0; i < 4; i++ {
+ conn, err := ln.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sni, _, err := PeekClientHello(conn, 2*time.Second)
+ conn.Close()
+ if err != nil {
+ t.Errorf("peek %d: %v", i, err)
+ }
+ if sni != "x.example.com" {
+ t.Errorf("peek %d: got %q", i, sni)
+ }
+ }
+ wg.Wait()
+}
diff --git a/core/sni-router b/core/sni-router
new file mode 100755
index 0000000..44aa374
Binary files /dev/null and b/core/sni-router differ
diff --git a/core/systemd/orama-sni-router.service b/core/systemd/orama-sni-router.service
new file mode 100644
index 0000000..d41837a
--- /dev/null
+++ b/core/systemd/orama-sni-router.service
@@ -0,0 +1,38 @@
+[Unit]
+Description=Orama SNI Router (TLS-level :443 → backend forwarder)
+Documentation=https://github.com/DeBrosOfficial/network
+After=network.target
+Before=caddy.service
+PartOf=orama-node.service
+
+[Service]
+Type=simple
+WorkingDirectory=/opt/orama
+EnvironmentFile=-/opt/orama/.orama/data/sni-router.env
+ExecStart=/opt/orama/bin/orama-sni-router --config sni-router.yaml
+
+# Bind privileged ports (:80, :443) without running as root.
+AmbientCapabilities=CAP_NET_BIND_SERVICE
+CapabilityBoundingSet=CAP_NET_BIND_SERVICE
+
+User=orama
+Group=orama
+NoNewPrivileges=yes
+ProtectSystem=strict
+ProtectHome=yes
+PrivateTmp=yes
+LimitNOFILE=65536
+
+TimeoutStopSec=15s
+KillMode=mixed
+KillSignal=SIGTERM
+
+Restart=on-failure
+RestartSec=5s
+
+StandardOutput=journal
+StandardError=journal
+SyslogIdentifier=orama-sni-router
+
+[Install]
+WantedBy=multi-user.target