feat(core): implement sni-router for stealth turn

- add `orama-sni-router` binary to build process
- introduce `cmd/sni-router` for TLS-level SNI routing
- add documentation for stealth turn deployment architecture
This commit is contained in:
anonpenguin23 2026-05-03 18:20:21 +03:00
parent 54852076f9
commit 0379dc39f1
59 changed files with 5568 additions and 192 deletions

View File

@ -63,7 +63,7 @@ test-e2e-quick:
.PHONY: build clean test deps tidy fmt vet lint install-hooks push-devnet push-testnet rollout-devnet rollout-testnet release
VERSION := 0.120.0
VERSION := 0.121.0
COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown)
DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ)
LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)'
@ -80,6 +80,7 @@ build: deps
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
go build -ldflags "$(LDFLAGS)" -o bin/sfu ./cmd/sfu
go build -ldflags "$(LDFLAGS)" -o bin/turn ./cmd/turn
go build -ldflags "$(LDFLAGS)" -o bin/orama-sni-router ./cmd/sni-router
@echo "Build complete! Run ./bin/orama version"
# Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source)

242
core/cmd/sni-router/main.go Normal file
View File

@ -0,0 +1,242 @@
// Command sni-router is a TLS-level Server Name Indication router.
//
// It listens on a public TCP port (typically :443), peeks at the TLS
// ClientHello SNI on each connection, and forwards the raw stream to
// a configured backend. It does NOT terminate TLS — encrypted bytes
// pass through verbatim. This lets one port serve multiple TLS-speaking
// backends (HTTPS for the gateway, TURN-over-TLS for stealth WebRTC).
//
// See pkg/sniproxy for the underlying library.
//
// Configuration: YAML file at --config (defaults to ~/.orama/sni-router.yaml).
//
// Example sni-router.yaml:
//
// listen: ":443"
// client_hello_timeout: 5s
// backend_dial_timeout: 5s
// max_concurrent_conns: 10000
// fallback:
// name: caddy
// addr: "127.0.0.1:8443"
// routes:
// - match: "cdn.example.com"
// backend:
// name: turn-tls
// addr: "127.0.0.1:5349"
// - match: "turn.example.com"
// backend:
// name: turn-tls
// addr: "127.0.0.1:5349"
// - match: "*.ns-myapp.example.com"
// backend:
// name: gateway
// addr: "127.0.0.1:8443"
package main
import (
"flag"
"fmt"
"net"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/DeBrosOfficial/network/pkg/config"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/sniproxy"
"go.uber.org/zap"
)
var (
version = "dev"
commit = "unknown"
)
// yamlBackend mirrors sniproxy.Backend for YAML decoding.
type yamlBackend struct {
Name string `yaml:"name"`
Network string `yaml:"network"`
Addr string `yaml:"addr"`
}
// yamlRoute mirrors sniproxy.Route for YAML decoding.
type yamlRoute struct {
Match string `yaml:"match"`
Backend yamlBackend `yaml:"backend"`
}
// yamlConfig is the on-disk configuration shape.
type yamlConfig struct {
Listen string `yaml:"listen"`
ClientHelloTimeout time.Duration `yaml:"client_hello_timeout"`
BackendDialTimeout time.Duration `yaml:"backend_dial_timeout"`
MaxConcurrentConns int `yaml:"max_concurrent_conns"`
Fallback yamlBackend `yaml:"fallback"`
Routes []yamlRoute `yaml:"routes"`
}
func main() {
logger, err := logging.NewColoredLogger(logging.ComponentSNI, true)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to init logger: %v\n", err)
os.Exit(1)
}
logger.ComponentInfo(logging.ComponentSNI, "Starting SNI router",
zap.String("version", version),
zap.String("commit", commit))
cfg := parseConfig(logger)
router := sniproxy.NewRouter(toBackend(cfg.Fallback))
router.Replace(toRoutes(cfg.Routes), toBackend(cfg.Fallback))
srv := sniproxy.NewServer(router, sniproxy.Config{
ClientHelloTimeout: cfg.ClientHelloTimeout,
BackendDialTimeout: cfg.BackendDialTimeout,
MaxConcurrentConns: cfg.MaxConcurrentConns,
}, logger.Logger)
ln, err := net.Listen("tcp", cfg.Listen)
if err != nil {
logger.ComponentError(logging.ComponentSNI, "Failed to listen",
zap.String("addr", cfg.Listen), zap.Error(err))
os.Exit(1)
}
logger.ComponentInfo(logging.ComponentSNI, "SNI router listening",
zap.String("addr", cfg.Listen),
zap.Int("routes", len(cfg.Routes)),
zap.String("fallback", cfg.Fallback.Addr),
)
// Run Serve in a goroutine so the main goroutine can wait on signals.
serveErrCh := make(chan error, 1)
go func() {
serveErrCh <- srv.Serve(ln)
}()
// Wait for termination signal or unrecoverable Serve error.
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
select {
case sig := <-quit:
logger.ComponentInfo(logging.ComponentSNI, "Shutdown signal received",
zap.String("signal", sig.String()))
case err := <-serveErrCh:
logger.ComponentError(logging.ComponentSNI, "Serve returned",
zap.Error(err))
}
// Stop accepting new connections, then drain in-flight ones.
_ = ln.Close()
srv.Close()
logger.ComponentInfo(logging.ComponentSNI, "SNI router shutdown complete")
}
func parseConfig(logger *logging.ColoredLogger) yamlConfig {
configFlag := flag.String("config", "", "Config file path (absolute or filename in ~/.orama)")
flag.Parse()
var configPath string
var err error
if *configFlag != "" {
if filepath.IsAbs(*configFlag) {
configPath = *configFlag
} else {
configPath, err = config.DefaultPath(*configFlag)
if err != nil {
logger.ComponentError(logging.ComponentSNI, "Failed to determine config path",
zap.Error(err))
os.Exit(1)
}
}
} else {
configPath, err = config.DefaultPath("sni-router.yaml")
if err != nil {
logger.ComponentError(logging.ComponentSNI, "Failed to determine config path",
zap.Error(err))
os.Exit(1)
}
}
data, err := os.ReadFile(configPath)
if err != nil {
logger.ComponentError(logging.ComponentSNI, "Config file not found",
zap.String("path", configPath), zap.Error(err))
fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
os.Exit(1)
}
var y yamlConfig
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
logger.ComponentError(logging.ComponentSNI, "Failed to parse SNI router config",
zap.Error(err))
fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
os.Exit(1)
}
if errs := validateConfig(&y); len(errs) > 0 {
fmt.Fprintf(os.Stderr, "\nSNI router configuration errors (%d):\n", len(errs))
for _, e := range errs {
fmt.Fprintf(os.Stderr, " - %s\n", e)
}
fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
os.Exit(1)
}
logger.ComponentInfo(logging.ComponentSNI, "Loaded SNI router configuration",
zap.String("path", configPath),
)
return y
}
// validateConfig returns a non-empty slice of human-readable errors on misconfig.
func validateConfig(y *yamlConfig) []string {
var errs []string
if y.Listen == "" {
errs = append(errs, "listen: required (e.g. \":443\")")
}
if y.Fallback.Addr == "" {
errs = append(errs, "fallback.addr: required (where to send unmatched SNIs, typically Caddy)")
}
for i, r := range y.Routes {
if r.Match == "" {
errs = append(errs, fmt.Sprintf("routes[%d].match: required", i))
}
if r.Backend.Addr == "" {
errs = append(errs, fmt.Sprintf("routes[%d].backend.addr: required", i))
}
}
return errs
}
func toBackend(b yamlBackend) sniproxy.Backend {
network := b.Network
if network == "" {
network = "tcp"
}
return sniproxy.Backend{
Name: b.Name,
Network: network,
Addr: b.Addr,
}
}
func toRoutes(in []yamlRoute) []sniproxy.Route {
out := make([]sniproxy.Route, len(in))
for i, r := range in {
out[i] = sniproxy.Route{
Match: r.Match,
Backend: toBackend(r.Backend),
}
}
return out
}

187
core/docs/STEALTH_TURN.md Normal file
View File

@ -0,0 +1,187 @@
# Stealth TURN Deployment Guide
## What this is
A TLS-level SNI router that lets Orama serve TURN-over-TLS on `:443`,
sharing the port with Caddy HTTPS. From a network observer's
perspective, TURN traffic is indistinguishable from ordinary HTTPS —
useful for users in regions that block standard VoIP ports (UAE, Saudi
Arabia, China, Iran).
## Architecture
```
Internet
TCP :443
┌─────────┴─────────┐
│ orama-sni-router │ peeks SNI, forwards bytes
└─────────┬─────────┘
┌───────────────┼────────────────┐
▼ ▼
cdn.<base> *.<base>, <base>
turn.<base> (everything else)
│ │
▼ ▼
Pion TURN-TLS Caddy
127.0.0.1:5349 127.0.0.1:8443
(existing) (moved from :443)
```
The router does **not** terminate TLS. It reads the unencrypted TLS
ClientHello (first ~5 KB), inspects the SNI extension, and dials the
matching backend. Encrypted bytes pass through verbatim.
## Components
- **Library:** `pkg/sniproxy/` — ClientHello parser, route table, TCP server
- **Binary:** `cmd/sni-router/` (built as `bin/orama-sni-router`)
- **Systemd unit:** `systemd/orama-sni-router.service`
- **Config:** `~/.orama/sni-router.yaml`
## Deployment cutover
⚠️ **This change touches production `:443`. Stage on one node first, watch for 24h, then roll out.**
### 1. Reconfigure Caddy to listen on `:8443`
Update wherever the Caddy config is generated (`pkg/environments/production/installers/caddy.go`)
so Caddy binds `:8443` (HTTPS) and `:8080` (HTTP) instead of `:443` and `:80`.
Drop `CAP_NET_BIND_SERVICE` from Caddy's systemd unit — it no longer needs privileged ports.
### 2. Provision the cert SAN for `cdn.<base-domain>`
Caddy's automatic Let's Encrypt flow needs to issue a cert covering
`cdn.<base-domain>` and `cdn.ns-*.<base-domain>` so Pion TURN can read it
on startup. Add these names to Caddy's TLS config block.
### 3. Drop `sni-router.yaml` config
Example for a single-namespace node:
```yaml
listen: ":443"
client_hello_timeout: 5s
backend_dial_timeout: 5s
max_concurrent_conns: 10000
fallback:
name: caddy
addr: "127.0.0.1:8443"
routes:
- match: "cdn.example.com"
backend:
name: turn-tls
addr: "127.0.0.1:5349"
- match: "turn.example.com"
backend:
name: turn-tls
addr: "127.0.0.1:5349"
```
For multi-namespace, add per-namespace TURN backends (each namespace's
TURN-TLS port is allocated by `pkg/namespace`):
```yaml
- match: "cdn.ns-myapp.example.com"
backend: { name: "turn-myapp", addr: "127.0.0.1:5349" }
- match: "cdn.ns-other.example.com"
backend: { name: "turn-other", addr: "127.0.0.1:5350" }
```
### 4. Deploy + start in order
```bash
# Install binary
sudo cp bin-linux/orama-sni-router /opt/orama/bin/
# Install service
sudo cp systemd/orama-sni-router.service /etc/systemd/system/
sudo systemctl daemon-reload
# Stop Caddy briefly (it's about to lose :443)
sudo systemctl stop caddy
# Start the SNI router (it takes :443)
sudo systemctl enable --now orama-sni-router
# Restart Caddy on its new port
sudo systemctl start caddy
# Verify
curl -v https://cdn.<base>:443 # should hit TURN backend (TLS handshake will fail; that's fine)
curl -v https://<base>:443 # should hit Caddy (normal HTTPS response)
```
### 5. Enable stealth in the gateway
Once the SNI router is live, tell the gateway to advertise the stealth URI:
```go
// in gateway dependencies / startup
webrtcHandlers.SetStealthCDNDomain("cdn.<base-domain>")
```
The credentials handler will start including `turns:cdn.<base-domain>:443`
in `POST /v1/webrtc/turn/credentials` responses automatically.
### 6. Monitor
```bash
journalctl -u orama-sni-router.service -f
journalctl -u caddy.service -f
```
Watch for:
- `Connection limit reached` warnings (bump `max_concurrent_conns`)
- `backend dial failed` warnings (Caddy isn't listening on `:8443`, or TURN isn't on `:5349`)
- `ClientHello peek failed` debugs (curious clients sending non-TLS to `:443` — usually port scanners)
## Rollback
If anything is wrong:
```bash
sudo systemctl stop orama-sni-router
# Reconfigure Caddy back to :443 and restart
sudo systemctl restart caddy
```
Caddy reclaiming `:443` from the disabled router is the fastest way back to
the previous topology.
## Known gaps
- **Dynamic route source:** today's router reads YAML once at startup. To
pick up new namespaces without restart, implement a `RouteSource` that
polls `pkg/namespace` for active TURN deployments. The library is
already designed for `Router.Replace` to be called concurrently.
- **TLS cert hot-reload:** Pion TURN reads the cert once at startup. When
Caddy renews `cdn.<base-domain>`, Pion needs to be restarted to pick up
the new cert. A small file-watcher service (or a periodic restart in
off-peak hours) handles this for now.
## What clients see
Once enabled, the credentials response gains one entry:
```json
{
"username": "...",
"password": "...",
"ttl": 600,
"uris": [
"turn:turn.example.com:3478?transport=udp",
"turn:turn.example.com:3478?transport=tcp",
"turns:turn.example.com:5349",
"turns:cdn.example.com:443"
]
}
```
Browsers iterate ICE candidates; users in restricted regions will silently
succeed via the `:443` URI when others fail. No client-side change is
required.

View File

@ -0,0 +1,28 @@
-- =============================================================================
-- 021_pubsub_trigger_patterns.sql
--
-- Add `topic_pattern` column alongside the existing `topic` column to
-- function_pubsub_triggers. The new column may contain SQLite GLOB
-- patterns (e.g. "presence:*") in addition to exact topic names.
--
-- This is intentionally ADDITIVE rather than a column rename to remain
-- safe under rolling upgrades:
-- - Old binaries continue reading `topic` and keep working.
-- - New binaries read `topic_pattern` (which is back-filled from
-- `topic` for existing rows) and write BOTH columns.
-- A future migration can DROP COLUMN topic once every node is on the
-- new release.
-- =============================================================================
ALTER TABLE function_pubsub_triggers
ADD COLUMN topic_pattern TEXT NOT NULL DEFAULT '';
UPDATE function_pubsub_triggers
SET topic_pattern = topic
WHERE topic_pattern = '';
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_function
ON function_pubsub_triggers(function_id);
CREATE INDEX IF NOT EXISTS idx_function_pubsub_triggers_enabled
ON function_pubsub_triggers(enabled);

View File

@ -0,0 +1,20 @@
-- =============================================================================
-- 022_aggregation_windows.sql
--
-- Add per-trigger aggregation parameters to function_pubsub_triggers.
--
-- aggregation_window_ms = 0 means "no aggregation, invoke once per event"
-- (the existing behaviour). Any positive value enables buffering of events
-- in-memory on the dispatching node; the function is invoked once per
-- window with a batched payload.
--
-- aggregation_max_batch_size caps the per-window batch. When the buffer
-- reaches this size, the dispatcher flushes immediately even if the
-- window timer hasn't fired yet.
-- =============================================================================
ALTER TABLE function_pubsub_triggers
ADD COLUMN aggregation_window_ms INTEGER NOT NULL DEFAULT 0;
ALTER TABLE function_pubsub_triggers
ADD COLUMN aggregation_max_batch_size INTEGER NOT NULL DEFAULT 100;

View File

@ -0,0 +1,33 @@
-- =============================================================================
-- 023_push_devices.sql
--
-- Per-namespace, per-user push notification device registry.
--
-- token_encrypted is AES-256-GCM ciphertext (prefix 'enc:') derived via
-- pkg/secrets. Tokens are sensitive — they let the holder spam a user's
-- device — so they are never returned via any API or written to logs.
--
-- provider matches a registered push.PushProvider name:
-- 'ntfy', 'expo', 'apns', 'fcm' (future), ...
-- =============================================================================
CREATE TABLE IF NOT EXISTS push_devices (
id TEXT PRIMARY KEY,
namespace TEXT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
provider TEXT NOT NULL,
token_encrypted TEXT NOT NULL,
platform TEXT,
app_version TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
last_seen INTEGER,
UNIQUE(namespace, user_id, device_id)
);
CREATE INDEX IF NOT EXISTS idx_push_devices_user
ON push_devices(namespace, user_id);
CREATE INDEX IF NOT EXISTS idx_push_devices_provider
ON push_devices(provider);

View File

@ -157,6 +157,7 @@ func (b *Builder) buildOramaBinaries() error {
{Name: "identity", Package: "./cmd/identity/"},
{Name: "sfu", Package: "./cmd/sfu/"},
{Name: "turn", Package: "./cmd/turn/"},
{Name: "orama-sni-router", Package: "./cmd/sni-router/"},
}
for _, bin := range binaries {

View File

@ -171,7 +171,7 @@ rm -f /tmp/orama-*.sh /tmp/network-source.tar.gz /tmp/orama-*.tar.gz
# Nuclear: remove binaries
if [ -n "$NUCLEAR" ]; then
rm -f /usr/local/bin/orama /usr/local/bin/orama-node /usr/local/bin/gateway
rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn
rm -f /usr/local/bin/identity /usr/local/bin/sfu /usr/local/bin/turn /usr/local/bin/orama-sni-router
rm -f /usr/local/bin/olric-server /usr/local/bin/ipfs /usr/local/bin/ipfs-cluster-service
rm -f /usr/local/bin/rqlited /usr/local/bin/coredns
rm -f /usr/bin/caddy

View File

@ -47,10 +47,29 @@ type DatabaseClient interface {
type PubSubClient interface {
Subscribe(ctx context.Context, topic string, handler MessageHandler) error
Publish(ctx context.Context, topic string, data []byte) error
// PublishBatch publishes multiple messages in parallel, one per topic.
// See pubsub.Manager.PublishBatch for semantics (fail-fast vs. best-effort).
PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error
// PublishSame sends the same payload to every topic in parallel.
PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error
Unsubscribe(ctx context.Context, topic string) error
ListTopics(ctx context.Context) ([]string, error)
}
// TopicMessage is one entry in a batch publish.
// Mirrors pubsub.TopicMessage to avoid forcing client callers to import pkg/pubsub.
type TopicMessage struct {
Topic string
Data []byte
}
// PublishBatchOptions controls batch publish behavior.
// Mirrors pubsub.PublishBatchOptions.
type PublishBatchOptions struct {
BestEffort bool
MaxConcurrency int
}
// NetworkInfo provides network status and peer information
type NetworkInfo interface {
GetPeers(ctx context.Context) ([]PeerInfo, error)

View File

@ -4,13 +4,13 @@ import (
"context"
"fmt"
"github.com/DeBrosOfficial/network/pkg/pubsub"
pkgpubsub "github.com/DeBrosOfficial/network/pkg/pubsub"
)
// pubSubBridge bridges between our PubSubClient interface and the pubsub package
type pubSubBridge struct {
client *Client
adapter *pubsub.ClientAdapter
adapter *pkgpubsub.ClientAdapter
}
func (p *pubSubBridge) Subscribe(ctx context.Context, topic string, handler MessageHandler) error {
@ -31,6 +31,26 @@ func (p *pubSubBridge) Publish(ctx context.Context, topic string, data []byte) e
return p.adapter.Publish(ctx, topic, data)
}
func (p *pubSubBridge) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
if err := p.client.requireAccess(ctx); err != nil {
return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
}
pkgMsgs := make([]pkgpubsub.TopicMessage, len(msgs))
for i, m := range msgs {
pkgMsgs[i] = pkgpubsub.TopicMessage{Topic: m.Topic, Data: m.Data}
}
pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency}
return p.adapter.PublishBatch(ctx, pkgMsgs, pkgOpts)
}
func (p *pubSubBridge) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
if err := p.client.requireAccess(ctx); err != nil {
return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)
}
pkgOpts := pkgpubsub.PublishBatchOptions{BestEffort: opts.BestEffort, MaxConcurrency: opts.MaxConcurrency}
return p.adapter.PublishSame(ctx, topics, data, pkgOpts)
}
func (p *pubSubBridge) Unsubscribe(ctx context.Context, topic string) error {
if err := p.client.requireAccess(ctx); err != nil {
return fmt.Errorf("authentication required: %w - run CLI commands to authenticate automatically", err)

View File

@ -56,4 +56,15 @@ type Config struct {
SFUPort int // Local SFU signaling port to proxy WebSocket connections to
TURNDomain string // TURN server domain for credential generation
TURNSecret string // HMAC-SHA1 shared secret for TURN credential generation
// StealthCDNDomain, when set, makes the WebRTC credentials handler
// advertise turns:<StealthCDNDomain>:443 (served by the SNI router).
StealthCDNDomain string
// Push notification configuration. Push is enabled when at least one
// provider URL/token is set. Tokens stored in the push_devices table
// are encrypted at rest via pkg/secrets using the cluster secret.
NtfyBaseURL string // ntfy server URL (e.g. "http://localhost:8080")
NtfyAuthToken string // optional bearer token for ntfy
ExpoAccessToken string // optional Expo access token
}

View File

@ -19,6 +19,9 @@ import (
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/push"
pushexpo "github.com/DeBrosOfficial/network/pkg/push/providers/expo"
pushntfy "github.com/DeBrosOfficial/network/pkg/push/providers/ntfy"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/DeBrosOfficial/network/pkg/serverless/hostfunctions"
@ -63,6 +66,11 @@ type Dependencies struct {
// PubSub trigger dispatcher (used to wire into PubSubHandlers)
PubSubDispatcher *triggers.PubSubDispatcher
// Push notification dispatcher (nil when push isn't configured —
// hostfunc + HTTP handlers degrade to no-op / 503).
PushDispatcher *push.PushDispatcher
PushDeviceStore push.PushDeviceStore
// Authentication service
AuthService *auth.Service
}
@ -412,6 +420,19 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
secretsMgr = smImpl
}
// Initialize push notification dispatcher if any provider is configured.
// Devices are stored encrypted in RQLite (see migration 023). Providers
// are registered based on gateway config; missing config = provider absent.
pushDispatcher, pushStore, err := buildPushDispatcher(cfg, deps.ORMClient, logger)
if err != nil {
// Non-fatal: log and continue. Functions calling push_send will get nil
// (silent no-op) and HTTP /v1/push/* endpoints return 503.
logger.ComponentWarn(logging.ComponentGeneral,
"push notifications disabled (init failed)", zap.Error(err))
}
deps.PushDispatcher = pushDispatcher
deps.PushDeviceStore = pushStore
// Create host functions provider (allows functions to call Orama services)
hostFuncsCfg := hostfunctions.HostFunctionsConfig{
IPFSAPIURL: cfg.IPFSAPIURL,
@ -424,6 +445,7 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe
pubsubAdapter, // pubsub adapter for serverless functions
deps.ServerlessWSMgr,
secretsMgr,
pushDispatcher, // may be nil — PushSend hostfunc handles that
hostFuncsCfg,
logger.Logger,
)
@ -686,3 +708,40 @@ func injectRQLiteAuth(dsn, username, password string) string {
}
return dsn
}
// buildPushDispatcher constructs a push.PushDispatcher + device store with
// all enabled providers. Returns (nil, nil, nil) when no provider is
// configured — that's a supported state, not an error. Returns (nil, nil,
// err) on hard init failures (e.g. cluster secret missing for the
// encrypted device store).
func buildPushDispatcher(cfg *Config, db rqlite.Client, logger *logging.ColoredLogger) (*push.PushDispatcher, push.PushDeviceStore, error) {
if cfg.NtfyBaseURL == "" && cfg.ExpoAccessToken == "" {
// No providers configured — push is disabled.
return nil, nil, nil
}
if cfg.ClusterSecret == "" {
// Devices are encrypted at rest using a cluster-secret-derived key.
// Without it we can't store anything safely.
return nil, nil, fmt.Errorf("push enabled but ClusterSecret is empty")
}
store, err := push.NewRqliteDeviceStore(db, cfg.ClusterSecret, logger.Logger)
if err != nil {
return nil, nil, fmt.Errorf("init push device store: %w", err)
}
d := push.New(store, logger.Logger)
if cfg.NtfyBaseURL != "" {
d.Register(pushntfy.New(pushntfy.Config{
BaseURL: cfg.NtfyBaseURL,
AuthToken: cfg.NtfyAuthToken,
}, logger.Logger))
logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: ntfy",
zap.String("base_url", cfg.NtfyBaseURL))
}
if cfg.ExpoAccessToken != "" {
d.Register(pushexpo.New(pushexpo.Config{
AccessToken: cfg.ExpoAccessToken,
}, logger.Logger))
logger.ComponentInfo(logging.ComponentGeneral, "push provider registered: expo")
}
return d, store, nil
}

View File

@ -28,6 +28,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache"
deploymentshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/deployments"
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
pushhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/push"
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
enrollhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/enroll"
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
@ -43,6 +44,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/DeBrosOfficial/network/pkg/serverless/triggers"
_ "github.com/mattn/go-sqlite3"
"go.uber.org/zap"
)
@ -83,13 +85,15 @@ type Gateway struct {
mu sync.RWMutex
presenceMu sync.RWMutex
pubsubHandlers *pubsubhandlers.PubSubHandlers
pushHandlers *pushhandlers.Handlers
// Serverless function engine
serverlessEngine *serverless.Engine
serverlessRegistry *serverless.Registry
serverlessInvoker *serverless.Invoker
serverlessWSMgr *serverless.WSManager
serverlessHandlers *serverlesshandlers.ServerlessHandlers
serverlessEngine *serverless.Engine
serverlessRegistry *serverless.Registry
serverlessInvoker *serverless.Invoker
serverlessWSMgr *serverless.WSManager
serverlessHandlers *serverlesshandlers.ServerlessHandlers
pubsubDispatcher *triggers.PubSubDispatcher
// Authentication service
authService *auth.Service
@ -342,11 +346,20 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
// Wire PubSub trigger dispatch if serverless is available
if deps.PubSubDispatcher != nil {
gw.pubsubDispatcher = deps.PubSubDispatcher
gw.pubsubHandlers.SetOnPublish(func(ctx context.Context, namespace, topic string, data []byte) {
deps.PubSubDispatcher.Dispatch(ctx, namespace, topic, data, 0)
})
}
// Push notification handlers — disabled when no provider is configured.
// The handlers themselves return 503 if dispatcher/store is nil; we
// register them unconditionally so the routes always exist with a
// predictable shape.
if deps.PushDispatcher != nil {
gw.pushHandlers = pushhandlers.NewHandlers(deps.PushDispatcher, deps.PushDeviceStore, logger)
}
if cfg.WebRTCEnabled && cfg.SFUPort > 0 {
gw.webrtcHandlers = webrtchandlers.NewWebRTCHandlers(
logger,

View File

@ -21,10 +21,12 @@ import (
// mockPubSubClient implements client.PubSubClient for testing
type mockPubSubClient struct {
PublishFunc func(ctx context.Context, topic string, data []byte) error
SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error
UnsubscribeFunc func(ctx context.Context, topic string) error
ListTopicsFunc func(ctx context.Context) ([]string, error)
PublishFunc func(ctx context.Context, topic string, data []byte) error
PublishBatchFunc func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error
PublishSameFunc func(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error
SubscribeFunc func(ctx context.Context, topic string, handler client.MessageHandler) error
UnsubscribeFunc func(ctx context.Context, topic string) error
ListTopicsFunc func(ctx context.Context) ([]string, error)
}
func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byte) error {
@ -34,6 +36,20 @@ func (m *mockPubSubClient) Publish(ctx context.Context, topic string, data []byt
return nil
}
func (m *mockPubSubClient) PublishBatch(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error {
if m.PublishBatchFunc != nil {
return m.PublishBatchFunc(ctx, msgs, opts)
}
return nil
}
func (m *mockPubSubClient) PublishSame(ctx context.Context, topics []string, data []byte, opts client.PublishBatchOptions) error {
if m.PublishSameFunc != nil {
return m.PublishSameFunc(ctx, topics, data, opts)
}
return nil
}
func (m *mockPubSubClient) Subscribe(ctx context.Context, topic string, handler client.MessageHandler) error {
if m.SubscribeFunc != nil {
return m.SubscribeFunc(ctx, topic, handler)

View File

@ -0,0 +1,156 @@
package pubsub
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
)
func TestPublishBatchHandler_invalid_method(t *testing.T) {
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
req := withNamespace(httptest.NewRequest(http.MethodGet, "/v1/pubsub/publish-batch", nil), "ns")
rr := httptest.NewRecorder()
h.PublishBatchHandler(rr, req)
if rr.Code != http.StatusMethodNotAllowed {
t.Errorf("expected 405, got %d", rr.Code)
}
}
func TestPublishBatchHandler_missing_namespace(t *testing.T) {
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{{Topic: "a", DataB64: "AA=="}}})
req := httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body))
rr := httptest.NewRecorder()
h.PublishBatchHandler(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("expected 403, got %d (body: %s)", rr.Code, rr.Body.String())
}
}
func TestPublishBatchHandler_empty_messages_rejected(t *testing.T) {
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{}})
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
rr := httptest.NewRecorder()
h.PublishBatchHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400 for empty messages, got %d", rr.Code)
}
}
func TestPublishBatchHandler_oversize_batch_rejected(t *testing.T) {
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
entries := make([]PublishBatchEntry, MaxPublishBatchSize+1)
for i := range entries {
entries[i] = PublishBatchEntry{Topic: "t", DataB64: "AA=="}
}
body, _ := json.Marshal(PublishBatchRequest{Messages: entries})
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
rr := httptest.NewRecorder()
h.PublishBatchHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400 for oversize batch, got %d", rr.Code)
}
}
func TestPublishBatchHandler_invalid_base64_rejected(t *testing.T) {
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{
{Topic: "good", DataB64: base64.StdEncoding.EncodeToString([]byte("ok"))},
{Topic: "bad", DataB64: "!!!not-base64"},
}})
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
rr := httptest.NewRecorder()
h.PublishBatchHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400 for invalid base64, got %d", rr.Code)
}
}
func TestPublishBatchHandler_missing_topic_rejected(t *testing.T) {
h := newTestHandlers(&mockNetworkClient{pubsub: &mockPubSubClient{}})
body, _ := json.Marshal(PublishBatchRequest{Messages: []PublishBatchEntry{
{Topic: "", DataB64: base64.StdEncoding.EncodeToString([]byte("x"))},
}})
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "ns")
rr := httptest.NewRecorder()
h.PublishBatchHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400 for missing topic, got %d", rr.Code)
}
}
func TestPublishBatchHandler_happy_calls_PublishBatch(t *testing.T) {
var (
called int32
gotMessages []client.TopicMessage
mu sync.Mutex
)
mock := &mockPubSubClient{
PublishBatchFunc: func(ctx context.Context, msgs []client.TopicMessage, opts client.PublishBatchOptions) error {
atomic.AddInt32(&called, 1)
mu.Lock()
gotMessages = append(gotMessages, msgs...)
mu.Unlock()
return nil
},
}
h := newTestHandlers(&mockNetworkClient{pubsub: mock})
entries := []PublishBatchEntry{
{Topic: "a", DataB64: base64.StdEncoding.EncodeToString([]byte("data-a"))},
{Topic: "b", DataB64: base64.StdEncoding.EncodeToString([]byte("data-b"))},
}
body, _ := json.Marshal(PublishBatchRequest{Messages: entries})
req := withNamespace(httptest.NewRequest(http.MethodPost, "/v1/pubsub/publish-batch", bytes.NewReader(body)), "test-ns")
rr := httptest.NewRecorder()
h.PublishBatchHandler(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
}
// PublishBatch is invoked from a goroutine; give it a moment to run.
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&called) == 0 {
if time.Now().After(deadline) {
t.Fatal("PublishBatch was not called within 2s")
}
time.Sleep(10 * time.Millisecond)
}
mu.Lock()
defer mu.Unlock()
if len(gotMessages) != 2 {
t.Fatalf("expected 2 messages forwarded, got %d", len(gotMessages))
}
if gotMessages[0].Topic != "a" || string(gotMessages[0].Data) != "data-a" {
t.Errorf("unexpected first message: %+v", gotMessages[0])
}
if gotMessages[1].Topic != "b" || string(gotMessages[1].Data) != "data-b" {
t.Errorf("unexpected second message: %+v", gotMessages[1])
}
}

View File

@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
@ -12,6 +13,10 @@ import (
"go.uber.org/zap"
)
// MaxPublishBatchSize is the maximum number of messages allowed in a single
// /v1/pubsub/publish-batch request. Mirrors pubsub.MaxBatchSize.
const MaxPublishBatchSize = pubsub.MaxBatchSize
// PublishHandler handles POST /v1/pubsub/publish {topic, data_base64}
func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request) {
if p.client == nil {
@ -39,9 +44,133 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request)
return
}
// Check for local websocket subscribers FIRST and deliver directly
p.deliverLocal(ns, body.Topic, data)
// Publish to libp2p asynchronously for cross-node delivery.
go func() {
publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
p.logger.ComponentWarn("gateway", "async libp2p publish failed",
zap.String("topic", body.Topic),
zap.Error(err))
}
}()
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
}
// PublishBatchRequest is the request body for POST /v1/pubsub/publish-batch.
type PublishBatchRequest struct {
Messages []PublishBatchEntry `json:"messages"`
BestEffort bool `json:"best_effort,omitempty"`
}
// PublishBatchEntry is one message in a batch publish request.
type PublishBatchEntry struct {
Topic string `json:"topic"`
DataB64 string `json:"data_base64"`
}
// PublishBatchResponse is the response body for /v1/pubsub/publish-batch.
//
// libp2p delivery is asynchronous and not awaited here, mirroring the
// single-publish handler's fire-and-forget contract. Per-topic failures
// are not surfaced via this response — operators should consult logs /
// metrics for delivery health.
type PublishBatchResponse struct {
Status string `json:"status"` // always "ok" — request was accepted
}
// MaxPerMessageBytes caps an individual message payload inside a batch.
// Mirrors the 1MB cap on /v1/pubsub/publish.
const MaxPerMessageBytes = 1 << 20
// PublishBatchHandler handles POST /v1/pubsub/publish-batch.
// Accepts up to MaxPublishBatchSize messages and publishes them in parallel,
// preserving namespace isolation. Local subscribers receive messages
// immediately; libp2p delivery is async.
func (p *PubSubHandlers) PublishBatchHandler(w http.ResponseWriter, r *http.Request) {
if p.client == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized")
return
}
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
// Limit body size: MaxPublishBatchSize messages * ~1MB each = up to ~100MB.
// Cap conservatively at 16MB to discourage huge payloads.
r.Body = http.MaxBytesReader(w, r.Body, 16<<20)
var body PublishBatchRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid body: expected {messages:[{topic,data_base64}]}")
return
}
if len(body.Messages) == 0 {
writeError(w, http.StatusBadRequest, "messages required")
return
}
if len(body.Messages) > MaxPublishBatchSize {
writeError(w, http.StatusBadRequest, "too many messages: max is 100 per batch")
return
}
// Decode all messages up-front so we can fail fast on bad input.
decoded := make([]pubsub.TopicMessage, 0, len(body.Messages))
for i, m := range body.Messages {
if m.Topic == "" {
writeError(w, http.StatusBadRequest, "message missing topic at index "+strconv.Itoa(i))
return
}
data, err := base64.StdEncoding.DecodeString(m.DataB64)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid base64 data at index "+strconv.Itoa(i))
return
}
if len(data) > MaxPerMessageBytes {
writeError(w, http.StatusBadRequest, "message too large at index "+strconv.Itoa(i))
return
}
decoded = append(decoded, pubsub.TopicMessage{Topic: m.Topic, Data: data})
}
// Deliver locally + dispatch triggers per topic synchronously (fast in-process).
for _, msg := range decoded {
p.deliverLocal(ns, msg.Topic, msg.Data)
}
// Async libp2p batch publish, similar to PublishHandler's approach.
go func() {
publishCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
opts := pubsub.PublishBatchOptions{BestEffort: body.BestEffort}
err := p.client.PubSub().PublishBatch(ctx, toClientMessages(decoded), clientOpts(opts))
if err != nil {
p.logger.ComponentWarn("gateway", "async libp2p batch publish failed",
zap.Int("messages", len(decoded)),
zap.Error(err))
}
}()
writeJSON(w, http.StatusOK, PublishBatchResponse{Status: "ok"})
}
// deliverLocal handles local-subscriber delivery and fires PubSub triggers.
// It does NOT publish to libp2p — callers handle that themselves (single
// or batched) so this helper stays focused on in-process fan-out.
func (p *PubSubHandlers) deliverLocal(ns, topic string, data []byte) {
p.mu.RLock()
localSubs := p.getLocalSubscribers(body.Topic, ns)
localSubs := p.getLocalSubscribers(topic, ns)
p.mu.RUnlock()
localDeliveryCount := 0
@ -50,48 +179,38 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request)
select {
case sub.msgChan <- data:
localDeliveryCount++
p.logger.ComponentDebug("gateway", "delivered to local subscriber",
zap.String("topic", body.Topic))
default:
// Drop if buffer full
p.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message",
zap.String("topic", body.Topic))
zap.String("topic", topic))
}
}
}
p.logger.ComponentInfo("gateway", "pubsub publish: processing message",
zap.String("topic", body.Topic),
zap.String("topic", topic),
zap.String("namespace", ns),
zap.Int("data_len", len(data)),
zap.Int("local_subscribers", len(localSubs)),
zap.Int("local_delivered", localDeliveryCount))
// Fire PubSub triggers for serverless functions (non-blocking)
// Fire PubSub triggers for serverless functions (non-blocking).
if p.onPublish != nil {
go p.onPublish(context.Background(), ns, body.Topic, data)
go p.onPublish(context.Background(), ns, topic, data)
}
}
// Publish to libp2p asynchronously for cross-node delivery
// This prevents blocking the HTTP response if libp2p network is slow
go func() {
publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// toClientMessages converts pubsub.TopicMessage to client.TopicMessage for
// passing through the PubSubClient interface.
func toClientMessages(msgs []pubsub.TopicMessage) []client.TopicMessage {
out := make([]client.TopicMessage, len(msgs))
for i, m := range msgs {
out[i] = client.TopicMessage{Topic: m.Topic, Data: m.Data}
}
return out
}
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
if err := p.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
p.logger.ComponentWarn("gateway", "async libp2p publish failed",
zap.String("topic", body.Topic),
zap.Error(err))
} else {
p.logger.ComponentDebug("gateway", "async libp2p publish succeeded",
zap.String("topic", body.Topic))
}
}()
// Return immediately after local delivery
// Local WebSocket subscribers already received the message
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
func clientOpts(o pubsub.PublishBatchOptions) client.PublishBatchOptions {
return client.PublishBatchOptions{BestEffort: o.BestEffort, MaxConcurrency: o.MaxConcurrency}
}
// TopicsHandler lists topics within the caller's namespace

View File

@ -0,0 +1,291 @@
package push
import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/push"
"go.uber.org/zap"
)
// validProviders is the allowlist for the `provider` field on RegisterDevice.
// Keep in sync with what the dispatcher actually has registered at startup.
var validProviders = map[string]struct{}{
"ntfy": {},
"expo": {},
"apns": {}, // future — accepted at registration so apps can pre-flight
}
// MaxTokenBytes caps the device-token length to prevent abuse.
// Real ntfy topic paths and Expo tokens are well under this.
const MaxTokenBytes = 512
// RegisterDeviceHandler handles POST /v1/push/devices.
//
// The caller must be authenticated; their JWT subject (Sub) is used as the
// user_id. API-key callers are allowed only if the body explicitly carries
// a user_id — currently rejected to keep the surface small.
func (h *Handlers) RegisterDeviceHandler(w http.ResponseWriter, r *http.Request) {
if h.store == nil {
writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
return
}
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespace(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
userID := resolveCallerUserID(r)
if userID == "" {
// We require a JWT-authenticated user to bind the device to.
// API-key-only callers can't register devices on behalf of users.
writeError(w, http.StatusUnauthorized, "user authentication required")
return
}
r.Body = http.MaxBytesReader(w, r.Body, 4096)
var body RegisterDeviceRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid body")
return
}
body.DeviceID = strings.TrimSpace(body.DeviceID)
body.Provider = strings.TrimSpace(body.Provider)
body.Token = strings.TrimSpace(body.Token)
if body.DeviceID == "" {
writeError(w, http.StatusBadRequest, "device_id required")
return
}
if _, ok := validProviders[body.Provider]; !ok {
writeError(w, http.StatusBadRequest, "unknown provider: "+body.Provider)
return
}
if body.Token == "" {
writeError(w, http.StatusBadRequest, "token required")
return
}
if len(body.Token) > MaxTokenBytes {
writeError(w, http.StatusBadRequest, "token too long")
return
}
now := time.Now().Unix()
dev := push.PushDevice{
Namespace: ns,
UserID: userID,
DeviceID: body.DeviceID,
Provider: body.Provider,
Token: body.Token,
Platform: body.Platform,
AppVer: body.AppVersion,
LastSeen: now,
}
if err := h.store.Upsert(boundCtx(r), dev); err != nil {
h.logger.ComponentWarn("push", "device upsert failed",
zap.String("namespace", ns),
zap.String("user_id", userID),
zap.Error(err))
writeError(w, http.StatusInternalServerError, "registration failed")
return
}
writeJSON(w, http.StatusOK, RegisterDeviceResponse{Status: "ok"})
}
// ListDevicesHandler handles GET /v1/push/devices.
//
// Returns the caller's own devices; tokens are NEVER included in the
// response. Other namespaces / other users are inaccessible.
func (h *Handlers) ListDevicesHandler(w http.ResponseWriter, r *http.Request) {
if h.store == nil {
writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
return
}
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespace(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
userID := resolveCallerUserID(r)
if userID == "" {
writeError(w, http.StatusUnauthorized, "user authentication required")
return
}
devs, err := h.store.ListForUser(boundCtx(r), ns, userID)
if err != nil {
writeError(w, http.StatusInternalServerError, "list failed")
return
}
views := make([]PushDeviceView, len(devs))
for i, d := range devs {
views[i] = PushDeviceView{
ID: d.ID,
DeviceID: d.DeviceID,
Provider: d.Provider,
Platform: d.Platform,
AppVersion: d.AppVer,
CreatedAt: d.CreatedAt,
UpdatedAt: d.UpdatedAt,
LastSeen: d.LastSeen,
}
}
writeJSON(w, http.StatusOK, map[string]interface{}{"devices": views})
}
// DeleteDeviceHandler handles DELETE /v1/push/devices/{id}.
//
// `{id}` is the database row ID returned at registration / by ListDevices.
// Only devices belonging to the caller (matched by namespace + user_id +
// the device ID lookup) can be deleted.
func (h *Handlers) DeleteDeviceHandler(w http.ResponseWriter, r *http.Request) {
if h.store == nil {
writeError(w, http.StatusServiceUnavailable, "push: device store not configured")
return
}
if r.Method != http.MethodDelete {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespace(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
userID := resolveCallerUserID(r)
if userID == "" {
writeError(w, http.StatusUnauthorized, "user authentication required")
return
}
id := extractIDFromPath(r.URL.Path, "/v1/push/devices/")
if id == "" {
writeError(w, http.StatusBadRequest, "device id required in path")
return
}
// Authorization check: confirm the device belongs to the caller.
devs, err := h.store.ListForUser(boundCtx(r), ns, userID)
if err != nil {
writeError(w, http.StatusInternalServerError, "lookup failed")
return
}
owns := false
for _, d := range devs {
if d.ID == id {
owns = true
break
}
}
if !owns {
// 404, not 403 — don't leak whether the ID exists in another scope.
writeError(w, http.StatusNotFound, "not found")
return
}
if err := h.store.Delete(boundCtx(r), ns, id); err != nil {
h.logger.ComponentWarn("push", "device delete failed",
zap.String("namespace", ns),
zap.String("device_row_id", id),
zap.Error(err))
writeError(w, http.StatusInternalServerError, "delete failed")
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
// SendHandler handles POST /v1/push/send.
//
// SECURITY: this endpoint sends arbitrary push messages to any user_id
// in the caller's namespace. It MUST be gated to a small set of trusted
// callers — typically only the namespace's own serverless functions
// (which can send via the WASM `push_send` hostfunc directly without
// going through HTTP) and the namespace operator.
//
// The current implementation accepts any JWT-authenticated caller within
// the namespace. **Add an explicit allow-list or admin-scope check before
// exposing this in production.** The WASM hostfunc bypasses this issue
// because trigger registration already gates which functions exist.
func (h *Handlers) SendHandler(w http.ResponseWriter, r *http.Request) {
if h.dispatcher == nil {
writeError(w, http.StatusServiceUnavailable, "push: dispatcher not configured")
return
}
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespace(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
if resolveCallerUserID(r) == "" {
writeError(w, http.StatusUnauthorized, "user authentication required")
return
}
r.Body = http.MaxBytesReader(w, r.Body, 64*1024) // generous for Data payloads
var body SendRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid body")
return
}
body.UserID = strings.TrimSpace(body.UserID)
if body.UserID == "" {
writeError(w, http.StatusBadRequest, "user_id required")
return
}
msg := push.PushMessage{
Title: body.Title,
Body: body.Body,
Channel: body.Channel,
Priority: pickPriority(body.Priority),
Badge: body.Badge,
Sound: body.Sound,
Data: body.Data,
}
if err := h.dispatcher.SendToUser(boundCtx(r), ns, body.UserID, msg); err != nil {
// Treat as non-fatal: some devices may have failed but others may
// have succeeded. Surface as 502 to signal partial trouble; logs
// have the per-device detail.
h.logger.ComponentWarn("push", "send to user partially failed",
zap.String("namespace", ns),
zap.String("user_id", body.UserID),
zap.Error(err))
writeError(w, http.StatusBadGateway, "one or more devices failed")
return
}
writeJSON(w, http.StatusOK, SendResponse{Status: "ok"})
}
// extractIDFromPath returns the trailing path segment after `prefix`, or
// empty string if the path doesn't match. Used because the gateway uses
// the standard `net/http` mux which doesn't extract path params.
func extractIDFromPath(urlPath, prefix string) string {
if !strings.HasPrefix(urlPath, prefix) {
return ""
}
rest := urlPath[len(prefix):]
// Drop any query string (shouldn't normally appear in path here).
if i := strings.IndexAny(rest, "?#/"); i >= 0 {
rest = rest[:i]
}
return rest
}

View File

@ -0,0 +1,330 @@
package push
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
authsvc "github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/push"
"go.uber.org/zap"
)
// fakeStore is an in-memory PushDeviceStore for tests.
type fakeStore struct {
devices []push.PushDevice
upsertFn func(push.PushDevice) error
deleteFn func(ns, id string) error
listErr error
}
func (s *fakeStore) Upsert(ctx context.Context, dev push.PushDevice) error {
if s.upsertFn != nil {
return s.upsertFn(dev)
}
if dev.ID == "" {
dev.ID = "row-" + dev.DeviceID
}
s.devices = append(s.devices, dev)
return nil
}
func (s *fakeStore) Delete(ctx context.Context, ns, id string) error {
if s.deleteFn != nil {
return s.deleteFn(ns, id)
}
for i, d := range s.devices {
if d.ID == id && d.Namespace == ns {
s.devices = append(s.devices[:i], s.devices[i+1:]...)
return nil
}
}
return errors.New("not found")
}
func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]push.PushDevice, error) {
if s.listErr != nil {
return nil, s.listErr
}
out := []push.PushDevice{}
for _, d := range s.devices {
if d.Namespace == ns && d.UserID == userID {
out = append(out, d)
}
}
return out, nil
}
// withAuth populates the namespace + JWT claims (caller user ID).
func withAuth(r *http.Request, namespace, userID string) *http.Request {
ctx := r.Context()
if namespace != "" {
ctx = context.WithValue(ctx, ctxkeys.NamespaceOverride, namespace)
}
if userID != "" {
ctx = context.WithValue(ctx, ctxkeys.JWT, &authsvc.JWTClaims{Sub: userID, Namespace: namespace})
}
return r.WithContext(ctx)
}
func newHandlers(store push.PushDeviceStore, dispatcher *push.PushDispatcher) *Handlers {
logger := &logging.ColoredLogger{Logger: zap.NewNop()}
return NewHandlers(dispatcher, store, logger)
}
// --- RegisterDeviceHandler ---
func TestRegister_happy_path(t *testing.T) {
store := &fakeStore{}
h := newHandlers(store, nil)
body, _ := json.Marshal(RegisterDeviceRequest{
DeviceID: "iphone-abc",
Provider: "ntfy",
Token: "ns/myapp/user-1",
Platform: "ios",
})
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "myapp", "user-1")
rr := httptest.NewRecorder()
h.RegisterDeviceHandler(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
}
if len(store.devices) != 1 {
t.Fatalf("expected 1 device stored, got %d", len(store.devices))
}
d := store.devices[0]
if d.Namespace != "myapp" || d.UserID != "user-1" || d.Token != "ns/myapp/user-1" {
t.Errorf("unexpected device: %+v", d)
}
}
func TestRegister_unauthenticated_rejected(t *testing.T) {
h := newHandlers(&fakeStore{}, nil)
body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: "t"})
// No JWT in context.
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "")
rr := httptest.NewRecorder()
h.RegisterDeviceHandler(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", rr.Code)
}
}
func TestRegister_unknown_provider_rejected(t *testing.T) {
h := newHandlers(&fakeStore{}, nil)
body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "weirdmail", Token: "t"})
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u")
rr := httptest.NewRecorder()
h.RegisterDeviceHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rr.Code)
}
}
func TestRegister_oversize_token_rejected(t *testing.T) {
h := newHandlers(&fakeStore{}, nil)
huge := make([]byte, MaxTokenBytes+1)
for i := range huge {
huge[i] = 'a'
}
body, _ := json.Marshal(RegisterDeviceRequest{DeviceID: "x", Provider: "ntfy", Token: string(huge)})
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader(body)), "ns", "u")
rr := httptest.NewRecorder()
h.RegisterDeviceHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rr.Code)
}
}
func TestRegister_no_store_returns_503(t *testing.T) {
h := newHandlers(nil, nil)
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/devices", bytes.NewReader([]byte(`{}`))), "ns", "u")
rr := httptest.NewRecorder()
h.RegisterDeviceHandler(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("expected 503, got %d", rr.Code)
}
}
// --- ListDevicesHandler ---
func TestList_returns_only_callers_devices_without_tokens(t *testing.T) {
store := &fakeStore{
devices: []push.PushDevice{
{ID: "1", Namespace: "myapp", UserID: "u1", DeviceID: "d1", Provider: "ntfy", Token: "secret-token-1"},
{ID: "2", Namespace: "myapp", UserID: "u1", DeviceID: "d2", Provider: "expo", Token: "secret-token-2"},
{ID: "3", Namespace: "myapp", UserID: "u2", DeviceID: "d3", Provider: "ntfy", Token: "secret-token-3"},
{ID: "4", Namespace: "other", UserID: "u1", DeviceID: "d4", Provider: "ntfy", Token: "secret-token-4"},
},
}
h := newHandlers(store, nil)
req := withAuth(httptest.NewRequest(http.MethodGet, "/v1/push/devices", nil), "myapp", "u1")
rr := httptest.NewRecorder()
h.ListDevicesHandler(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rr.Code)
}
var resp struct {
Devices []PushDeviceView `json:"devices"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode: %v", err)
}
if len(resp.Devices) != 2 {
t.Fatalf("expected 2 devices, got %d", len(resp.Devices))
}
// Tokens must NOT appear in response — they're not even in the struct.
if bytes.Contains(rr.Body.Bytes(), []byte("secret-token")) {
t.Errorf("response leaked a token: %s", rr.Body.String())
}
}
// --- DeleteDeviceHandler ---
func TestDelete_owns_device_succeeds(t *testing.T) {
store := &fakeStore{
devices: []push.PushDevice{
{ID: "row-d1", Namespace: "myapp", UserID: "u1", DeviceID: "d1"},
},
}
h := newHandlers(store, nil)
req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1")
rr := httptest.NewRecorder()
h.DeleteDeviceHandler(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
}
if len(store.devices) != 0 {
t.Errorf("expected device removed")
}
}
func TestDelete_other_users_device_returns_404(t *testing.T) {
store := &fakeStore{
devices: []push.PushDevice{
{ID: "row-d1", Namespace: "myapp", UserID: "other-user", DeviceID: "d1"},
},
}
h := newHandlers(store, nil)
req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/row-d1", nil), "myapp", "u1")
rr := httptest.NewRecorder()
h.DeleteDeviceHandler(rr, req)
if rr.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d", rr.Code)
}
if len(store.devices) != 1 {
t.Errorf("expected device NOT removed")
}
}
func TestDelete_missing_id_returns_400(t *testing.T) {
h := newHandlers(&fakeStore{}, nil)
req := withAuth(httptest.NewRequest(http.MethodDelete, "/v1/push/devices/", nil), "myapp", "u1")
rr := httptest.NewRecorder()
h.DeleteDeviceHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rr.Code)
}
}
// --- SendHandler ---
func TestSend_dispatcher_called_for_user(t *testing.T) {
var sent int32
dispatcher := push.New(&fakeStore{
devices: []push.PushDevice{
{ID: "row-1", Namespace: "myapp", UserID: "target-user", Provider: "fake", Token: "tok"},
},
}, zap.NewNop())
dispatcher.Register(&fakePushProvider{
name: "fake",
fn: func(ctx context.Context, msg push.PushMessage) error { atomic.AddInt32(&sent, 1); return nil },
})
h := newHandlers(&fakeStore{}, dispatcher)
body, _ := json.Marshal(SendRequest{
UserID: "target-user", Title: "hi", Body: "world",
})
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1")
rr := httptest.NewRecorder()
h.SendHandler(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d (body: %s)", rr.Code, rr.Body.String())
}
if atomic.LoadInt32(&sent) != 1 {
t.Errorf("expected provider called once, got %d", sent)
}
}
func TestSend_no_dispatcher_returns_503(t *testing.T) {
h := newHandlers(&fakeStore{}, nil)
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader([]byte(`{"user_id":"u"}`))), "myapp", "u1")
rr := httptest.NewRecorder()
h.SendHandler(rr, req)
if rr.Code != http.StatusServiceUnavailable {
t.Errorf("expected 503, got %d", rr.Code)
}
}
func TestSend_missing_user_id_returns_400(t *testing.T) {
dispatcher := push.New(&fakeStore{}, zap.NewNop())
h := newHandlers(&fakeStore{}, dispatcher)
body, _ := json.Marshal(SendRequest{})
req := withAuth(httptest.NewRequest(http.MethodPost, "/v1/push/send", bytes.NewReader(body)), "myapp", "u1")
rr := httptest.NewRecorder()
h.SendHandler(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rr.Code)
}
}
// --- helpers ---
type fakePushProvider struct {
name string
fn func(ctx context.Context, msg push.PushMessage) error
}
func (p *fakePushProvider) Name() string { return p.name }
func (p *fakePushProvider) Send(ctx context.Context, msg push.PushMessage) error {
if p.fn != nil {
return p.fn(ctx, msg)
}
return nil
}
func TestExtractIDFromPath(t *testing.T) {
cases := []struct {
path, prefix, want string
}{
{"/v1/push/devices/abc", "/v1/push/devices/", "abc"},
{"/v1/push/devices/abc?x=1", "/v1/push/devices/", "abc"},
{"/v1/push/devices/", "/v1/push/devices/", ""},
{"/v1/other/abc", "/v1/push/devices/", ""},
}
for _, c := range cases {
if got := extractIDFromPath(c.path, c.prefix); got != c.want {
t.Errorf("extractIDFromPath(%q, %q) = %q, want %q", c.path, c.prefix, got, c.want)
}
}
}

View File

@ -0,0 +1,150 @@
// Package push provides HTTP handlers for managing push-notification
// device registrations and sending pushes.
//
// Endpoints:
//
// GET /v1/push/devices — list caller's registered devices (tokens omitted)
// POST /v1/push/devices — register / update a device
// DELETE /v1/push/devices/{id} — unregister a device
// POST /v1/push/send — send a push to a user (admin/internal scope)
//
// Device tokens are stored AES-256-GCM-encrypted in RQLite via the
// pkg/push.RqliteDeviceStore. Tokens are NEVER returned by any endpoint —
// the GET endpoint omits the token field for safety.
package push
import (
"context"
"encoding/json"
"net/http"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/push"
)
// Handlers serves the /v1/push/* HTTP endpoints. Construct via NewHandlers;
// it's safe for concurrent use.
type Handlers struct {
dispatcher *push.PushDispatcher
store push.PushDeviceStore
logger *logging.ColoredLogger
}
// NewHandlers constructs a Handlers. Either argument may be nil — in which
// case the corresponding endpoints return 503 Service Unavailable.
func NewHandlers(dispatcher *push.PushDispatcher, store push.PushDeviceStore, logger *logging.ColoredLogger) *Handlers {
return &Handlers{
dispatcher: dispatcher,
store: store,
logger: logger,
}
}
// RegisterDeviceRequest is the body of POST /v1/push/devices.
//
// `device_id` is an app-supplied stable identifier (e.g. the OS-assigned
// device UUID). Combined with (namespace, user_id) it uniquely identifies
// the registration; re-posting with the same device_id updates the token.
//
// `token` is provider-specific:
// - ntfy: the topic path the device subscribes to (e.g. "ns/myapp/user-1")
// - expo: an ExponentPushToken[...]
// - apns: a hex APNs device token (future)
type RegisterDeviceRequest struct {
DeviceID string `json:"device_id"`
Provider string `json:"provider"` // "ntfy" | "expo" | "apns"
Token string `json:"token"`
Platform string `json:"platform,omitempty"` // "ios" | "android" | "web"
AppVersion string `json:"app_version,omitempty"`
}
// RegisterDeviceResponse is the body of POST /v1/push/devices.
type RegisterDeviceResponse struct {
Status string `json:"status"`
}
// PushDeviceView is the safe (token-omitting) representation returned
// by GET /v1/push/devices.
type PushDeviceView struct {
ID string `json:"id"`
DeviceID string `json:"device_id"`
Provider string `json:"provider"`
Platform string `json:"platform,omitempty"`
AppVersion string `json:"app_version,omitempty"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
LastSeen int64 `json:"last_seen,omitempty"`
}
// SendRequest is the body of POST /v1/push/send.
//
// The dispatcher fans out to all of `user_id`'s registered devices in
// the caller's namespace. Auth scope: see SendHandler — currently
// requires the caller to act on behalf of their own namespace; finer
// per-user authorization is the app's responsibility.
type SendRequest struct {
UserID string `json:"user_id"`
Title string `json:"title"`
Body string `json:"body"`
Channel string `json:"channel,omitempty"`
Priority string `json:"priority,omitempty"` // "high" | "normal" | "" (default)
Badge int `json:"badge,omitempty"`
Sound string `json:"sound,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
}
// SendResponse is the body of POST /v1/push/send.
type SendResponse struct {
Status string `json:"status"`
}
// resolveNamespace pulls the namespace set by auth middleware out of context.
func resolveNamespace(r *http.Request) string {
if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
// resolveCallerUserID extracts the JWT subject (typically the wallet) of
// the caller, or empty if the request was authenticated by API key only.
func resolveCallerUserID(r *http.Request) string {
if v := r.Context().Value(ctxkeys.JWT); v != nil {
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
return claims.Sub
}
}
return ""
}
func writeError(w http.ResponseWriter, code int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(map[string]string{"error": message})
}
func writeJSON(w http.ResponseWriter, code int, v interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
_ = json.NewEncoder(w).Encode(v)
}
// pickPriority maps the wire-format priority string to the typed enum.
func pickPriority(s string) push.PushPriority {
switch s {
case "high":
return push.PriorityHigh
case "normal":
return push.PriorityNormal
default:
return push.PriorityNormal
}
}
// boundCtx returns a request-scoped context with no extra wrapping;
// kept as a seam for future scope (rate-limit context etc.).
func boundCtx(r *http.Request) context.Context { return r.Context() }

View File

@ -42,6 +42,11 @@ func (h *WebRTCHandlers) CredentialsHandler(w http.ResponseWriter, r *http.Reque
fmt.Sprintf("turns:%s:5349", h.turnDomain),
)
}
// Stealth: TURNS via the SNI router on :443. Looks like ordinary HTTPS
// to a passive observer / DPI; usable in restricted regions.
if h.stealthCDNDomain != "" {
uris = append(uris, fmt.Sprintf("turns:%s:443", h.stealthCDNDomain))
}
h.logger.ComponentInfo(logging.ComponentGeneral, "Issued TURN credentials",
zap.String("namespace", ns),

View File

@ -17,10 +17,21 @@ type WebRTCHandlers struct {
turnDomain string // TURN server domain for building URIs
turnSecret string // HMAC-SHA1 shared secret for TURN credential generation
// stealthCDNDomain, when non-empty, causes CredentialsHandler to also
// advertise turns://<stealthCDNDomain>:443 — the stealth TURN URI served
// via the in-house SNI router. See pkg/sniproxy.
stealthCDNDomain string
// proxyWebSocket is injected from the gateway to reuse its WebSocket proxy logic
proxyWebSocket func(w http.ResponseWriter, r *http.Request, targetHost string) bool
}
// SetStealthCDNDomain enables the stealth TURN URI in CredentialsHandler.
// Pass empty string to disable. Safe to call before serving begins.
func (h *WebRTCHandlers) SetStealthCDNDomain(domain string) {
h.stealthCDNDomain = domain
}
// NewWebRTCHandlers creates a new WebRTCHandlers instance.
func NewWebRTCHandlers(
logger *logging.ColoredLogger,

View File

@ -12,6 +12,23 @@ import (
// It closes the serverless engine, network client, database connections,
// Olric cache client, and IPFS client in sequence.
func (g *Gateway) Close() {
// Flush PubSub aggregator buffers before tearing down the engine.
// Pending events are dispatched via the invoker which still needs the
// engine to be alive, so this MUST happen before the engine close.
// Aggregator state is local to this node — events not flushed here are
// lost (intended trade-off for high-frequency lossy streams).
if g.pubsubDispatcher != nil {
if agg := g.pubsubDispatcher.Aggregator(); agg != nil {
// 5s budget — same as the engine close timeout below.
// In-flight flushes call back into the invoker which still
// needs the engine to be alive.
if !agg.Shutdown(5 * time.Second) {
g.logger.ComponentWarn(logging.ComponentGeneral,
"PubSub aggregator shutdown timed out; some buffered events may be lost")
}
}
}
// Close serverless engine first
if g.serverlessEngine != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

View File

@ -643,6 +643,9 @@ func requiresNamespaceOwnership(p string) bool {
if strings.HasPrefix(p, "/v1/webrtc/") {
return true
}
if strings.HasPrefix(p, "/v1/push/") {
return true
}
return false
}

View File

@ -119,10 +119,30 @@ func (g *Gateway) Routes() http.Handler {
if g.pubsubHandlers != nil {
mux.HandleFunc("/v1/pubsub/ws", g.pubsubHandlers.WebsocketHandler)
mux.HandleFunc("/v1/pubsub/publish", g.pubsubHandlers.PublishHandler)
mux.HandleFunc("/v1/pubsub/publish-batch", g.pubsubHandlers.PublishBatchHandler)
mux.HandleFunc("/v1/pubsub/topics", g.pubsubHandlers.TopicsHandler)
mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler)
}
// push notifications
if g.pushHandlers != nil {
// GET + POST share the path; the handler dispatches by method.
mux.HandleFunc("/v1/push/devices", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
g.pushHandlers.ListDevicesHandler(w, r)
case http.MethodPost:
g.pushHandlers.RegisterDeviceHandler(w, r)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
})
// DELETE /v1/push/devices/{id} — uses path-prefix routing because
// net/http mux doesn't extract path params; the handler parses {id}.
mux.HandleFunc("/v1/push/devices/", g.pushHandlers.DeleteDeviceHandler)
mux.HandleFunc("/v1/push/send", g.pushHandlers.SendHandler)
}
// operator node management (wallet JWT auth via middleware)
if g.operatorHandler != nil {
mux.HandleFunc("/v1/operator/invite", g.operatorHandler.HandleInvite)

View File

@ -57,6 +57,7 @@ const (
ComponentGateway Component = "GATEWAY"
ComponentSFU Component = "SFU"
ComponentTURN Component = "TURN"
ComponentSNI Component = "SNI"
)
// getComponentColor returns the color for a specific component

View File

@ -29,6 +29,18 @@ func (a *ClientAdapter) Publish(ctx context.Context, topic string, data []byte)
return a.manager.Publish(ctx, topic, data)
}
// PublishBatch publishes multiple messages in parallel.
// See Manager.PublishBatch for semantics.
func (a *ClientAdapter) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
return a.manager.PublishBatch(ctx, msgs, opts)
}
// PublishSame sends the same payload to every topic in parallel.
// See Manager.PublishSame for semantics.
func (a *ClientAdapter) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
return a.manager.PublishSame(ctx, topics, data, opts)
}
// Unsubscribe unsubscribes from a topic
func (a *ClientAdapter) Unsubscribe(ctx context.Context, topic string) error {
return a.manager.Unsubscribe(ctx, topic)

View File

@ -3,9 +3,56 @@ package pubsub
import (
"context"
"fmt"
"strings"
"sync"
"time"
"golang.org/x/sync/errgroup"
)
// defaultBatchConcurrency is the default cap on in-flight publishes within a single batch.
const defaultBatchConcurrency = 32
// MaxBatchSize is the maximum number of messages allowed per PublishBatch call.
// The HTTP handler enforces this; the Manager itself does not, so internal callers
// (e.g. the SDK) can pass larger batches if they accept the responsibility.
const MaxBatchSize = 100
// TopicMessage is one entry in a batch publish.
type TopicMessage struct {
Topic string
Data []byte
}
// PublishBatchOptions controls batch publish behavior.
type PublishBatchOptions struct {
// BestEffort, if true, attempts every publish even when some fail and returns
// a *BatchError summarizing per-topic failures. Default (false) is fail-fast:
// the first failure cancels remaining in-flight publishes and returns that error.
BestEffort bool
// MaxConcurrency caps the number of in-flight publishes within this batch.
// 0 means use defaultBatchConcurrency.
MaxConcurrency int
}
// BatchError aggregates per-topic errors returned when PublishBatch is called
// with BestEffort=true and at least one publish failed.
type BatchError struct {
Errors map[string]error // topic -> error
}
func (e *BatchError) Error() string {
if len(e.Errors) == 0 {
return "batch publish: no errors"
}
names := make([]string, 0, len(e.Errors))
for t := range e.Errors {
names = append(names, t)
}
return fmt.Sprintf("batch publish: %d topic(s) failed: %s", len(e.Errors), strings.Join(names, ", "))
}
// Publish publishes a message to a topic
func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error {
if m.pubsub == nil {
@ -58,3 +105,90 @@ func (m *Manager) Publish(ctx context.Context, topic string, data []byte) error
return nil
}
// PublishBatch publishes multiple messages in parallel, one per topic.
// Default behavior is fail-fast: the first publish error cancels remaining work
// and is returned. If opts.BestEffort is set, every publish is attempted and a
// *BatchError is returned if any failed.
//
// Concurrency is bounded by opts.MaxConcurrency (default 32).
// Empty msgs slice is a no-op and returns nil.
func (m *Manager) PublishBatch(ctx context.Context, msgs []TopicMessage, opts PublishBatchOptions) error {
if len(msgs) == 0 {
return nil
}
if m.pubsub == nil {
return fmt.Errorf("pubsub not initialized")
}
maxConc := opts.MaxConcurrency
if maxConc <= 0 {
maxConc = defaultBatchConcurrency
}
if maxConc > len(msgs) {
maxConc = len(msgs)
}
sem := make(chan struct{}, maxConc)
if !opts.BestEffort {
g, gctx := errgroup.WithContext(ctx)
for _, msg := range msgs {
msg := msg
g.Go(func() error {
select {
case sem <- struct{}{}:
case <-gctx.Done():
return gctx.Err()
}
defer func() { <-sem }()
return m.Publish(gctx, msg.Topic, msg.Data)
})
}
return g.Wait()
}
// Best-effort path: attempt every publish, collect per-topic errors.
var (
wg sync.WaitGroup
errMu sync.Mutex
errMap = map[string]error{}
)
for _, msg := range msgs {
msg := msg
wg.Add(1)
go func() {
defer wg.Done()
select {
case sem <- struct{}{}:
case <-ctx.Done():
errMu.Lock()
errMap[msg.Topic] = ctx.Err()
errMu.Unlock()
return
}
defer func() { <-sem }()
if err := m.Publish(ctx, msg.Topic, msg.Data); err != nil {
errMu.Lock()
errMap[msg.Topic] = err
errMu.Unlock()
}
}()
}
wg.Wait()
if len(errMap) > 0 {
return &BatchError{Errors: errMap}
}
return nil
}
// PublishSame is a convenience wrapper that sends the same payload to every topic.
func (m *Manager) PublishSame(ctx context.Context, topics []string, data []byte, opts PublishBatchOptions) error {
if len(topics) == 0 {
return nil
}
msgs := make([]TopicMessage, len(topics))
for i, t := range topics {
msgs[i] = TopicMessage{Topic: t, Data: data}
}
return m.PublishBatch(ctx, msgs, opts)
}

View File

@ -0,0 +1,196 @@
package pubsub
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestPublishBatch_empty_slice_returns_nil(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
if err := mgr.PublishBatch(context.Background(), nil, PublishBatchOptions{}); err != nil {
t.Fatalf("expected nil error for empty slice, got: %v", err)
}
if err := mgr.PublishBatch(context.Background(), []TopicMessage{}, PublishBatchOptions{}); err != nil {
t.Fatalf("expected nil error for empty slice, got: %v", err)
}
}
func TestPublishBatch_happy_path(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
msgs := []TopicMessage{
{Topic: "a", Data: []byte("data-a")},
{Topic: "b", Data: []byte("data-b")},
{Topic: "c", Data: []byte("data-c")},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil {
t.Fatalf("PublishBatch failed: %v", err)
}
}
func TestPublishSame_uses_same_payload(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
topics := []string{"x", "y", "z"}
data := []byte("shared")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := mgr.PublishSame(ctx, topics, data, PublishBatchOptions{}); err != nil {
t.Fatalf("PublishSame failed: %v", err)
}
}
func TestPublishSame_empty_returns_nil(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
if err := mgr.PublishSame(context.Background(), nil, []byte("x"), PublishBatchOptions{}); err != nil {
t.Fatalf("expected nil for empty topics, got: %v", err)
}
}
func TestPublishBatch_context_cancel_returns_error(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
msgs := make([]TopicMessage, 50)
for i := range msgs {
msgs[i] = TopicMessage{Topic: fmt.Sprintf("topic-%d", i), Data: []byte("d")}
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{})
if err == nil {
t.Fatal("expected context.Canceled error, got nil")
}
if !errors.Is(err, context.Canceled) {
t.Logf("got error (acceptable as long as it's an error): %v", err)
}
}
func TestPublishBatch_concurrency_limit(t *testing.T) {
// Verify PublishBatch with low MaxConcurrency completes without deadlocking.
// Each Publish in a no-peer test environment waits up to 2s for mesh formation,
// so we use a small batch size to keep wall time bounded.
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
msgs := make([]TopicMessage, 8)
for i := range msgs {
msgs[i] = TopicMessage{Topic: fmt.Sprintf("ct-%d", i), Data: []byte("d")}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 2}); err != nil {
t.Fatalf("PublishBatch with low concurrency failed: %v", err)
}
}
// TestPublishBatch_caps_concurrency_above_msg_count verifies that MaxConcurrency
// is clamped to len(msgs) — passing 100 with 3 messages should not panic on
// channel capacity.
func TestPublishBatch_caps_concurrency_above_msg_count(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
msgs := []TopicMessage{
{Topic: "a", Data: []byte("1")},
{Topic: "b", Data: []byte("2")},
{Topic: "c", Data: []byte("3")},
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{MaxConcurrency: 100}); err != nil {
t.Fatalf("PublishBatch failed: %v", err)
}
}
func TestBatchError_Error_summarizes(t *testing.T) {
be := &BatchError{Errors: map[string]error{
"topic-a": errors.New("boom"),
"topic-b": errors.New("kaboom"),
}}
s := be.Error()
if s == "" {
t.Fatal("expected non-empty error string")
}
// Should mention both topics.
if !contains(s, "topic-a") || !contains(s, "topic-b") {
t.Errorf("error string %q should mention both failing topics", s)
}
}
func TestBatchError_Error_empty_map(t *testing.T) {
be := &BatchError{}
if s := be.Error(); s == "" {
t.Fatal("expected non-empty string even for empty map")
}
}
// contains is a tiny helper to avoid importing strings just for this.
func contains(s, substr string) bool {
for i := 0; i+len(substr) <= len(s); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// TestPublishBatch_concurrent_publishes_thread_safe ensures concurrent
// PublishBatch invocations don't race on internal state.
func TestPublishBatch_concurrent_publishes_thread_safe(t *testing.T) {
mgr, cleanup := createTestManager(t, "test-ns")
defer cleanup()
const goroutines = 8
const msgsPerGoroutine = 5
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
var wg sync.WaitGroup
var failures int64
for g := 0; g < goroutines; g++ {
wg.Add(1)
go func(gid int) {
defer wg.Done()
msgs := make([]TopicMessage, msgsPerGoroutine)
for i := range msgs {
msgs[i] = TopicMessage{
Topic: fmt.Sprintf("g%d-t%d", gid, i),
Data: []byte("d"),
}
}
if err := mgr.PublishBatch(ctx, msgs, PublishBatchOptions{}); err != nil {
atomic.AddInt64(&failures, 1)
t.Logf("goroutine %d failed: %v", gid, err)
}
}(g)
}
wg.Wait()
if failures > 0 {
t.Errorf("%d concurrent batches failed", failures)
}
}

View File

@ -0,0 +1,172 @@
package push
import (
"context"
"fmt"
"time"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/secrets"
"github.com/google/uuid"
"go.uber.org/zap"
)
// SecretsKeyPurpose is the HKDF purpose string for push token encryption.
// Used in pkg/secrets.DeriveKey to derive a domain-separated AES key.
const SecretsKeyPurpose = "push-device-tokens"
// RqliteDeviceStore is a PushDeviceStore backed by RQLite + AES-256-GCM
// at-rest encryption of the push token.
type RqliteDeviceStore struct {
db rqlite.Client
encKey []byte // derived once at construction
logger *zap.Logger
}
// NewRqliteDeviceStore derives the per-cluster encryption key from the
// cluster secret and returns a ready-to-use store. The cluster secret is
// the same one used for other at-rest encryption (see pkg/secrets).
func NewRqliteDeviceStore(db rqlite.Client, clusterSecret string, logger *zap.Logger) (*RqliteDeviceStore, error) {
if logger == nil {
logger = zap.NewNop()
}
key, err := secrets.DeriveKey(clusterSecret, SecretsKeyPurpose)
if err != nil {
return nil, fmt.Errorf("derive push-device key: %w", err)
}
return &RqliteDeviceStore{
db: db,
encKey: key,
logger: logger.Named("push-store"),
}, nil
}
// deviceRow is the scan target for SELECT queries.
type deviceRow struct {
ID string
Namespace string
UserID string
DeviceID string
Provider string
TokenEncrypted string
Platform string
AppVersion string
CreatedAt int64
UpdatedAt int64
LastSeen int64
}
// Upsert implements PushDeviceStore.
func (s *RqliteDeviceStore) Upsert(ctx context.Context, dev PushDevice) error {
if dev.Namespace == "" || dev.UserID == "" || dev.DeviceID == "" {
return fmt.Errorf("namespace, user_id, device_id required")
}
if dev.Provider == "" {
return fmt.Errorf("provider required")
}
if dev.Token == "" {
return ErrEmptyToken
}
encToken, err := secrets.Encrypt(dev.Token, s.encKey)
if err != nil {
return fmt.Errorf("encrypt token: %w", err)
}
now := time.Now().Unix()
if dev.CreatedAt == 0 {
dev.CreatedAt = now
}
dev.UpdatedAt = now
id := dev.ID
if id == "" {
id = uuid.New().String()
}
// SQLite UPSERT keyed on (namespace, user_id, device_id) per the migration's
// UNIQUE constraint. On conflict we replace token + provider + metadata
// while preserving the original id and created_at.
query := `
INSERT INTO push_devices
(id, namespace, user_id, device_id, provider, token_encrypted,
platform, app_version, created_at, updated_at, last_seen)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(namespace, user_id, device_id) DO UPDATE SET
provider = excluded.provider,
token_encrypted = excluded.token_encrypted,
platform = excluded.platform,
app_version = excluded.app_version,
updated_at = excluded.updated_at,
last_seen = excluded.last_seen
`
_, err = s.db.Exec(ctx, query,
id, dev.Namespace, dev.UserID, dev.DeviceID, dev.Provider, encToken,
dev.Platform, dev.AppVer, dev.CreatedAt, dev.UpdatedAt, dev.LastSeen,
)
if err != nil {
return fmt.Errorf("upsert push device: %w", err)
}
return nil
}
// Delete implements PushDeviceStore.
func (s *RqliteDeviceStore) Delete(ctx context.Context, namespace, id string) error {
if namespace == "" || id == "" {
return fmt.Errorf("namespace and id required")
}
query := `DELETE FROM push_devices WHERE id = ? AND namespace = ?`
res, err := s.db.Exec(ctx, query, id, namespace)
if err != nil {
return fmt.Errorf("delete push device: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return fmt.Errorf("push device not found: %s", id)
}
return nil
}
// ListForUser implements PushDeviceStore. Returns devices with decrypted tokens.
// Caller MUST treat tokens as sensitive.
func (s *RqliteDeviceStore) ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error) {
if namespace == "" || userID == "" {
return nil, nil
}
query := `
SELECT id, namespace, user_id, device_id, provider, token_encrypted,
COALESCE(platform, ''), COALESCE(app_version, ''),
created_at, updated_at, COALESCE(last_seen, 0)
FROM push_devices
WHERE namespace = ? AND user_id = ?
`
var rows []deviceRow
if err := s.db.Query(ctx, &rows, query, namespace, userID); err != nil {
return nil, fmt.Errorf("query push devices: %w", err)
}
out := make([]PushDevice, 0, len(rows))
for _, r := range rows {
token, err := secrets.Decrypt(r.TokenEncrypted, s.encKey)
if err != nil {
s.logger.Warn("failed to decrypt push token; skipping device",
zap.String("device_id", r.DeviceID),
zap.Error(err))
continue
}
out = append(out, PushDevice{
ID: r.ID,
Namespace: r.Namespace,
UserID: r.UserID,
DeviceID: r.DeviceID,
Provider: r.Provider,
Token: token,
Platform: r.Platform,
AppVer: r.AppVersion,
CreatedAt: r.CreatedAt,
UpdatedAt: r.UpdatedAt,
LastSeen: r.LastSeen,
})
}
return out, nil
}

View File

@ -0,0 +1,97 @@
package push
import (
"context"
"fmt"
"sync"
"go.uber.org/zap"
)
// PushDispatcher routes push messages to the matching provider for each
// of a user's registered devices.
type PushDispatcher struct {
mu sync.RWMutex
providers map[string]PushProvider
devices PushDeviceStore
logger *zap.Logger
}
// New creates a dispatcher with the given device store. Register
// providers before sending.
func New(devices PushDeviceStore, logger *zap.Logger) *PushDispatcher {
if logger == nil {
logger = zap.NewNop()
}
return &PushDispatcher{
providers: map[string]PushProvider{},
devices: devices,
logger: logger.Named("push"),
}
}
// Register makes a provider available to dispatch. Calling Register with
// the same name twice replaces the previous provider — useful in tests.
func (d *PushDispatcher) Register(p PushProvider) {
d.mu.Lock()
defer d.mu.Unlock()
d.providers[p.Name()] = p
}
// Provider returns the registered provider by name, or nil.
func (d *PushDispatcher) Provider(name string) PushProvider {
d.mu.RLock()
defer d.mu.RUnlock()
return d.providers[name]
}
// SendToUser fans out the message to every registered device for the
// user. Each provider failure is logged but does not stop subsequent
// devices. Returns the first encountered error (if any) so callers can
// surface a partial-failure signal.
//
// SendToUser returns nil if the user has no registered devices — that
// is normal, not an error.
func (d *PushDispatcher) SendToUser(
ctx context.Context,
namespace, userID string,
msg PushMessage,
) error {
devs, err := d.devices.ListForUser(ctx, namespace, userID)
if err != nil {
return fmt.Errorf("list devices: %w", err)
}
if len(devs) == 0 {
return nil
}
var firstErr error
for _, dev := range devs {
d.mu.RLock()
p, ok := d.providers[dev.Provider]
d.mu.RUnlock()
if !ok {
d.logger.Warn("push: dropping device with unregistered provider",
zap.String("provider", dev.Provider),
zap.String("device_id", dev.DeviceID),
)
if firstErr == nil {
firstErr = fmt.Errorf("%w: %s", ErrUnknownProvider, dev.Provider)
}
continue
}
m := msg
m.DeviceToken = dev.Token
if err := p.Send(ctx, m); err != nil {
d.logger.Warn("push: provider send failed",
zap.String("provider", dev.Provider),
zap.String("device_id", dev.DeviceID),
zap.Error(err),
)
if firstErr == nil {
firstErr = err
}
}
}
return firstErr
}

View File

@ -0,0 +1,149 @@
package push
import (
"context"
"errors"
"sync/atomic"
"testing"
"go.uber.org/zap"
)
// fakeProvider records every Send call.
type fakeProvider struct {
name string
sent int32
lastToken string
err error
}
func (f *fakeProvider) Name() string { return f.name }
func (f *fakeProvider) Send(ctx context.Context, msg PushMessage) error {
atomic.AddInt32(&f.sent, 1)
f.lastToken = msg.DeviceToken
return f.err
}
// fakeStore is an in-memory PushDeviceStore.
type fakeStore struct {
devices []PushDevice
err error
}
func (s *fakeStore) Upsert(ctx context.Context, dev PushDevice) error {
if s.err != nil {
return s.err
}
s.devices = append(s.devices, dev)
return nil
}
func (s *fakeStore) Delete(ctx context.Context, ns, id string) error { return nil }
func (s *fakeStore) ListForUser(ctx context.Context, ns, userID string) ([]PushDevice, error) {
if s.err != nil {
return nil, s.err
}
out := []PushDevice{}
for _, d := range s.devices {
if d.Namespace == ns && d.UserID == userID {
out = append(out, d)
}
}
return out, nil
}
func TestSendToUser_no_devices_returns_nil(t *testing.T) {
d := New(&fakeStore{}, zap.NewNop())
if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "x"}); err != nil {
t.Fatalf("expected nil for no devices, got: %v", err)
}
}
func TestSendToUser_routes_to_correct_provider(t *testing.T) {
store := &fakeStore{devices: []PushDevice{
{Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "ntfy-tok"},
{Namespace: "ns", UserID: "u", Provider: "expo", Token: "expo-tok"},
}}
ntfy := &fakeProvider{name: "ntfy"}
expo := &fakeProvider{name: "expo"}
d := New(store, zap.NewNop())
d.Register(ntfy)
d.Register(expo)
if err := d.SendToUser(context.Background(), "ns", "u", PushMessage{Title: "hi"}); err != nil {
t.Fatalf("SendToUser: %v", err)
}
if atomic.LoadInt32(&ntfy.sent) != 1 || ntfy.lastToken != "ntfy-tok" {
t.Errorf("ntfy provider not called correctly: sent=%d token=%s", ntfy.sent, ntfy.lastToken)
}
if atomic.LoadInt32(&expo.sent) != 1 || expo.lastToken != "expo-tok" {
t.Errorf("expo provider not called correctly: sent=%d token=%s", expo.sent, expo.lastToken)
}
}
func TestSendToUser_unknown_provider_returns_error_continues(t *testing.T) {
store := &fakeStore{devices: []PushDevice{
{Namespace: "ns", UserID: "u", Provider: "ghost", Token: "tok"},
{Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "real"},
}}
ntfy := &fakeProvider{name: "ntfy"}
d := New(store, zap.NewNop())
d.Register(ntfy)
err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
if err == nil {
t.Fatal("expected error for unknown provider")
}
if !errors.Is(err, ErrUnknownProvider) {
t.Errorf("expected ErrUnknownProvider, got %v", err)
}
// ntfy should still have been called.
if atomic.LoadInt32(&ntfy.sent) != 1 {
t.Error("ntfy should have been called for the second device")
}
}
func TestSendToUser_provider_failure_returned_but_other_devices_still_processed(t *testing.T) {
store := &fakeStore{devices: []PushDevice{
{Namespace: "ns", UserID: "u", Provider: "expo", Token: "tok-1"},
{Namespace: "ns", UserID: "u", Provider: "ntfy", Token: "tok-2"},
}}
expoErr := errors.New("expo down")
expo := &fakeProvider{name: "expo", err: expoErr}
ntfy := &fakeProvider{name: "ntfy"}
d := New(store, zap.NewNop())
d.Register(expo)
d.Register(ntfy)
err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
if !errors.Is(err, expoErr) {
t.Errorf("expected expo error, got %v", err)
}
if atomic.LoadInt32(&ntfy.sent) != 1 {
t.Error("ntfy should have been called even though expo failed")
}
}
func TestSendToUser_store_error_propagated(t *testing.T) {
storeErr := errors.New("store boom")
d := New(&fakeStore{err: storeErr}, zap.NewNop())
err := d.SendToUser(context.Background(), "ns", "u", PushMessage{})
if err == nil || !errors.Is(err, storeErr) {
t.Errorf("expected store error, got %v", err)
}
}
func TestRegister_replaces_existing_provider(t *testing.T) {
d := New(&fakeStore{}, zap.NewNop())
a := &fakeProvider{name: "ntfy"}
b := &fakeProvider{name: "ntfy"}
d.Register(a)
d.Register(b)
if d.Provider("ntfy") != b {
t.Error("expected second Register to replace the first")
}
}

View File

@ -0,0 +1,160 @@
// Package expo wraps the Expo push relay as a push.PushProvider.
//
// This is a thin port of the legacy gateway.PushNotificationService —
// behaviour preserved, surface adapted to the provider abstraction.
//
// Long term Expo is intended to be replaced with direct APNs (iOS) +
// ntfy (Android). This provider exists so the gateway can keep using
// Expo while the migration happens.
package expo
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/DeBrosOfficial/network/pkg/push"
"go.uber.org/zap"
)
const expoAPIURL = "https://exp.host/--/api/v2/push/send"
// Config holds Expo provider settings.
type Config struct {
// AccessToken is an optional Expo access token. When set, it's sent
// as a Bearer token, which Expo uses for higher-priority delivery
// and to attribute the send to your account.
AccessToken string
// Timeout bounds each Send call. 0 selects 10 seconds (matching the
// previous PushNotificationService default).
Timeout time.Duration
}
// Provider is the Expo push.PushProvider implementation.
type Provider struct {
accessToken string
httpClient *http.Client
logger *zap.Logger
}
// New creates a Provider with the given config.
func New(cfg Config, logger *zap.Logger) *Provider {
if logger == nil {
logger = zap.NewNop()
}
timeout := cfg.Timeout
if timeout <= 0 {
timeout = 10 * time.Second
}
return &Provider{
accessToken: cfg.AccessToken,
httpClient: &http.Client{Timeout: timeout},
logger: logger.Named("expo"),
}
}
// Name implements push.PushProvider.
func (p *Provider) Name() string { return "expo" }
// expoMessage matches the wire format Expo expects.
type expoMessage struct {
To string `json:"to"`
Title string `json:"title,omitempty"`
Body string `json:"body,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
Sound string `json:"sound,omitempty"`
Badge int `json:"badge,omitempty"`
Priority string `json:"priority,omitempty"`
MutableContent bool `json:"mutableContent,omitempty"`
ChannelID string `json:"channelId,omitempty"`
}
// expoTicket is the per-message response.
type expoTicket struct {
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
type expoResponse struct {
Data []expoTicket `json:"data"`
}
// Send delivers a push via the Expo relay.
func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
if msg.DeviceToken == "" {
return push.ErrEmptyToken
}
priority := "default"
if msg.Priority == push.PriorityHigh {
priority = "high"
}
wire := expoMessage{
To: msg.DeviceToken,
Title: msg.Title,
Body: msg.Body,
Data: msg.Data,
Sound: msg.Sound,
Badge: msg.Badge,
Priority: priority,
MutableContent: true, // for iOS Notification Service Extension
ChannelID: msg.Channel,
}
if wire.Sound == "" {
wire.Sound = "default"
}
body, err := json.Marshal(wire)
if err != nil {
return fmt.Errorf("expo: marshal: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, expoAPIURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("expo: build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
if p.accessToken != "" {
req.Header.Set("Authorization", "Bearer "+p.accessToken)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return fmt.Errorf("expo: post: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 16<<10))
if err != nil {
return fmt.Errorf("expo: read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("expo: http %d: %s", resp.StatusCode, string(respBody))
}
var er expoResponse
if err := json.Unmarshal(respBody, &er); err != nil {
// Older Expo responses sometimes return a bare array; try that fallback.
var tickets []expoTicket
if err2 := json.Unmarshal(respBody, &tickets); err2 == nil {
er.Data = tickets
} else {
return fmt.Errorf("expo: parse response: %w", err)
}
}
for _, t := range er.Data {
if t.Status != "" && t.Status != "ok" {
return fmt.Errorf("expo: ticket status %q: %s", t.Status, t.Message)
}
}
return nil
}

View File

@ -0,0 +1,118 @@
package expo
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/DeBrosOfficial/network/pkg/push"
)
// roundTripFunc lets us mock http.Client transport for the Expo provider so
// we can assert against requests without hitting exp.host.
type roundTripFunc func(req *http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) }
func newTestProvider(rt roundTripFunc) *Provider {
p := New(Config{}, nil)
p.httpClient.Transport = rt
return p
}
func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) {
p := New(Config{}, nil)
err := p.Send(context.Background(), push.PushMessage{Body: "x"})
if err != push.ErrEmptyToken {
t.Errorf("expected ErrEmptyToken, got %v", err)
}
}
func TestSend_happy_path(t *testing.T) {
var gotPayload map[string]interface{}
var gotAuth string
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
gotAuth = req.Header.Get("Authorization")
body, _ := io.ReadAll(req.Body)
_ = json.Unmarshal(body, &gotPayload)
resp := httptest.NewRecorder()
resp.WriteHeader(200)
_, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`)
return resp.Result(), nil
})
p.accessToken = "secret-token"
err := p.Send(context.Background(), push.PushMessage{
DeviceToken: "ExponentPushToken[abc]",
Title: "T", Body: "B",
Priority: push.PriorityHigh,
})
if err != nil {
t.Fatalf("Send failed: %v", err)
}
if gotAuth != "Bearer secret-token" {
t.Errorf("auth header wrong: %s", gotAuth)
}
if gotPayload["to"] != "ExponentPushToken[abc]" {
t.Errorf("to wrong: %v", gotPayload["to"])
}
if gotPayload["priority"] != "high" {
t.Errorf("priority wrong: %v", gotPayload["priority"])
}
}
func TestSend_ticket_error_returns_error(t *testing.T) {
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
resp.WriteHeader(200)
_, _ = resp.WriteString(`{"data":[{"status":"error","message":"DeviceNotRegistered"}]}`)
return resp.Result(), nil
})
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"})
if err == nil {
t.Fatal("expected error for ticket failure")
}
}
func TestSend_http_error_returns_error(t *testing.T) {
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
resp := httptest.NewRecorder()
resp.WriteHeader(500)
_, _ = resp.WriteString(`upstream broken`)
return resp.Result(), nil
})
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "x", Body: "y"})
if err == nil {
t.Fatal("expected error for HTTP 500")
}
}
func TestSend_normal_priority_maps_to_default(t *testing.T) {
var gotPayload map[string]interface{}
p := newTestProvider(func(req *http.Request) (*http.Response, error) {
body, _ := io.ReadAll(req.Body)
_ = json.Unmarshal(body, &gotPayload)
resp := httptest.NewRecorder()
resp.WriteHeader(200)
_, _ = resp.WriteString(`{"data":[{"status":"ok"}]}`)
return resp.Result(), nil
})
if err := p.Send(context.Background(), push.PushMessage{
DeviceToken: "x", Body: "y", Priority: push.PriorityNormal,
}); err != nil {
t.Fatal(err)
}
if gotPayload["priority"] != "default" {
t.Errorf("expected priority=default, got %v", gotPayload["priority"])
}
}
func TestName(t *testing.T) {
if New(Config{}, nil).Name() != "expo" {
t.Error("expected Name=expo")
}
}

View File

@ -0,0 +1,132 @@
// Package ntfy implements a push.PushProvider backed by an ntfy server.
//
// ntfy delivers notifications via plain HTTP POST to <baseURL>/<topic>.
// We map PushMessage fields to ntfy headers:
// - Title -> "Title"
// - Priority -> "Priority"
// - Channel -> "Tags"
// - Data -> base64-encoded JSON in "X-Data"
//
// See https://docs.ntfy.sh/publish/#publish-as-json for details.
package ntfy
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/push"
"go.uber.org/zap"
)
// Config holds per-provider settings.
type Config struct {
// BaseURL is the ntfy HTTP endpoint (e.g. "http://localhost:8080" or
// "https://push.example.com"). Trailing slash is tolerated.
BaseURL string
// AuthToken is an optional per-namespace bearer token. Leave empty to
// disable authentication.
AuthToken string
// Timeout bounds each Send call. 0 selects 5 seconds.
Timeout time.Duration
}
// Provider is the ntfy push.PushProvider implementation.
type Provider struct {
baseURL string
authToken string
httpClient *http.Client
logger *zap.Logger
}
// New creates a Provider with the given config.
func New(cfg Config, logger *zap.Logger) *Provider {
if logger == nil {
logger = zap.NewNop()
}
timeout := cfg.Timeout
if timeout <= 0 {
timeout = 5 * time.Second
}
return &Provider{
baseURL: strings.TrimRight(cfg.BaseURL, "/"),
authToken: cfg.AuthToken,
httpClient: &http.Client{Timeout: timeout},
logger: logger.Named("ntfy"),
}
}
// Name implements push.PushProvider.
func (p *Provider) Name() string { return "ntfy" }
// Send delivers a push notification to the device's ntfy topic.
func (p *Provider) Send(ctx context.Context, msg push.PushMessage) error {
if msg.DeviceToken == "" {
return push.ErrEmptyToken
}
if p.baseURL == "" {
return fmt.Errorf("ntfy: base URL not configured")
}
// URL-escape each path segment of the device token. ntfy topics can be
// hierarchical (e.g. "ns/myapp/user-1") and we want to preserve those
// '/' separators while escaping any other special characters that
// could let a malicious token escape the topic path.
parts := strings.Split(msg.DeviceToken, "/")
for i, p := range parts {
parts[i] = url.PathEscape(p)
}
endpointURL := p.baseURL + "/" + strings.Join(parts, "/")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(msg.Body))
if err != nil {
return fmt.Errorf("ntfy: build request: %w", err)
}
if msg.Title != "" {
req.Header.Set("Title", msg.Title)
}
if msg.Priority == push.PriorityHigh {
req.Header.Set("Priority", "high")
} else if msg.Priority == push.PriorityNormal {
req.Header.Set("Priority", "default")
}
if msg.Channel != "" {
// ntfy uses "Tags" for both visual emoji and operator-defined tags.
req.Header.Set("Tags", msg.Channel)
}
if msg.Badge > 0 {
req.Header.Set("X-Badge", fmt.Sprintf("%d", msg.Badge))
}
if len(msg.Data) > 0 {
b, err := json.Marshal(msg.Data)
if err != nil {
return fmt.Errorf("ntfy: marshal data: %w", err)
}
req.Header.Set("X-Data", base64.StdEncoding.EncodeToString(b))
}
if p.authToken != "" {
req.Header.Set("Authorization", "Bearer "+p.authToken)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return fmt.Errorf("ntfy: post: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return fmt.Errorf("ntfy: http %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
// Drain body to allow connection reuse.
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4096))
return nil
}

View File

@ -0,0 +1,191 @@
package ntfy
import (
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/DeBrosOfficial/network/pkg/push"
)
func TestSend_happy_path(t *testing.T) {
var (
gotPath string
gotBody string
gotTitle string
gotPriority string
gotAuth string
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotTitle = r.Header.Get("Title")
gotPriority = r.Header.Get("Priority")
gotAuth = r.Header.Get("Authorization")
b, _ := io.ReadAll(r.Body)
gotBody = string(b)
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := New(Config{BaseURL: srv.URL, AuthToken: "secret"}, nil)
err := p.Send(context.Background(), push.PushMessage{
DeviceToken: "ns/myapp/user-1",
Title: "Hello",
Body: "World",
Priority: push.PriorityHigh,
})
if err != nil {
t.Fatalf("Send failed: %v", err)
}
if gotPath != "/ns/myapp/user-1" {
t.Errorf("expected path /ns/myapp/user-1, got %s", gotPath)
}
if gotTitle != "Hello" {
t.Errorf("expected Title=Hello, got %s", gotTitle)
}
if gotPriority != "high" {
t.Errorf("expected Priority=high, got %s", gotPriority)
}
if gotAuth != "Bearer secret" {
t.Errorf("expected Authorization=Bearer secret, got %s", gotAuth)
}
if gotBody != "World" {
t.Errorf("expected body=World, got %s", gotBody)
}
}
func TestSend_includes_data_header_when_data_set(t *testing.T) {
var gotData string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotData = r.Header.Get("X-Data")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := New(Config{BaseURL: srv.URL}, nil)
err := p.Send(context.Background(), push.PushMessage{
DeviceToken: "topic",
Body: "x",
Data: map[string]interface{}{"call_id": "abc-123"},
})
if err != nil {
t.Fatalf("Send: %v", err)
}
decoded, err := base64.StdEncoding.DecodeString(gotData)
if err != nil {
t.Fatalf("X-Data not valid base64: %v", err)
}
var got map[string]interface{}
if err := json.Unmarshal(decoded, &got); err != nil {
t.Fatalf("X-Data not valid JSON: %v", err)
}
if got["call_id"] != "abc-123" {
t.Errorf("data round-trip failed: got %v", got)
}
}
func TestSend_no_data_no_data_header(t *testing.T) {
var gotData string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotData = r.Header.Get("X-Data")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := New(Config{BaseURL: srv.URL}, nil)
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
t.Fatal(err)
}
if gotData != "" {
t.Errorf("expected no X-Data header, got %q", gotData)
}
}
func TestSend_no_auth_header_when_token_empty(t *testing.T) {
var gotAuth string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
p := New(Config{BaseURL: srv.URL}, nil)
if err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"}); err != nil {
t.Fatal(err)
}
if gotAuth != "" {
t.Errorf("expected no Authorization header, got %q", gotAuth)
}
}
func TestSend_4xx_returns_error_with_body_excerpt(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("forbidden topic"))
}))
defer srv.Close()
p := New(Config{BaseURL: srv.URL}, nil)
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
if err == nil {
t.Fatal("expected error for 403")
}
if !strings.Contains(err.Error(), "403") || !strings.Contains(err.Error(), "forbidden") {
t.Errorf("error should mention status and body, got: %v", err)
}
}
func TestSend_empty_token_returns_ErrEmptyToken(t *testing.T) {
p := New(Config{BaseURL: "http://example.invalid"}, nil)
err := p.Send(context.Background(), push.PushMessage{Body: "x"})
if err == nil {
t.Fatal("expected error for empty token")
}
if err != push.ErrEmptyToken {
t.Errorf("expected ErrEmptyToken, got %v", err)
}
}
func TestSend_short_timeout_returns_error(t *testing.T) {
// Server that blocks for 2s — provider with 100ms timeout should give up.
blockUntil := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-blockUntil:
case <-time.After(2 * time.Second):
}
}))
defer func() { close(blockUntil); srv.Close() }()
p := New(Config{BaseURL: srv.URL, Timeout: 100 * time.Millisecond}, nil)
start := time.Now()
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
elapsed := time.Since(start)
if err == nil {
t.Error("expected timeout error")
}
if elapsed > 1*time.Second {
t.Errorf("expected fast timeout, took %v", elapsed)
}
}
func TestSend_no_baseURL_returns_error(t *testing.T) {
p := New(Config{}, nil)
err := p.Send(context.Background(), push.PushMessage{DeviceToken: "t", Body: "x"})
if err == nil {
t.Fatal("expected error for missing base URL")
}
}
func TestName(t *testing.T) {
p := New(Config{BaseURL: "http://x"}, nil)
if p.Name() != "ntfy" {
t.Errorf("expected Name=ntfy, got %s", p.Name())
}
}

91
core/pkg/push/types.go Normal file
View File

@ -0,0 +1,91 @@
// Package push provides a generic push-notification abstraction for Orama.
//
// Apps register devices with a provider name ("ntfy", "expo", "apns", ...)
// and a provider-specific token. The PushDispatcher routes outbound push
// messages to the matching provider so call sites stay backend-agnostic.
//
// Long-term the platform aims to drop Expo in favour of direct APNs +
// ntfy. The abstraction makes that swap a configuration change rather
// than a code change.
package push
import (
"context"
"errors"
)
// PushPriority signals delivery urgency to the provider.
// Providers that don't support priorities ignore the value.
type PushPriority string
const (
PriorityNormal PushPriority = "normal"
PriorityHigh PushPriority = "high"
)
// PushMessage is the provider-agnostic message format.
//
// DeviceToken is the provider-specific identifier (e.g. an ntfy topic,
// an Expo push token, an APNs device token). The PushDispatcher fills
// it in per-device before calling Send.
type PushMessage struct {
DeviceToken string
Title string
Body string
Data map[string]interface{}
Badge int
Sound string
Channel string // "messages", "calls", etc — provider may map to its own channel concept
Priority PushPriority
}
// PushProvider is implemented by each backend (ntfy, expo, apns).
type PushProvider interface {
Name() string
// Send delivers a single push. Returning an error counts as a delivery
// failure for that device; the dispatcher logs it and continues.
Send(ctx context.Context, msg PushMessage) error
}
// PushDevice represents a registered push target for a user.
//
// Token is plaintext in this struct — encryption happens at the storage
// layer. Callers who load Devices from the store must treat tokens as
// sensitive material (don't log them).
type PushDevice struct {
ID string
Namespace string
UserID string
DeviceID string // app-provided
Provider string // matches PushProvider.Name()
Token string
Platform string // "ios" | "android" | "web"
AppVer string
CreatedAt int64 // unix seconds
UpdatedAt int64
LastSeen int64
}
// PushDeviceStore persists per-user device registrations.
type PushDeviceStore interface {
// Upsert registers or updates a device. The Token is encrypted by the
// implementation before being written to durable storage.
Upsert(ctx context.Context, dev PushDevice) error
// Delete removes a single device by ID, scoped to the namespace.
Delete(ctx context.Context, namespace, id string) error
// ListForUser returns all devices for a user within a namespace.
// Tokens in the returned slice are decrypted.
ListForUser(ctx context.Context, namespace, userID string) ([]PushDevice, error)
}
// Sentinel errors.
var (
// ErrUnknownProvider is returned by the dispatcher when a device
// references a provider that isn't registered.
ErrUnknownProvider = errors.New("push: unknown provider")
// ErrEmptyToken is returned by providers when called with an empty
// DeviceToken.
ErrEmptyToken = errors.New("push: empty device token")
)

View File

@ -0,0 +1,295 @@
// Package aggregator buffers PubSub trigger events per
// (namespace, function, trigger) and flushes them as a single batched
// invocation. It's used by the PubSub trigger dispatcher when a trigger
// declares aggregation_window_ms > 0.
//
// State is local to each node — buffers are not replicated. This is by
// design: aggregation is intended for high-frequency, lossy event streams
// (presence, VAD signals, metrics). Crash recovery is not provided; an
// orderly shutdown flushes pending buffers via Shutdown().
package aggregator
import (
"context"
"encoding/json"
"sync"
"time"
"go.uber.org/zap"
)
// DefaultMaxBatchSize is used when a trigger sets MaxBatchSize=0.
const DefaultMaxBatchSize = 100
// Event is one buffered message, mirroring the dispatcher's PubSubEvent
// shape but kept local to avoid an import cycle.
type Event struct {
Topic string `json:"topic"`
Data json.RawMessage `json:"data"`
Namespace string `json:"namespace"`
TriggerDepth int `json:"trigger_depth"`
Timestamp int64 `json:"timestamp"`
}
// BatchedPayload is what the function receives on a flush.
// `Batched: true` lets a function differentiate single vs. batched mode
// by parsing this discriminator first.
type BatchedPayload struct {
Batched bool `json:"batched"`
Events []Event `json:"events"`
}
// FlushFn is invoked when a buffer flushes. It receives the marshalled
// BatchedPayload and a context with a sane timeout. The aggregator does
// not retry on flush errors — that's the invoker's responsibility.
type FlushFn func(ctx context.Context, payload []byte)
// FlushReason annotates why a flush happened. Useful for metrics.
type FlushReason string
const (
FlushReasonTimer FlushReason = "timer"
FlushReasonSize FlushReason = "size"
FlushReasonShutdown FlushReason = "shutdown"
)
// FlushFnWithReason is like FlushFn but also receives the reason.
// Internal use; FlushFn is the simple public form.
type FlushFnWithReason func(ctx context.Context, payload []byte, reason FlushReason)
// bufferKey identifies a single in-memory buffer.
type bufferKey struct {
Namespace string
FunctionID string
TriggerID string
}
type bufferEntry struct {
events []Event
timer *time.Timer
windowMs int
maxBatch int
flushFn FlushFnWithReason
}
// Aggregator buffers events per (namespace, function, trigger) and flushes
// either when the window timer fires or when MaxBatch is reached.
type Aggregator struct {
mu sync.Mutex
buffers map[bufferKey]*bufferEntry
logger *zap.Logger
flushTimeout time.Duration
// inflight tracks dispatched flush goroutines so Shutdown can wait
// for them to finish (or time out) before returning.
inflight sync.WaitGroup
}
// New creates an Aggregator. flushTimeout bounds the context passed to FlushFn.
// 0 selects a sane default (60s).
func New(logger *zap.Logger, flushTimeout time.Duration) *Aggregator {
if flushTimeout <= 0 {
flushTimeout = 60 * time.Second
}
if logger == nil {
logger = zap.NewNop()
}
return &Aggregator{
buffers: map[bufferKey]*bufferEntry{},
logger: logger.Named("aggregator"),
flushTimeout: flushTimeout,
}
}
// BufferRequest carries everything needed to add an event.
type BufferRequest struct {
Namespace string
FunctionID string
TriggerID string
WindowMs int
MaxBatchSize int
FlushFn FlushFn // simple public form; internally promoted to FlushFnWithReason
Event Event
}
// Buffer adds an event to the matching buffer. Returns immediately —
// the function is invoked later, asynchronously, when the window or
// size threshold fires.
//
// If WindowMs <= 0, this method panics with a programming-error message
// to surface misuse: callers should not buffer non-aggregating triggers.
func (a *Aggregator) Buffer(req BufferRequest) {
if req.WindowMs <= 0 {
// Aggregator should never be called for non-aggregating triggers.
// Panicking here makes the caller bug obvious during development.
panic("aggregator: Buffer called with WindowMs <= 0")
}
maxBatch := req.MaxBatchSize
if maxBatch <= 0 {
maxBatch = DefaultMaxBatchSize
}
key := bufferKey{Namespace: req.Namespace, FunctionID: req.FunctionID, TriggerID: req.TriggerID}
a.mu.Lock()
defer a.mu.Unlock()
entry, ok := a.buffers[key]
if !ok {
// Promote the user-facing FlushFn into the reason-aware variant.
// We capture req.FlushFn so subsequent Buffer calls keep using it.
userFn := req.FlushFn
entry = &bufferEntry{
events: make([]Event, 0, maxBatch),
windowMs: req.WindowMs,
maxBatch: maxBatch,
flushFn: func(ctx context.Context, payload []byte, reason FlushReason) {
if userFn != nil {
userFn(ctx, payload)
}
},
}
a.buffers[key] = entry
}
entry.events = append(entry.events, req.Event)
// Start the window timer on the first event of a new window.
if entry.timer == nil {
// Capture key by value for the closure.
k := key
entry.timer = time.AfterFunc(time.Duration(entry.windowMs)*time.Millisecond, func() {
a.flushByTimer(k)
})
}
// Size-triggered flush.
if len(entry.events) >= entry.maxBatch {
a.flushLocked(key, entry, FlushReasonSize)
}
}
// flushByTimer is invoked by time.AfterFunc; it acquires the lock then flushes.
func (a *Aggregator) flushByTimer(key bufferKey) {
a.mu.Lock()
defer a.mu.Unlock()
entry, ok := a.buffers[key]
if !ok {
// Buffer already flushed by size threshold and the bucket removed.
return
}
if len(entry.events) == 0 {
// Defensive: empty buffer — drop it so the map stays bounded.
delete(a.buffers, key)
return
}
a.flushLocked(key, entry, FlushReasonTimer)
}
// flushLocked must be called with a.mu held. It snapshots the current
// events, removes the bucket entry, then dispatches the flush in a
// goroutine so the caller doesn't block on the function invocation.
//
// Removing the bucket on flush keeps the buffers map bounded over the
// lifetime of the process. If a subsequent event arrives for the same
// (namespace, function, trigger) tuple, Buffer recreates the entry.
func (a *Aggregator) flushLocked(key bufferKey, entry *bufferEntry, reason FlushReason) {
if entry.timer != nil {
entry.timer.Stop()
entry.timer = nil
}
if len(entry.events) == 0 {
// Empty bucket — drop it so the map doesn't accumulate.
delete(a.buffers, key)
return
}
events := entry.events
payload, err := json.Marshal(BatchedPayload{Batched: true, Events: events})
if err != nil {
a.logger.Error("failed to marshal batched payload",
zap.String("namespace", key.Namespace),
zap.String("function_id", key.FunctionID),
zap.String("trigger_id", key.TriggerID),
zap.Int("batch_size", len(events)),
zap.Error(err),
)
// Still drop the bucket — there's no point retrying with the same data.
delete(a.buffers, key)
return
}
a.logger.Debug("aggregator flush",
zap.String("namespace", key.Namespace),
zap.String("function_id", key.FunctionID),
zap.String("trigger_id", key.TriggerID),
zap.Int("batch_size", len(events)),
zap.String("reason", string(reason)),
)
flushFn := entry.flushFn
timeout := a.flushTimeout
delete(a.buffers, key)
a.inflight.Add(1)
go func() {
defer a.inflight.Done()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
flushFn(ctx, payload, reason)
}()
}
// Shutdown drains all non-empty buffers and waits for the resulting flush
// invocations to finish, bounded by `wait`. Callers should pass a wait
// long enough to cover one function invocation (e.g. 510 seconds) but
// short enough that a misbehaving function can't delay process exit.
//
// Returns true if all in-flight flushes completed before the deadline,
// false on timeout (in which case some events are effectively lost).
func (a *Aggregator) Shutdown(wait time.Duration) bool {
a.mu.Lock()
keys := make([]bufferKey, 0, len(a.buffers))
for k := range a.buffers {
keys = append(keys, k)
}
for _, k := range keys {
entry := a.buffers[k]
if entry == nil {
continue
}
if entry.timer != nil {
entry.timer.Stop()
entry.timer = nil
}
a.flushLocked(k, entry, FlushReasonShutdown)
}
a.mu.Unlock()
if wait <= 0 {
return true
}
done := make(chan struct{})
go func() {
a.inflight.Wait()
close(done)
}()
select {
case <-done:
return true
case <-time.After(wait):
a.logger.Warn("aggregator shutdown timed out; some buffered events may be lost")
return false
}
}
// Stats reports the current number of buffered events across all keys.
// Useful for metrics.
func (a *Aggregator) Stats() (numBuffers, totalEvents int) {
a.mu.Lock()
defer a.mu.Unlock()
numBuffers = len(a.buffers)
for _, e := range a.buffers {
totalEvents += len(e.events)
}
return
}

View File

@ -0,0 +1,307 @@
package aggregator
import (
"context"
"encoding/json"
"sync"
"sync/atomic"
"testing"
"time"
"go.uber.org/zap"
)
func TestBuffer_panics_on_zero_window(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when WindowMs <= 0")
}
}()
a := New(zap.NewNop(), time.Second)
a.Buffer(BufferRequest{
Namespace: "ns",
FunctionID: "fn",
TriggerID: "tr",
WindowMs: 0,
FlushFn: func(ctx context.Context, payload []byte) {},
Event: Event{Topic: "t"},
})
}
func TestBuffer_flushes_on_timer(t *testing.T) {
a := New(zap.NewNop(), 5*time.Second)
var (
got []Event
gotMu sync.Mutex
done = make(chan struct{})
)
flush := func(ctx context.Context, payload []byte) {
var p BatchedPayload
if err := json.Unmarshal(payload, &p); err != nil {
t.Errorf("unmarshal: %v", err)
}
gotMu.Lock()
got = append(got, p.Events...)
gotMu.Unlock()
close(done)
}
for i := 0; i < 3; i++ {
a.Buffer(BufferRequest{
Namespace: "ns",
FunctionID: "fn",
TriggerID: "tr",
WindowMs: 100, // short window so test runs fast
FlushFn: flush,
Event: Event{Topic: "presence:user", Data: json.RawMessage(`"e"`)},
})
}
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("flush did not fire within 2s")
}
gotMu.Lock()
defer gotMu.Unlock()
if len(got) != 3 {
t.Errorf("expected 3 buffered events, got %d", len(got))
}
}
func TestBuffer_flushes_on_max_batch_size(t *testing.T) {
a := New(zap.NewNop(), 5*time.Second)
var (
flushCount int32
flushSize int32
done = make(chan struct{})
)
flush := func(ctx context.Context, payload []byte) {
var p BatchedPayload
_ = json.Unmarshal(payload, &p)
atomic.AddInt32(&flushCount, 1)
atomic.StoreInt32(&flushSize, int32(len(p.Events)))
select {
case <-done:
default:
close(done)
}
}
for i := 0; i < 5; i++ {
a.Buffer(BufferRequest{
Namespace: "ns",
FunctionID: "fn",
TriggerID: "tr",
WindowMs: 30_000, // long enough that the timer won't fire
MaxBatchSize: 5,
FlushFn: flush,
Event: Event{Topic: "t"},
})
}
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("max-batch flush did not fire")
}
if atomic.LoadInt32(&flushCount) != 1 {
t.Errorf("expected 1 flush, got %d", flushCount)
}
if atomic.LoadInt32(&flushSize) != 5 {
t.Errorf("expected batch size 5, got %d", flushSize)
}
}
func TestBuffer_separate_keys_independent(t *testing.T) {
a := New(zap.NewNop(), 5*time.Second)
var (
mu sync.Mutex
counts = map[string]int{}
flush = func(name string) FlushFn {
return func(ctx context.Context, payload []byte) {
var p BatchedPayload
_ = json.Unmarshal(payload, &p)
mu.Lock()
counts[name] += len(p.Events)
mu.Unlock()
}
}
)
a.Buffer(BufferRequest{
Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A",
WindowMs: 100, FlushFn: flush("A"),
Event: Event{Topic: "a"},
})
a.Buffer(BufferRequest{
Namespace: "ns", FunctionID: "fn", TriggerID: "tr-B",
WindowMs: 100, FlushFn: flush("B"),
Event: Event{Topic: "b"},
})
a.Buffer(BufferRequest{
Namespace: "ns", FunctionID: "fn", TriggerID: "tr-A",
WindowMs: 100, FlushFn: flush("A"),
Event: Event{Topic: "a2"},
})
time.Sleep(500 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if counts["A"] != 2 {
t.Errorf("A: expected 2 events, got %d", counts["A"])
}
if counts["B"] != 1 {
t.Errorf("B: expected 1 event, got %d", counts["B"])
}
}
func TestShutdown_flushes_all_buffers(t *testing.T) {
a := New(zap.NewNop(), 2*time.Second)
var flushed int32
flush := func(ctx context.Context, payload []byte) {
atomic.AddInt32(&flushed, 1)
}
for i := 0; i < 4; i++ {
a.Buffer(BufferRequest{
Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
WindowMs: 30_000,
FlushFn: flush,
Event: Event{Topic: "t"},
})
}
// Different trigger key — should produce a separate flush.
a.Buffer(BufferRequest{
Namespace: "ns", FunctionID: "fn", TriggerID: "other",
WindowMs: 30_000,
FlushFn: flush,
Event: Event{Topic: "t2"},
})
a.Shutdown(2*time.Second)
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&flushed) < 2 {
if time.Now().After(deadline) {
t.Fatalf("expected 2 flushes from Shutdown, got %d", flushed)
}
time.Sleep(10 * time.Millisecond)
}
}
func TestShutdown_skips_empty_buffers(t *testing.T) {
a := New(zap.NewNop(), 2*time.Second)
var flushed int32
flush := func(ctx context.Context, payload []byte) {
atomic.AddInt32(&flushed, 1)
}
// Add an event to create the buffer entry, then drain via size flush.
a.Buffer(BufferRequest{
Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
WindowMs: 30_000, MaxBatchSize: 1,
FlushFn: flush, Event: Event{Topic: "t"},
})
// Wait for the size-triggered flush to drain.
deadline := time.Now().Add(2 * time.Second)
for atomic.LoadInt32(&flushed) < 1 {
if time.Now().After(deadline) {
t.Fatal("size flush didn't fire")
}
time.Sleep(5 * time.Millisecond)
}
// Now the buffer is empty. Shutdown should not flush again.
a.Shutdown(2*time.Second)
time.Sleep(200 * time.Millisecond)
if atomic.LoadInt32(&flushed) != 1 {
t.Errorf("Shutdown flushed an empty buffer: total flushes %d", flushed)
}
}
func TestStats_reports_buffered_state(t *testing.T) {
a := New(zap.NewNop(), 2*time.Second)
flush := func(ctx context.Context, payload []byte) {}
a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "a", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
a.Buffer(BufferRequest{Namespace: "ns", FunctionID: "fn", TriggerID: "b", WindowMs: 30_000, FlushFn: flush, Event: Event{Topic: "t"}})
bufs, evs := a.Stats()
if bufs != 2 {
t.Errorf("expected 2 buffers, got %d", bufs)
}
if evs != 3 {
t.Errorf("expected 3 buffered events, got %d", evs)
}
}
func TestBuffer_concurrent_writes_no_race(t *testing.T) {
// Run with -race: this should not detect any data races.
a := New(zap.NewNop(), 2*time.Second)
flush := func(ctx context.Context, payload []byte) {}
var wg sync.WaitGroup
for g := 0; g < 8; g++ {
wg.Add(1)
go func(gid int) {
defer wg.Done()
for i := 0; i < 50; i++ {
a.Buffer(BufferRequest{
Namespace: "ns",
FunctionID: "fn",
TriggerID: "tr",
WindowMs: 200,
FlushFn: flush,
Event: Event{Topic: "t"},
})
}
}(g)
}
wg.Wait()
// Drain.
a.Shutdown(2*time.Second)
}
func TestBuffer_payload_includes_batched_true_and_topic(t *testing.T) {
a := New(zap.NewNop(), 2*time.Second)
got := make(chan BatchedPayload, 1)
flush := func(ctx context.Context, payload []byte) {
var p BatchedPayload
if err := json.Unmarshal(payload, &p); err != nil {
t.Errorf("unmarshal: %v", err)
}
got <- p
}
a.Buffer(BufferRequest{
Namespace: "ns", FunctionID: "fn", TriggerID: "tr",
WindowMs: 50, FlushFn: flush,
Event: Event{Topic: "presence:user-1", Data: json.RawMessage(`{"x":1}`)},
})
select {
case p := <-got:
if !p.Batched {
t.Error("payload should have Batched=true")
}
if len(p.Events) != 1 || p.Events[0].Topic != "presence:user-1" {
t.Errorf("unexpected events: %+v", p.Events)
}
case <-time.After(2 * time.Second):
t.Fatal("flush did not fire")
}
}

View File

@ -328,6 +328,8 @@ func (e *Engine) registerHostModule(ctx context.Context) error {
NewFunctionBuilder().WithFunc(e.hCacheIncrBy).Export("cache_incr_by").
NewFunctionBuilder().WithFunc(e.hHTTPFetch).Export("http_fetch").
NewFunctionBuilder().WithFunc(e.hPubSubPublish).Export("pubsub_publish").
NewFunctionBuilder().WithFunc(e.hPubSubPublishBatch).Export("pubsub_publish_batch").
NewFunctionBuilder().WithFunc(e.hPushSend).Export("push_send").
NewFunctionBuilder().WithFunc(e.hLogInfo).Export("log_info").
NewFunctionBuilder().WithFunc(e.hLogError).Export("log_error").
Instantiate(ctx)
@ -517,6 +519,46 @@ func (e *Engine) hPubSubPublish(ctx context.Context, mod api.Module, topicPtr, t
return 1 // Success
}
// hPubSubPublishBatch is the WASM-callable wrapper for PubSubPublishBatch.
// Input: pointer/length of a JSON array of {topic, data_base64}.
// Returns 1 on success, 0 on error.
func (e *Engine) hPubSubPublishBatch(ctx context.Context, mod api.Module, msgsPtr, msgsLen uint32) uint32 {
msgsJSON, ok := e.executor.ReadFromGuest(mod, msgsPtr, msgsLen)
if !ok {
return 0
}
if err := e.hostServices.PubSubPublishBatch(ctx, msgsJSON); err != nil {
e.logger.Error("host function pubsub_publish_batch failed", zap.Error(err))
return 0
}
return 1
}
// hPushSend is the WASM-callable wrapper for PushSend.
// Inputs:
// userIDPtr/userIDLen — UTF-8 user ID to push to (within the function's
// own namespace; the namespace is server-side trusted)
// msgPtr/msgLen — JSON payload matching hostfunctions.PushSendArgs
// Returns 1 on success, 0 on error.
func (e *Engine) hPushSend(ctx context.Context, mod api.Module,
userIDPtr, userIDLen, msgPtr, msgLen uint32) uint32 {
userID, ok := e.executor.ReadFromGuest(mod, userIDPtr, userIDLen)
if !ok {
return 0
}
msgJSON, ok := e.executor.ReadFromGuest(mod, msgPtr, msgLen)
if !ok {
return 0
}
if err := e.hostServices.PushSend(ctx, string(userID), msgJSON); err != nil {
e.logger.Error("host function push_send failed",
zap.String("user_id", string(userID)),
zap.Error(err))
return 0
}
return 1
}
func (e *Engine) hLogInfo(ctx context.Context, mod api.Module, ptr, size uint32) {
msg, ok := e.executor.ReadFromGuest(mod, ptr, size)
if ok {

View File

@ -90,6 +90,14 @@ func (m *mockHostServices) PubSubPublish(ctx context.Context, topic string, data
return nil
}
func (m *mockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
return nil
}
func (m *mockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
return nil
}
func (m *mockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
return nil
}

View File

@ -5,6 +5,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/push"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/DeBrosOfficial/network/pkg/tlsutil"
@ -13,6 +14,10 @@ import (
)
// NewHostFunctions creates a new HostFunctions instance.
//
// pushDispatcher may be nil when push isn't configured on this gateway —
// in that case PushSend hostfunc returns nil (silent no-op) so functions
// remain portable across deployments with/without push.
func NewHostFunctions(
db rqlite.Client,
cacheClient olriclib.Client,
@ -20,6 +25,7 @@ func NewHostFunctions(
pubsubAdapter *pubsub.ClientAdapter,
wsManager serverless.WebSocketManager,
secrets serverless.SecretsManager,
pushDispatcher *push.PushDispatcher,
cfg HostFunctionsConfig,
logger *zap.Logger,
) *HostFunctions {
@ -29,15 +35,16 @@ func NewHostFunctions(
}
return &HostFunctions{
db: db,
cacheClient: cacheClient,
storage: storage,
ipfsAPIURL: cfg.IPFSAPIURL,
pubsub: pubsubAdapter,
wsManager: wsManager,
secrets: secrets,
httpClient: tlsutil.NewHTTPClient(httpTimeout),
logger: logger,
logs: make([]serverless.LogEntry, 0),
db: db,
cacheClient: cacheClient,
storage: storage,
ipfsAPIURL: cfg.IPFSAPIURL,
pubsub: pubsubAdapter,
wsManager: wsManager,
secrets: secrets,
pushDispatcher: pushDispatcher,
httpClient: tlsutil.NewHTTPClient(httpTimeout),
logger: logger,
logs: make([]serverless.LogEntry, 0),
}
}

View File

@ -2,8 +2,11 @@ package hostfunctions
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/serverless"
)
@ -21,6 +24,64 @@ func (h *HostFunctions) PubSubPublish(ctx context.Context, topic string, data []
return nil
}
// pubSubBatchEntry mirrors the JSON shape accepted by PubSubPublishBatch.
type pubSubBatchEntry struct {
Topic string `json:"topic"`
DataB64 string `json:"data_base64"`
}
// PubSubPublishBatch publishes multiple messages in parallel.
//
// Input is JSON: [{"topic":"...","data_base64":"..."}, ...]
// Up to pubsub.MaxBatchSize entries per call.
//
// Default behavior is fail-fast (first publish error is returned). The
// host function does not currently expose a best-effort flag — WASM
// callers that need it should call this function multiple times in
// chunks they're willing to retry independently.
func (h *HostFunctions) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
if h.pubsub == nil {
return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("pubsub not available")}
}
var entries []pubSubBatchEntry
if err := json.Unmarshal(msgsJSON, &entries); err != nil {
return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: fmt.Errorf("invalid json: %w", err)}
}
if len(entries) == 0 {
return nil
}
if len(entries) > pubsub.MaxBatchSize {
return &serverless.HostFunctionError{
Function: "pubsub_publish_batch",
Cause: fmt.Errorf("too many messages: max %d per batch", pubsub.MaxBatchSize),
}
}
msgs := make([]pubsub.TopicMessage, 0, len(entries))
for i, e := range entries {
if e.Topic == "" {
return &serverless.HostFunctionError{
Function: "pubsub_publish_batch",
Cause: fmt.Errorf("entry %d: empty topic", i),
}
}
data, err := base64.StdEncoding.DecodeString(e.DataB64)
if err != nil {
return &serverless.HostFunctionError{
Function: "pubsub_publish_batch",
Cause: fmt.Errorf("entry %d (topic %q): bad base64: %w", i, e.Topic, err),
}
}
msgs = append(msgs, pubsub.TopicMessage{Topic: e.Topic, Data: data})
}
if err := h.pubsub.PublishBatch(ctx, msgs, pubsub.PublishBatchOptions{}); err != nil {
return &serverless.HostFunctionError{Function: "pubsub_publish_batch", Cause: err}
}
return nil
}
// WSSend sends data to a specific WebSocket client.
func (h *HostFunctions) WSSend(ctx context.Context, clientID string, data []byte) error {
if h.wsManager == nil {

View File

@ -0,0 +1,103 @@
package hostfunctions
import (
"context"
"encoding/json"
"fmt"
"github.com/DeBrosOfficial/network/pkg/push"
"github.com/DeBrosOfficial/network/pkg/serverless"
)
// PushSendArgs is the JSON payload format the WASM caller marshals into
// the `msgJSON` argument of PushSend. Mirrors push.PushMessage minus the
// device-token (which is filled in per-device by the dispatcher).
type PushSendArgs struct {
Title string `json:"title,omitempty"`
Body string `json:"body,omitempty"`
Channel string `json:"channel,omitempty"`
Priority string `json:"priority,omitempty"` // "high" | "normal" | ""
Badge int `json:"badge,omitempty"`
Sound string `json:"sound,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
}
// MaxPushSendArgsBytes caps the JSON arg size to a few KB. Push payloads
// are small by construction (APNs caps at 4KB, ntfy/Expo similar).
const MaxPushSendArgsBytes = 16 * 1024
// PushSend implements serverless.HostServices.PushSend.
//
// Sends a push notification to all devices the user has registered in the
// function's namespace. The caller can only target users in their own
// namespace — the dispatcher reads the namespace from the invocation
// context (set by the engine before invoking).
//
// If push is not configured on this gateway (no dispatcher), this returns
// nil (silent no-op) so functions remain portable across environments.
func (h *HostFunctions) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
if h.pushDispatcher == nil {
// Silent no-op — push isn't configured on this gateway.
return nil
}
if userID == "" {
return &serverless.HostFunctionError{
Function: "push_send",
Cause: fmt.Errorf("user_id required"),
}
}
if len(msgJSON) > MaxPushSendArgsBytes {
return &serverless.HostFunctionError{
Function: "push_send",
Cause: fmt.Errorf("msg too large: max %d bytes", MaxPushSendArgsBytes),
}
}
var args PushSendArgs
if err := json.Unmarshal(msgJSON, &args); err != nil {
return &serverless.HostFunctionError{
Function: "push_send",
Cause: fmt.Errorf("invalid json: %w", err),
}
}
// Resolve namespace from the current invocation context. A function
// can NEVER push to another namespace's users — the namespace is
// trusted server-side, not from the WASM input.
h.invCtxLock.RLock()
var namespace string
if h.invCtx != nil {
namespace = h.invCtx.Namespace
}
h.invCtxLock.RUnlock()
if namespace == "" {
return &serverless.HostFunctionError{
Function: "push_send",
Cause: fmt.Errorf("no namespace in invocation context"),
}
}
priority := push.PriorityNormal
switch args.Priority {
case "high":
priority = push.PriorityHigh
case "normal", "":
priority = push.PriorityNormal
}
msg := push.PushMessage{
Title: args.Title,
Body: args.Body,
Channel: args.Channel,
Priority: priority,
Badge: args.Badge,
Sound: args.Sound,
Data: args.Data,
}
if err := h.pushDispatcher.SendToUser(ctx, namespace, userID, msg); err != nil {
return &serverless.HostFunctionError{Function: "push_send", Cause: err}
}
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/push"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
olriclib "github.com/olric-data/olric"
@ -32,6 +33,10 @@ type HostFunctions struct {
httpClient *http.Client
logger *zap.Logger
// pushDispatcher may be nil when push isn't configured for this gateway.
// In that case PushSend returns nil silently — see hostfunctions/push.go.
pushDispatcher *push.PushDispatcher
// Current invocation context (set per-execution)
invCtx *serverless.InvocationContext
invCtxLock sync.RWMutex

View File

@ -176,6 +176,14 @@ func (m *MockHostServices) PubSubPublish(ctx context.Context, topic string, data
return nil
}
func (m *MockHostServices) PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error {
return nil
}
func (m *MockHostServices) PushSend(ctx context.Context, userID string, msgJSON []byte) error {
return nil
}
func (m *MockHostServices) WSSend(ctx context.Context, clientID string, data []byte) error {
return nil
}

View File

@ -6,14 +6,12 @@ import (
"time"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/DeBrosOfficial/network/pkg/serverless/aggregator"
olriclib "github.com/olric-data/olric"
"go.uber.org/zap"
)
const (
// triggerCacheDMap is the Olric DMap name for caching trigger lookups.
triggerCacheDMap = "pubsub_triggers"
// maxTriggerDepth prevents infinite loops when triggered functions publish
// back to the same topic via the HTTP API.
maxTriggerDepth = 5
@ -37,6 +35,7 @@ type PubSubDispatcher struct {
store *PubSubTriggerStore
invoker *serverless.Invoker
olricClient olriclib.Client // may be nil (cache disabled)
aggregator *aggregator.Aggregator
logger *zap.Logger
}
@ -51,10 +50,17 @@ func NewPubSubDispatcher(
store: store,
invoker: invoker,
olricClient: olricClient,
aggregator: aggregator.New(logger, dispatchTimeout),
logger: logger,
}
}
// Aggregator exposes the underlying aggregator so callers (gateway lifecycle)
// can flush pending buffers on shutdown.
func (d *PubSubDispatcher) Aggregator() *aggregator.Aggregator {
return d.aggregator
}
// Dispatch looks up all triggers registered for the given topic+namespace and
// invokes matching functions asynchronously. Each invocation runs in its own
// goroutine and does not block the caller.
@ -82,7 +88,7 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
return
}
// Build the event payload once for all invocations
// Build the per-event payload once for non-aggregating dispatches.
event := PubSubEvent{
Topic: topic,
Data: json.RawMessage(data),
@ -90,11 +96,6 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
TriggerDepth: depth + 1,
Timestamp: time.Now().Unix(),
}
eventJSON, err := json.Marshal(event)
if err != nil {
d.logger.Error("Failed to marshal PubSub event", zap.Error(err))
return
}
d.logger.Debug("Dispatching PubSub triggers",
zap.String("namespace", namespace),
@ -103,94 +104,81 @@ func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string
zap.Int("depth", depth),
)
var (
eventJSON []byte
marshalErr error
)
for _, match := range matches {
if match.AggregationWindowMs > 0 {
d.bufferEvent(match, event)
continue
}
// Lazily marshal — non-aggregating triggers need eventJSON.
if eventJSON == nil && marshalErr == nil {
eventJSON, marshalErr = json.Marshal(event)
if marshalErr != nil {
d.logger.Error("Failed to marshal PubSub event", zap.Error(marshalErr))
continue
}
}
if marshalErr != nil {
continue
}
go d.invokeFunction(match, eventJSON)
}
}
// InvalidateCache removes the cached trigger lookup for a namespace+topic.
// Call this when triggers are added or removed.
func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) {
if d.olricClient == nil {
return
}
dm, err := d.olricClient.NewDMap(triggerCacheDMap)
if err != nil {
d.logger.Debug("Failed to get trigger cache DMap for invalidation", zap.Error(err))
return
}
key := cacheKey(namespace, topic)
if _, err := dm.Delete(ctx, key); err != nil {
d.logger.Debug("Failed to invalidate trigger cache", zap.String("key", key), zap.Error(err))
}
// bufferEvent routes an event through the aggregator. The flush callback
// invokes the function with the batched payload.
func (d *PubSubDispatcher) bufferEvent(match TriggerMatch, event PubSubEvent) {
d.aggregator.Buffer(aggregator.BufferRequest{
Namespace: match.Namespace,
FunctionID: match.FunctionID,
TriggerID: match.TriggerID,
WindowMs: match.AggregationWindowMs,
MaxBatchSize: match.AggregationMaxBatchSize,
Event: aggregator.Event{
Topic: event.Topic,
Data: event.Data,
Namespace: event.Namespace,
TriggerDepth: event.TriggerDepth,
Timestamp: event.Timestamp,
},
FlushFn: func(ctx context.Context, payload []byte) {
req := &serverless.InvokeRequest{
Namespace: match.Namespace,
FunctionName: match.FunctionName,
Input: payload,
TriggerType: serverless.TriggerTypePubSub,
}
if _, err := d.invoker.Invoke(ctx, req); err != nil {
d.logger.Warn("Aggregated PubSub invocation failed",
zap.String("function", match.FunctionName),
zap.String("trigger_id", match.TriggerID),
zap.Error(err),
)
}
},
})
}
// getMatches returns the trigger matches for a topic+namespace, using Olric cache when available.
// InvalidateCache is now a no-op — the dispatcher no longer caches lookups.
// Kept on the type so callers who used it still compile.
func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) {}
// getMatches returns the trigger matches for a topic+namespace.
//
// Caching note: an earlier revision cached results in Olric keyed by
// (namespace, topic). With wildcard triggers the cache becomes
// inconsistent — a single trigger Add/Remove invalidates an unbounded
// number of resolved-topic keys. The cache was removed; re-introducing
// it requires a generation-counter (or TTL) scheme that handles
// wildcard pattern changes.
func (d *PubSubDispatcher) getMatches(ctx context.Context, namespace, topic string) ([]TriggerMatch, error) {
// Try cache first
if d.olricClient != nil {
if matches, ok := d.getCached(ctx, namespace, topic); ok {
return matches, nil
}
}
// Cache miss — query database
matches, err := d.store.GetByTopicAndNamespace(ctx, topic, namespace)
if err != nil {
return nil, err
}
// Populate cache
if d.olricClient != nil && matches != nil {
d.setCache(ctx, namespace, topic, matches)
}
return matches, nil
return d.store.GetByTopicAndNamespace(ctx, topic, namespace)
}
// getCached attempts to retrieve trigger matches from Olric cache.
func (d *PubSubDispatcher) getCached(ctx context.Context, namespace, topic string) ([]TriggerMatch, bool) {
dm, err := d.olricClient.NewDMap(triggerCacheDMap)
if err != nil {
return nil, false
}
key := cacheKey(namespace, topic)
result, err := dm.Get(ctx, key)
if err != nil {
return nil, false
}
data, err := result.Byte()
if err != nil {
return nil, false
}
var matches []TriggerMatch
if err := json.Unmarshal(data, &matches); err != nil {
return nil, false
}
return matches, true
}
// setCache stores trigger matches in Olric cache.
func (d *PubSubDispatcher) setCache(ctx context.Context, namespace, topic string, matches []TriggerMatch) {
dm, err := d.olricClient.NewDMap(triggerCacheDMap)
if err != nil {
return
}
data, err := json.Marshal(matches)
if err != nil {
return
}
key := cacheKey(namespace, topic)
_ = dm.Put(ctx, key, data)
}
// invokeFunction invokes a single function for a trigger match.
func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte) {
@ -224,7 +212,3 @@ func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte)
)
}
// cacheKey returns the Olric cache key for a namespace+topic pair.
func cacheKey(namespace, topic string) string {
return "triggers:" + namespace + ":" + topic
}

View File

@ -0,0 +1,138 @@
package triggers
import (
"fmt"
"strings"
)
// MaxPatternLength is the maximum allowed glob pattern length.
const MaxPatternLength = 256
// ValidatePattern checks that a glob pattern is well-formed.
// Empty patterns and unbalanced character classes return an error.
// Patterns longer than MaxPatternLength are rejected to keep DB scans bounded.
func ValidatePattern(p string) error {
if p == "" {
return fmt.Errorf("empty pattern")
}
if len(p) > MaxPatternLength {
return fmt.Errorf("pattern too long: %d > %d", len(p), MaxPatternLength)
}
open := 0
for i, c := range p {
switch c {
case '[':
open++
case ']':
open--
if open < 0 {
return fmt.Errorf("unmatched ']' at position %d", i)
}
}
}
if open != 0 {
return fmt.Errorf("unmatched '['")
}
return nil
}
// IsWildcard reports whether the pattern contains any glob metacharacter.
// Useful for choosing between exact-match cache keys and wildcard scans.
func IsWildcard(p string) bool {
return strings.ContainsAny(p, "*?[")
}
// PatternMatches returns true when topic matches pattern under Orama's glob
// semantics:
// - '*' matches zero or more characters EXCEPT ':'
// - '**' matches zero or more characters INCLUDING ':' (deep wildcard)
// - '?' matches exactly one character (any)
// - '[abc]' / '[!abc]' character classes
//
// SQLite's GLOB is the first-pass filter (in pubsub_store.go); this
// post-filter enforces segment boundaries for single-'*' patterns since
// SQLite GLOB treats '*' as "any chars including separators".
func PatternMatches(pattern, topic string) bool {
if strings.Contains(pattern, "**") {
// Deep wildcards already accept across segment boundaries — SQLite GLOB
// already accepted this row. No further filtering needed.
return true
}
return strictGlobMatch(pattern, topic)
}
// strictGlobMatch implements glob matching where '*' does NOT cross ':'.
// Recursive backtracking matcher; bounded length keeps it cheap.
func strictGlobMatch(pattern, s string) bool {
pi, si := 0, 0
starPi, starSi := -1, -1
for si < len(s) {
if pi < len(pattern) {
pc := pattern[pi]
switch pc {
case '?':
pi++
si++
continue
case '*':
// Remember position so we can backtrack.
starPi = pi
starSi = si
pi++
continue
case '[':
end := strings.IndexByte(pattern[pi+1:], ']')
if end < 0 {
return false
}
class := pattern[pi+1 : pi+1+end]
if matchClass(class, s[si]) {
pi += end + 2
si++
continue
}
default:
if pc == s[si] {
pi++
si++
continue
}
}
}
// No match at this position — try to extend the last '*' if any,
// but '*' must not cross a ':' segment separator.
if starPi >= 0 && s[starSi] != ':' {
starSi++
pi = starPi + 1
si = starSi
continue
}
return false
}
// Consume any trailing '*' in the pattern.
for pi < len(pattern) && pattern[pi] == '*' {
pi++
}
return pi == len(pattern)
}
// matchClass reports whether c matches the SQLite-style character class body
// (between '[' and ']'). Supports negation with leading '!'.
func matchClass(class string, c byte) bool {
if class == "" {
return false
}
negate := false
if class[0] == '!' {
negate = true
class = class[1:]
}
for i := 0; i < len(class); i++ {
if class[i] == c {
return !negate
}
}
return negate
}

View File

@ -0,0 +1,162 @@
package triggers
import (
"strings"
"testing"
)
func TestValidatePattern_empty_returns_error(t *testing.T) {
if err := ValidatePattern(""); err == nil {
t.Error("expected error for empty pattern")
}
}
func TestValidatePattern_too_long_returns_error(t *testing.T) {
long := strings.Repeat("a", MaxPatternLength+1)
if err := ValidatePattern(long); err == nil {
t.Error("expected error for over-long pattern")
}
}
func TestValidatePattern_unbalanced_brackets_returns_error(t *testing.T) {
cases := []string{"a[b", "a]b", "[a[b]", "a]"}
for _, c := range cases {
if err := ValidatePattern(c); err == nil {
t.Errorf("expected error for %q", c)
}
}
}
func TestValidatePattern_valid_patterns_no_error(t *testing.T) {
cases := []string{"foo", "foo:*", "foo:**", "*.bar", "[abc]xyz", "[!a]b", "?abc"}
for _, c := range cases {
if err := ValidatePattern(c); err != nil {
t.Errorf("expected no error for %q, got: %v", c, err)
}
}
}
func TestIsWildcard(t *testing.T) {
cases := map[string]bool{
"foo": false,
"foo:bar": false,
"foo:*": true,
"foo?bar": true,
"[abc]xyz": true,
"foo:**": true,
"a:b:c:d:e:f": false,
}
for in, want := range cases {
if got := IsWildcard(in); got != want {
t.Errorf("IsWildcard(%q) = %v, want %v", in, got, want)
}
}
}
func TestPatternMatches_exact(t *testing.T) {
cases := []struct {
pattern, topic string
want bool
}{
{"foo", "foo", true},
{"foo", "bar", false},
{"foo:bar", "foo:bar", true},
{"foo:bar", "foo:baz", false},
}
for _, c := range cases {
if got := PatternMatches(c.pattern, c.topic); got != c.want {
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
}
}
}
func TestPatternMatches_single_star_segment_bounded(t *testing.T) {
cases := []struct {
pattern, topic string
want bool
}{
// '*' matches within a single segment
{"presence:*", "presence:user-1", true},
{"presence:*", "presence:user-2", true},
{"presence:*", "presence:", true},
// '*' does NOT cross ':'
{"presence:*", "presence:user:device", false},
{"a:*:b", "a:x:b", true},
{"a:*:b", "a:x:y:b", false},
// Different prefix
{"presence:*", "calls:invite", false},
}
for _, c := range cases {
if got := PatternMatches(c.pattern, c.topic); got != c.want {
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
}
}
}
func TestPatternMatches_double_star_crosses_segments(t *testing.T) {
cases := []struct {
pattern, topic string
want bool
}{
{"notify:**", "notify:user-1", true},
{"notify:**", "notify:user:device:1", true},
{"**", "anything:goes:here", true},
}
for _, c := range cases {
if got := PatternMatches(c.pattern, c.topic); got != c.want {
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
}
}
}
func TestPatternMatches_question_mark(t *testing.T) {
cases := []struct {
pattern, topic string
want bool
}{
{"a?c", "abc", true},
{"a?c", "axc", true},
{"a?c", "ac", false},
{"a?c", "abbc", false},
}
for _, c := range cases {
if got := PatternMatches(c.pattern, c.topic); got != c.want {
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
}
}
}
func TestPatternMatches_character_class(t *testing.T) {
cases := []struct {
pattern, topic string
want bool
}{
{"[abc]xyz", "axyz", true},
{"[abc]xyz", "bxyz", true},
{"[abc]xyz", "dxyz", false},
{"[!a]bc", "xbc", true},
{"[!a]bc", "abc", false},
}
for _, c := range cases {
if got := PatternMatches(c.pattern, c.topic); got != c.want {
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
}
}
}
func TestPatternMatches_trailing_star_with_remaining_chars(t *testing.T) {
// '*' can match zero characters at end.
cases := []struct {
pattern, topic string
want bool
}{
{"foo*", "foo", true},
{"foo*", "foobar", true},
{"foo*", "foobar:baz", false}, // ':' breaks single '*'
}
for _, c := range cases {
if got := PatternMatches(c.pattern, c.topic); got != c.want {
t.Errorf("PatternMatches(%q, %q) = %v, want %v", c.pattern, c.topic, got, c.want)
}
}
}

View File

@ -16,30 +16,43 @@ import (
// TriggerMatch contains the fields needed to dispatch a trigger invocation.
// It's the result of JOINing function_pubsub_triggers with functions.
//
// Topic is the *resolved* topic that the published message was sent to,
// not the pattern stored in the trigger. This lets aggregating functions
// see which concrete topic each event came from.
//
// AggregationWindowMs > 0 indicates the dispatcher should buffer events
// instead of invoking the function per event.
type TriggerMatch struct {
TriggerID string
FunctionID string
FunctionName string
Namespace string
Topic string
TriggerID string
FunctionID string
FunctionName string
Namespace string
Topic string
AggregationWindowMs int
AggregationMaxBatchSize int
}
// triggerRow maps to the function_pubsub_triggers table for query scanning.
type triggerRow struct {
ID string
FunctionID string
Topic string
Enabled bool
CreatedAt time.Time
ID string
FunctionID string
TopicPattern string
Enabled bool
CreatedAt time.Time
AggregationWindowMs int
AggregationMaxBatchSize int
}
// triggerMatchRow maps to the JOIN query result for scanning.
type triggerMatchRow struct {
TriggerID string
FunctionID string
FunctionName string
Namespace string
Topic string
TriggerID string
FunctionID string
FunctionName string
Namespace string
TopicPattern string
AggregationWindowMs int
AggregationMaxBatchSize int
}
// PubSubTriggerStore manages PubSub trigger persistence in RQLite.
@ -57,30 +70,60 @@ func NewPubSubTriggerStore(db rqlite.Client, logger *zap.Logger) *PubSubTriggerS
}
// Add registers a new PubSub trigger for a function.
// `topicPattern` may be an exact topic or a SQLite GLOB pattern (e.g. "presence:*").
// Returns the trigger ID.
func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topic string) (string, error) {
//
// For backward compatibility, aggregation defaults to disabled (windowMs=0).
// Use AddWithAggregation to opt in.
func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topicPattern string) (string, error) {
return s.AddWithAggregation(ctx, functionID, topicPattern, 0, 0)
}
// AddWithAggregation registers a trigger with optional aggregation.
// - aggregationWindowMs = 0 disables aggregation (per-event invocation, default).
// - aggregationMaxBatchSize = 0 uses the default (100) when aggregation is enabled.
func (s *PubSubTriggerStore) AddWithAggregation(
ctx context.Context,
functionID, topicPattern string,
aggregationWindowMs, aggregationMaxBatchSize int,
) (string, error) {
if functionID == "" {
return "", fmt.Errorf("function ID required")
}
if topic == "" {
return "", fmt.Errorf("topic required")
if err := ValidatePattern(topicPattern); err != nil {
return "", fmt.Errorf("invalid topic pattern: %w", err)
}
if aggregationWindowMs < 0 || aggregationWindowMs > 60_000 {
return "", fmt.Errorf("aggregation_window_ms must be between 0 and 60000")
}
if aggregationMaxBatchSize < 0 || aggregationMaxBatchSize > 1000 {
return "", fmt.Errorf("aggregation_max_batch_size must be between 0 and 1000")
}
if aggregationWindowMs > 0 && aggregationMaxBatchSize == 0 {
aggregationMaxBatchSize = 100
}
id := uuid.New().String()
now := time.Now()
// Write both `topic` (legacy) and `topic_pattern` (new). Keeping `topic`
// populated lets old binaries running concurrently during a rolling
// upgrade continue reading triggers. A future migration drops `topic`.
query := `
INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at)
VALUES (?, ?, ?, TRUE, ?)
INSERT INTO function_pubsub_triggers (id, function_id, topic, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size)
VALUES (?, ?, ?, ?, TRUE, ?, ?, ?)
`
if _, err := s.db.Exec(ctx, query, id, functionID, topic, now); err != nil {
if _, err := s.db.Exec(ctx, query, id, functionID, topicPattern, topicPattern, now, aggregationWindowMs, aggregationMaxBatchSize); err != nil {
return "", fmt.Errorf("failed to add pubsub trigger: %w", err)
}
s.logger.Info("PubSub trigger added",
zap.String("trigger_id", id),
zap.String("function_id", functionID),
zap.String("topic", topic),
zap.String("topic_pattern", topicPattern),
zap.Bool("wildcard", IsWildcard(topicPattern)),
zap.Int("aggregation_window_ms", aggregationWindowMs),
zap.Int("aggregation_max_batch_size", aggregationMaxBatchSize),
)
return id, nil
@ -129,7 +172,7 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri
}
query := `
SELECT id, function_id, topic, enabled, created_at
SELECT id, function_id, topic_pattern, enabled, created_at, aggregation_window_ms, aggregation_max_batch_size
FROM function_pubsub_triggers
WHERE function_id = ?
`
@ -142,18 +185,22 @@ func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID stri
triggers := make([]serverless.PubSubTrigger, len(rows))
for i, row := range rows {
triggers[i] = serverless.PubSubTrigger{
ID: row.ID,
FunctionID: row.FunctionID,
Topic: row.Topic,
Enabled: row.Enabled,
ID: row.ID,
FunctionID: row.FunctionID,
Topic: row.TopicPattern,
Enabled: row.Enabled,
AggregationWindowMs: row.AggregationWindowMs,
AggregationMaxBatchSize: row.AggregationMaxBatchSize,
}
}
return triggers, nil
}
// GetByTopicAndNamespace returns all enabled triggers for a topic within a namespace.
// Only returns triggers for active functions.
// GetByTopicAndNamespace returns all enabled triggers whose topic_pattern
// matches `topic` within the namespace. Patterns are SQLite GLOB; the
// post-filter enforces stricter segment-aware semantics.
// Only triggers for active functions are returned.
func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic, namespace string) ([]TriggerMatch, error) {
if topic == "" || namespace == "" {
return nil, nil
@ -161,10 +208,12 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic,
query := `
SELECT t.id AS trigger_id, t.function_id AS function_id,
f.name AS function_name, f.namespace AS namespace, t.topic AS topic
f.name AS function_name, f.namespace AS namespace, t.topic_pattern AS topic_pattern,
t.aggregation_window_ms AS aggregation_window_ms,
t.aggregation_max_batch_size AS aggregation_max_batch_size
FROM function_pubsub_triggers t
JOIN functions f ON t.function_id = f.id
WHERE t.topic = ? AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active'
WHERE ? GLOB t.topic_pattern AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active'
`
var rows []triggerMatchRow
@ -172,15 +221,21 @@ func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic,
return nil, fmt.Errorf("failed to query triggers for topic: %w", err)
}
matches := make([]TriggerMatch, len(rows))
for i, row := range rows {
matches[i] = TriggerMatch{
TriggerID: row.TriggerID,
FunctionID: row.FunctionID,
FunctionName: row.FunctionName,
Namespace: row.Namespace,
Topic: row.Topic,
matches := make([]TriggerMatch, 0, len(rows))
for _, row := range rows {
// Post-filter to enforce strict segment boundaries on '*'.
if !PatternMatches(row.TopicPattern, topic) {
continue
}
matches = append(matches, TriggerMatch{
TriggerID: row.TriggerID,
FunctionID: row.FunctionID,
FunctionName: row.FunctionName,
Namespace: row.Namespace,
Topic: topic, // resolved topic, not the pattern
AggregationWindowMs: row.AggregationWindowMs,
AggregationMaxBatchSize: row.AggregationMaxBatchSize,
})
}
return matches, nil

View File

@ -40,13 +40,6 @@ func TestDispatcher_DepthLimit(t *testing.T) {
d.Dispatch(context.Background(), "ns", "topic", []byte("data"), maxTriggerDepth+1)
}
func TestCacheKey(t *testing.T) {
key := cacheKey("my-namespace", "my-topic")
if key != "triggers:my-namespace:my-topic" {
t.Errorf("unexpected cache key: %s", key)
}
}
func TestPubSubEvent_Marshal(t *testing.T) {
event := PubSubEvent{
Topic: "chat",

View File

@ -288,11 +288,21 @@ type DBTrigger struct {
}
// PubSubTrigger represents a pubsub trigger.
//
// Topic may be an exact topic name or a SQLite GLOB pattern (e.g.
// "presence:*"). See pkg/serverless/triggers/pattern.go for matching rules.
//
// AggregationWindowMs > 0 enables event buffering: the dispatcher accumulates
// events for at most that many milliseconds (or until AggregationMaxBatchSize
// events have been collected, whichever comes first), then invokes the
// function once with a batched payload of type BatchedPubSubEvent.
type PubSubTrigger struct {
ID string `json:"id"`
FunctionID string `json:"function_id"`
Topic string `json:"topic"`
Enabled bool `json:"enabled"`
ID string `json:"id"`
FunctionID string `json:"function_id"`
Topic string `json:"topic"`
Enabled bool `json:"enabled"`
AggregationWindowMs int `json:"aggregation_window_ms,omitempty"`
AggregationMaxBatchSize int `json:"aggregation_max_batch_size,omitempty"`
}
// Timer represents a one-time scheduled execution.
@ -337,6 +347,14 @@ type HostServices interface {
// PubSub operations
PubSubPublish(ctx context.Context, topic string, data []byte) error
PubSubPublishBatch(ctx context.Context, msgsJSON []byte) error
// Push notifications. Sends to all of `userID`'s registered devices in
// the function's namespace. `msgJSON` is the JSON-encoded PushSendArgs
// shape (see hostfunctions.PushSend). Returns nil if push is not
// configured (silent no-op) so functions can be portable across
// namespaces with/without push enabled.
PushSend(ctx context.Context, userID string, msgJSON []byte) error
// WebSocket operations (only valid in WS context)
WSSend(ctx context.Context, clientID string, data []byte) error

106
core/pkg/sniproxy/router.go Normal file
View File

@ -0,0 +1,106 @@
package sniproxy
import (
"strings"
"sync"
)
// Backend describes where to forward a connection.
type Backend struct {
// Name is for logs/metrics only. Optional.
Name string
// Network is the dial network ("tcp", "tcp4", "tcp6"). Default "tcp".
Network string
// Addr is the dial target ("127.0.0.1:5349").
Addr string
}
// Route maps an SNI value (or wildcard pattern) to a Backend.
//
// Match semantics:
// - "example.com" matches exactly "example.com"
// - "*.example.com" matches any single-label subdomain ("a.example.com"
// but not "a.b.example.com" — single-label like DNS wildcards)
type Route struct {
Match string
Backend Backend
}
// Router atomically swaps a routing table while concurrent reads are in
// flight. Reads are lock-free after the slice is published.
type Router struct {
mu sync.RWMutex
routes []Route
fallback Backend
}
// NewRouter creates a router with no routes and the given fallback.
func NewRouter(fallback Backend) *Router {
return &Router{fallback: fallback}
}
// Pick returns the matching backend for an SNI value, or the fallback if
// no route matches (or if sni is empty).
func (r *Router) Pick(sni string) Backend {
if sni == "" {
r.mu.RLock()
defer r.mu.RUnlock()
return r.fallback
}
sni = strings.ToLower(sni)
r.mu.RLock()
defer r.mu.RUnlock()
for _, route := range r.routes {
if matchSNI(route.Match, sni) {
return route.Backend
}
}
return r.fallback
}
// Replace atomically swaps the routing table. The new routes replace the
// old ones in their entirety; partial updates are not supported.
func (r *Router) Replace(routes []Route, fallback Backend) {
r.mu.Lock()
defer r.mu.Unlock()
r.routes = routes
r.fallback = fallback
}
// Routes returns a defensive copy of the current routes. For introspection.
func (r *Router) Routes() []Route {
r.mu.RLock()
defer r.mu.RUnlock()
out := make([]Route, len(r.routes))
copy(out, r.routes)
return out
}
// Fallback returns the current fallback backend.
func (r *Router) Fallback() Backend {
r.mu.RLock()
defer r.mu.RUnlock()
return r.fallback
}
// matchSNI implements the Match semantics documented on Route.
func matchSNI(pattern, sni string) bool {
pattern = strings.ToLower(pattern)
if pattern == sni {
return true
}
// "*.example.com" matches "<single-label>.example.com".
if strings.HasPrefix(pattern, "*.") {
suffix := pattern[1:] // ".example.com"
if !strings.HasSuffix(sni, suffix) {
return false
}
labelEnd := len(sni) - len(suffix)
if labelEnd <= 0 {
return false
}
// No additional dots in the wildcard label.
return !strings.Contains(sni[:labelEnd], ".")
}
return false
}

View File

@ -0,0 +1,113 @@
package sniproxy
import (
"sync"
"testing"
)
func TestRouter_pick_exact_match(t *testing.T) {
fb := Backend{Name: "fallback", Addr: "127.0.0.1:9000"}
r := NewRouter(fb)
r.Replace([]Route{
{Match: "turn.example.com", Backend: Backend{Name: "turn", Addr: "127.0.0.1:5349"}},
}, fb)
got := r.Pick("turn.example.com")
if got.Addr != "127.0.0.1:5349" {
t.Errorf("expected turn backend, got %+v", got)
}
}
func TestRouter_pick_unmatched_returns_fallback(t *testing.T) {
fb := Backend{Name: "caddy", Addr: "127.0.0.1:8443"}
r := NewRouter(fb)
r.Replace([]Route{
{Match: "turn.example.com", Backend: Backend{Addr: "127.0.0.1:5349"}},
}, fb)
if got := r.Pick("api.example.com"); got != fb {
t.Errorf("expected fallback, got %+v", got)
}
if got := r.Pick(""); got != fb {
t.Errorf("expected fallback for empty SNI, got %+v", got)
}
}
func TestRouter_pick_case_insensitive(t *testing.T) {
fb := Backend{Addr: "127.0.0.1:8443"}
r := NewRouter(fb)
r.Replace([]Route{
{Match: "Turn.Example.Com", Backend: Backend{Addr: "127.0.0.1:5349"}},
}, fb)
if got := r.Pick("turn.example.com"); got.Addr != "127.0.0.1:5349" {
t.Errorf("expected case-insensitive match, got %+v", got)
}
}
func TestRouter_pick_wildcard_subdomain(t *testing.T) {
fb := Backend{Addr: "127.0.0.1:8443"}
r := NewRouter(fb)
r.Replace([]Route{
{Match: "*.example.com", Backend: Backend{Name: "wild", Addr: "127.0.0.1:5349"}},
}, fb)
cases := map[string]bool{
"a.example.com": true,
"foo.example.com": true,
"a.b.example.com": false, // multi-label not allowed
"example.com": false, // bare domain doesn't match *.example.com
"other.com": false,
}
for sni, want := range cases {
got := r.Pick(sni) == Backend{Name: "wild", Addr: "127.0.0.1:5349"}
if got != want {
t.Errorf("Pick(%q): want match=%v, got match=%v", sni, want, got)
}
}
}
func TestRouter_replace_atomic(t *testing.T) {
// Many concurrent reads against many concurrent Replace calls — should
// never observe partial state. Run with -race.
fb := Backend{Addr: "fb"}
r := NewRouter(fb)
r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "1"}}}, fb)
var wg sync.WaitGroup
stop := make(chan struct{})
// Readers
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stop:
return
default:
_ = r.Pick("a.com")
}
}
}()
}
// Writers
for i := 0; i < 200; i++ {
r.Replace([]Route{{Match: "a.com", Backend: Backend{Addr: "x"}}}, fb)
}
close(stop)
wg.Wait()
}
func TestRouter_routes_returns_copy(t *testing.T) {
r := NewRouter(Backend{})
original := []Route{{Match: "a", Backend: Backend{Addr: "1"}}}
r.Replace(original, Backend{})
got := r.Routes()
got[0].Match = "mutated"
if r.Routes()[0].Match != "a" {
t.Error("Routes() should return a defensive copy")
}
}

177
core/pkg/sniproxy/server.go Normal file
View File

@ -0,0 +1,177 @@
package sniproxy
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
"go.uber.org/zap"
)
// Config tunes the proxy server.
type Config struct {
// ClientHelloTimeout bounds the wait for a parseable ClientHello.
// 0 selects 5 seconds.
ClientHelloTimeout time.Duration
// BackendDialTimeout bounds backend connect time. 0 selects 5 seconds.
BackendDialTimeout time.Duration
// MaxConcurrentConns caps total in-flight connections to prevent
// resource exhaustion. 0 selects 10000.
MaxConcurrentConns int
}
// Server is a TCP-level SNI router. Create via NewServer, then call
// Serve(listener) in a goroutine. Close cancels in-flight connections.
type Server struct {
router *Router
cfg Config
logger *zap.Logger
gate chan struct{} // bounded semaphore for concurrent connections
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
// NewServer constructs a Server with the given router and config.
func NewServer(router *Router, cfg Config, logger *zap.Logger) *Server {
if logger == nil {
logger = zap.NewNop()
}
if cfg.ClientHelloTimeout <= 0 {
cfg.ClientHelloTimeout = 5 * time.Second
}
if cfg.BackendDialTimeout <= 0 {
cfg.BackendDialTimeout = 5 * time.Second
}
if cfg.MaxConcurrentConns <= 0 {
cfg.MaxConcurrentConns = 10000
}
ctx, cancel := context.WithCancel(context.Background())
return &Server{
router: router,
cfg: cfg,
logger: logger.Named("sniproxy"),
gate: make(chan struct{}, cfg.MaxConcurrentConns),
ctx: ctx,
cancel: cancel,
}
}
// Serve accepts connections from ln until ln.Accept returns a permanent
// error or Close is called. Serve always returns a non-nil error.
func (s *Server) Serve(ln net.Listener) error {
for {
conn, err := ln.Accept()
if err != nil {
// Check for shutdown via cancelled ctx.
if s.ctx.Err() != nil {
return s.ctx.Err()
}
// Net errors temporarily? Backoff briefly so we don't busy-loop.
var ne net.Error
if errors.As(err, &ne) && ne.Timeout() {
time.Sleep(50 * time.Millisecond)
continue
}
return err
}
select {
case s.gate <- struct{}{}:
default:
s.logger.Warn("max concurrent connections reached, dropping",
zap.Int("limit", s.cfg.MaxConcurrentConns),
zap.String("remote", conn.RemoteAddr().String()),
)
conn.Close()
continue
}
s.wg.Add(1)
go func(c net.Conn) {
defer s.wg.Done()
defer func() { <-s.gate }()
s.handle(c)
}(conn)
}
}
// Close cancels in-flight connections and waits for handlers to drain.
func (s *Server) Close() {
s.cancel()
s.wg.Wait()
}
// handle processes a single accepted connection: peek SNI, dial backend,
// replay peeked bytes, then bidirectional copy.
func (s *Server) handle(conn net.Conn) {
defer conn.Close()
sni, peeked, err := PeekClientHello(conn, s.cfg.ClientHelloTimeout)
if err != nil {
s.logger.Debug("ClientHello peek failed",
zap.String("remote", conn.RemoteAddr().String()),
zap.Error(err),
)
return
}
backend := s.router.Pick(sni)
if backend.Addr == "" {
s.logger.Warn("no backend for SNI",
zap.String("sni", sni),
zap.String("remote", conn.RemoteAddr().String()),
)
return
}
network := backend.Network
if network == "" {
network = "tcp"
}
upstream, err := net.DialTimeout(network, backend.Addr, s.cfg.BackendDialTimeout)
if err != nil {
s.logger.Warn("backend dial failed",
zap.String("sni", sni),
zap.String("backend", backend.Addr),
zap.Error(err),
)
return
}
defer upstream.Close()
// Replay peeked bytes (the ClientHello + anything else buffered).
if len(peeked) > 0 {
if _, err := upstream.Write(peeked); err != nil {
s.logger.Debug("replay to backend failed",
zap.String("sni", sni),
zap.Error(err),
)
return
}
}
// Bidirectional copy. We close both connections when either side
// finishes OR when the server is shutting down, so handle() can't
// hang forever on a half-stuck peer.
done := make(chan struct{}, 2)
go func() {
_, _ = io.Copy(upstream, conn)
done <- struct{}{}
}()
go func() {
_, _ = io.Copy(conn, upstream)
done <- struct{}{}
}()
select {
case <-done:
case <-s.ctx.Done():
}
// Force both sides closed; second copy will exit immediately.
upstream.Close()
conn.Close()
<-done // drain the second goroutine
}

View File

@ -0,0 +1,143 @@
package sniproxy
import (
"bufio"
"crypto/tls"
"errors"
"io"
"net"
"testing"
"time"
"go.uber.org/zap"
)
// startEchoBackend creates a TCP server that echoes the first 1024 bytes
// it reads, then closes. Returns the listener and a chan that receives
// the bytes the server saw.
func startEchoBackend(t *testing.T) (net.Listener, <-chan []byte) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
got := make(chan []byte, 4)
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
_ = c.SetReadDeadline(time.Now().Add(2 * time.Second))
buf := make([]byte, 1024)
n, _ := c.Read(buf)
got <- append([]byte(nil), buf[:n]...)
}(conn)
}
}()
return ln, got
}
func TestServer_routes_TLS_to_correct_backend(t *testing.T) {
turnLn, turnGot := startEchoBackend(t)
defer turnLn.Close()
caddyLn, caddyGot := startEchoBackend(t)
defer caddyLn.Close()
router := NewRouter(Backend{Network: "tcp", Addr: caddyLn.Addr().String()})
router.Replace([]Route{
{Match: "turn.example.com", Backend: Backend{Network: "tcp", Addr: turnLn.Addr().String()}},
}, router.Fallback())
srv := NewServer(router, Config{}, zap.NewNop())
defer srv.Close()
frontLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer frontLn.Close()
go func() { _ = srv.Serve(frontLn) }()
// Client A: SNI=turn.example.com -> goes to turnLn
dialAndStartTLS(t, frontLn.Addr().String(), "turn.example.com")
// Client B: SNI=other.example.com -> falls through to caddyLn
dialAndStartTLS(t, frontLn.Addr().String(), "other.example.com")
select {
case b := <-turnGot:
if len(b) == 0 {
t.Error("turn backend received empty bytes")
}
case <-time.After(3 * time.Second):
t.Error("turn backend did not receive bytes")
}
select {
case b := <-caddyGot:
if len(b) == 0 {
t.Error("caddy fallback received empty bytes")
}
case <-time.After(3 * time.Second):
t.Error("caddy fallback did not receive bytes")
}
}
// dialAndStartTLS opens a TLS handshake (which produces a ClientHello)
// against the given address with the given SNI. Returns immediately —
// the test only needs the proxy to forward the bytes; it doesn't
// require handshake completion.
func dialAndStartTLS(t *testing.T, addr, sni string) {
t.Helper()
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
go func() {
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
c := tls.Client(conn, &tls.Config{ServerName: sni, InsecureSkipVerify: true})
_ = c.Handshake() // expected to fail (echo backend isn't TLS)
}()
}
func TestServer_no_backend_drops_connection(t *testing.T) {
router := NewRouter(Backend{}) // empty fallback, empty Addr -> dropped
srv := NewServer(router, Config{}, zap.NewNop())
defer srv.Close()
frontLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer frontLn.Close()
go func() { _ = srv.Serve(frontLn) }()
conn, err := net.Dial("tcp", frontLn.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true})
// Handshake should fail because connection is closed by proxy.
go func() { _ = c.Handshake() }()
// Reader should see EOF quickly.
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
br := bufio.NewReader(conn)
_, err = br.ReadByte()
if err == nil {
t.Error("expected connection drop")
}
if !errors.Is(err, io.EOF) {
// "use of closed network connection" is also fine.
t.Logf("acceptable read error: %v", err)
}
}

235
core/pkg/sniproxy/sni.go Normal file
View File

@ -0,0 +1,235 @@
// Package sniproxy provides a TCP-level Server Name Indication (SNI) router.
//
// The router peeks at the unencrypted TLS ClientHello on each accepted
// connection, extracts the SNI host name, and forwards the raw stream to
// a backend. It does NOT terminate TLS — encrypted bytes pass through
// verbatim. This lets one TCP port serve multiple TLS-speaking backends
// (HTTPS for the gateway, TURNS for stealth WebRTC, etc.) without
// sharing private keys with the proxy.
//
// Design goals:
// - Zero TLS material on the proxy
// - Bounded ClientHello read (no slowloris)
// - Backend dial timeout
// - Per-IP rate limiting
//
// SNI parsing follows RFC 5246 §7.4.1.2 (TLS record + ClientHello) and
// RFC 6066 §3 (server_name extension).
package sniproxy
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
)
// ErrNoSNI is returned when the ClientHello has no server_name extension.
var ErrNoSNI = errors.New("sniproxy: ClientHello has no SNI")
// MaxClientHelloBytes bounds how many bytes we'll read while looking for
// the SNI. TLS ClientHello records are typically 200500 bytes; this is
// a generous cap that still defends against memory abuse.
const MaxClientHelloBytes = 16 * 1024
// PeekClientHello reads bytes from conn until a TLS ClientHello has
// been parsed (or MaxClientHelloBytes is exceeded). Returns the SNI
// hostname (lowercased), the bytes consumed (must be replayed to the
// backend), and any error.
//
// readTimeout bounds the wait — slowloris-style stalls return an error
// quickly without holding the goroutine indefinitely.
func PeekClientHello(conn net.Conn, readTimeout time.Duration) (string, []byte, error) {
if readTimeout > 0 {
_ = conn.SetReadDeadline(time.Now().Add(readTimeout))
defer conn.SetReadDeadline(time.Time{})
}
br := bufio.NewReaderSize(conn, MaxClientHelloBytes)
// Peek the TLS record header (5 bytes): content_type, version (2),
// length (2). content_type for ClientHello is 22 (handshake).
header, err := br.Peek(5)
if err != nil {
return "", nil, fmt.Errorf("read tls record header: %w", err)
}
if header[0] != 22 {
return "", nil, fmt.Errorf("not a TLS handshake record (type=%d)", header[0])
}
recLen := int(binary.BigEndian.Uint16(header[3:5]))
if recLen <= 0 || 5+recLen > MaxClientHelloBytes {
return "", nil, fmt.Errorf("invalid record length %d", recLen)
}
full, err := br.Peek(5 + recLen)
if err != nil {
return "", nil, fmt.Errorf("read tls record body: %w", err)
}
sni, err := parseSNI(full[5:])
if err != nil {
return "", nil, err
}
// We've only peeked — drain the buffer to capture the bytes for replay.
consumed := make([]byte, br.Buffered())
if _, err := io.ReadFull(br, consumed); err != nil {
return "", nil, fmt.Errorf("drain peeked bytes: %w", err)
}
return sni, consumed, nil
}
// parseSNI parses a TLS ClientHello body (without the 5-byte record
// header) and returns the server_name extension value if present.
func parseSNI(body []byte) (string, error) {
r := newReader(body)
// Handshake type (1 byte) — must be 1 (ClientHello).
hsType, err := r.readByte()
if err != nil {
return "", err
}
if hsType != 1 {
return "", fmt.Errorf("not a ClientHello (handshake type %d)", hsType)
}
// Handshake length (3 bytes).
if _, err := r.readBytes(3); err != nil {
return "", err
}
// client_version (2) + random (32).
if _, err := r.readBytes(2 + 32); err != nil {
return "", err
}
// session_id.
sidLen, err := r.readByte()
if err != nil {
return "", err
}
if _, err := r.readBytes(int(sidLen)); err != nil {
return "", err
}
// cipher_suites length (2).
csLen, err := r.readUint16()
if err != nil {
return "", err
}
if _, err := r.readBytes(int(csLen)); err != nil {
return "", err
}
// compression_methods length (1).
cmLen, err := r.readByte()
if err != nil {
return "", err
}
if _, err := r.readBytes(int(cmLen)); err != nil {
return "", err
}
// Extensions length (2). Optional — TLS 1.0 ClientHello can skip it.
if r.remaining() < 2 {
return "", ErrNoSNI
}
extTotalLen, err := r.readUint16()
if err != nil {
return "", err
}
if int(extTotalLen) > r.remaining() {
return "", fmt.Errorf("extensions truncated")
}
end := r.pos + int(extTotalLen)
for r.pos < end {
extType, err := r.readUint16()
if err != nil {
return "", err
}
extLen, err := r.readUint16()
if err != nil {
return "", err
}
extData, err := r.readBytes(int(extLen))
if err != nil {
return "", err
}
// server_name extension is type 0.
if extType != 0 {
continue
}
return parseServerName(extData)
}
return "", ErrNoSNI
}
// parseServerName parses the body of a server_name extension and returns
// the first host_name (type 0) entry.
func parseServerName(data []byte) (string, error) {
r := newReader(data)
// server_name_list length (2).
listLen, err := r.readUint16()
if err != nil {
return "", err
}
if int(listLen) > r.remaining() {
return "", fmt.Errorf("server_name list truncated")
}
end := r.pos + int(listLen)
for r.pos < end {
nameType, err := r.readByte()
if err != nil {
return "", err
}
nameLen, err := r.readUint16()
if err != nil {
return "", err
}
nameBytes, err := r.readBytes(int(nameLen))
if err != nil {
return "", err
}
if nameType == 0 { // host_name
return strings.ToLower(string(nameBytes)), nil
}
}
return "", ErrNoSNI
}
// reader is a tiny byte-slice cursor used by parseSNI/parseServerName.
type reader struct {
buf []byte
pos int
}
func newReader(buf []byte) *reader { return &reader{buf: buf} }
func (r *reader) remaining() int { return len(r.buf) - r.pos }
func (r *reader) readByte() (byte, error) {
if r.pos >= len(r.buf) {
return 0, io.ErrUnexpectedEOF
}
b := r.buf[r.pos]
r.pos++
return b, nil
}
func (r *reader) readUint16() (uint16, error) {
if r.pos+2 > len(r.buf) {
return 0, io.ErrUnexpectedEOF
}
v := binary.BigEndian.Uint16(r.buf[r.pos : r.pos+2])
r.pos += 2
return v, nil
}
func (r *reader) readBytes(n int) ([]byte, error) {
if r.pos+n > len(r.buf) {
return nil, io.ErrUnexpectedEOF
}
b := r.buf[r.pos : r.pos+n]
r.pos += n
return b, nil
}

View File

@ -0,0 +1,173 @@
package sniproxy
import (
"crypto/tls"
"errors"
"io"
"net"
"sync"
"testing"
"time"
)
// dialAndPeek dials a TLS handshake to the given listener and returns
// what PeekClientHello on the server side parsed.
func dialAndPeek(t *testing.T, ln net.Listener, sni string) (string, []byte, error) {
t.Helper()
type result struct {
sni string
peeked []byte
err error
}
resCh := make(chan result, 1)
// Server side: accept once, peek SNI.
go func() {
conn, err := ln.Accept()
if err != nil {
resCh <- result{err: err}
return
}
defer conn.Close()
s, p, err := PeekClientHello(conn, 2*time.Second)
resCh <- result{sni: s, peeked: p, err: err}
}()
// Client side: kick off a TLS handshake. We don't care if it
// completes (no server cert) — we only need ClientHello to be sent.
// Use a goroutine so the test doesn't deadlock waiting on Handshake.
go func() {
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
return
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
c := tls.Client(conn, &tls.Config{
ServerName: sni,
InsecureSkipVerify: true,
})
_ = c.Handshake() // expected to fail; we only needed the ClientHello
}()
select {
case r := <-resCh:
return r.sni, r.peeked, r.err
case <-time.After(5 * time.Second):
return "", nil, errors.New("test timeout")
}
}
func TestPeekClientHello_returns_sni(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
sni, peeked, err := dialAndPeek(t, ln, "example.com")
if err != nil {
t.Fatalf("PeekClientHello: %v", err)
}
if sni != "example.com" {
t.Errorf("expected sni=example.com, got %q", sni)
}
if len(peeked) == 0 {
t.Error("expected non-empty peeked bytes")
}
}
func TestPeekClientHello_lowercases_sni(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
sni, _, err := dialAndPeek(t, ln, "Example.COM")
if err != nil {
t.Fatal(err)
}
if sni != "example.com" {
t.Errorf("expected lowercase, got %q", sni)
}
}
func TestPeekClientHello_non_tls_returns_error(t *testing.T) {
a, b := net.Pipe()
defer a.Close()
defer b.Close()
go func() {
// Send something that isn't a TLS handshake record.
_, _ = a.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
_ = a.Close()
}()
_, _, err := PeekClientHello(b, 1*time.Second)
if err == nil {
t.Fatal("expected error for non-TLS bytes")
}
}
func TestPeekClientHello_short_record_returns_error(t *testing.T) {
a, b := net.Pipe()
defer a.Close()
defer b.Close()
go func() {
// One byte, then close — too short for record header.
_, _ = a.Write([]byte{22})
_ = a.Close()
}()
_, _, err := PeekClientHello(b, 1*time.Second)
if err == nil {
t.Fatal("expected error for short record")
}
// EOF or read error is acceptable.
if !errors.Is(err, io.EOF) && err.Error() == "" {
t.Logf("error: %v", err)
}
}
func TestPeekClientHello_concurrent_safe(t *testing.T) {
// Verify no shared state leaks between PeekClientHello calls.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
return
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
c := tls.Client(conn, &tls.Config{ServerName: "x.example.com", InsecureSkipVerify: true})
_ = c.Handshake()
}()
}
for i := 0; i < 4; i++ {
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
sni, _, err := PeekClientHello(conn, 2*time.Second)
conn.Close()
if err != nil {
t.Errorf("peek %d: %v", i, err)
}
if sni != "x.example.com" {
t.Errorf("peek %d: got %q", i, sni)
}
}
wg.Wait()
}

BIN
core/sni-router Executable file

Binary file not shown.

View File

@ -0,0 +1,38 @@
[Unit]
Description=Orama SNI Router (TLS-level :443 → backend forwarder)
Documentation=https://github.com/DeBrosOfficial/network
After=network.target
Before=caddy.service
PartOf=orama-node.service
[Service]
Type=simple
WorkingDirectory=/opt/orama
EnvironmentFile=-/opt/orama/.orama/data/sni-router.env
ExecStart=/opt/orama/bin/orama-sni-router --config sni-router.yaml
# Bind privileged ports (:80, :443) without running as root.
AmbientCapabilities=CAP_NET_BIND_SERVICE
CapabilityBoundingSet=CAP_NET_BIND_SERVICE
User=orama
Group=orama
NoNewPrivileges=yes
ProtectSystem=strict
ProtectHome=yes
PrivateTmp=yes
LimitNOFILE=65536
TimeoutStopSec=15s
KillMode=mixed
KillSignal=SIGTERM
Restart=on-failure
RestartSec=5s
StandardOutput=journal
StandardError=journal
SyslogIdentifier=orama-sni-router
[Install]
WantedBy=multi-user.target